diff --git a/cql3/selection/abstract_function_selector.hh b/cql3/selection/abstract_function_selector.hh index 6236b5dfa5..e53997d865 100644 --- a/cql3/selection/abstract_function_selector.hh +++ b/cql3/selection/abstract_function_selector.hh @@ -35,11 +35,13 @@ protected: public: static shared_ptr new_factory(shared_ptr fun, shared_ptr factories); - abstract_function_selector(shared_ptr fun, std::vector> arg_selectors) + // If reserve_extra_arg is set, the internal buffer used for holding function argument lists is enlarged to account for + // an aggregate function's accumulator. + abstract_function_selector(shared_ptr fun, std::vector> arg_selectors, bool reserve_extra_arg = false) : _fun(std::move(fun)), _arg_selectors(std::move(arg_selectors)), _requires_thread(boost::algorithm::any_of(_arg_selectors, [] (auto& s) { return s->requires_thread(); }) || _fun->requires_thread()) { - _args.resize(_arg_selectors.size()); + _args.resize(_arg_selectors.size() + unsigned(reserve_extra_arg)); } virtual bool requires_thread() const override; @@ -74,8 +76,10 @@ protected: shared_ptr fun() const { return _tfun; } public: - abstract_function_selector_for(shared_ptr fun, std::vector> arg_selectors) - : abstract_function_selector(fun, std::move(arg_selectors)) + // If reserve_extra_arg is set, the internal buffer used for holding function argument lists is enlarged to account for + // an aggregate function's accumulator. + abstract_function_selector_for(shared_ptr fun, std::vector> arg_selectors, bool reserve_extra_arg = false) + : abstract_function_selector(fun, std::move(arg_selectors), reserve_extra_arg) , _tfun(dynamic_pointer_cast(fun)) { } diff --git a/cql3/selection/aggregate_function_selector.hh b/cql3/selection/aggregate_function_selector.hh index 97427da068..8b811c8036 100644 --- a/cql3/selection/aggregate_function_selector.hh +++ b/cql3/selection/aggregate_function_selector.hh @@ -17,7 +17,8 @@ namespace cql3 { namespace selection { class aggregate_function_selector : public abstract_function_selector_for { - std::unique_ptr _aggregate; + const db::functions::stateless_aggregate_function& _aggregate; + bytes_opt _accumulator; public: virtual bool is_aggregate() const override { return true; @@ -26,28 +27,32 @@ public: virtual void add_input(result_set_builder& rs) override { // Aggregation of aggregation is not supported size_t m = _arg_selectors.size(); + _args[0] = std::move(_accumulator); for (size_t i = 0; i < m; ++i) { auto&& s = _arg_selectors[i]; s->add_input(rs); - _args[i] = s->get_output(); + _args[i + 1] = s->get_output(); s->reset(); } - _aggregate->add_input(_args); + _accumulator = _aggregate.aggregation_function->execute(_args); } virtual bytes_opt get_output() override { - return _aggregate->compute(); + return _aggregate.state_to_result_function + ? _aggregate.state_to_result_function->execute({std::move(_accumulator)}) + : std::move(_accumulator); } virtual void reset() override { - _aggregate->reset(); + _accumulator = _aggregate.initial_state; } aggregate_function_selector(shared_ptr func, std::vector> arg_selectors) : abstract_function_selector_for( - dynamic_pointer_cast(func), std::move(arg_selectors)) - , _aggregate(fun()->new_aggregate()) { + dynamic_pointer_cast(func), std::move(arg_selectors), true) + , _aggregate(fun()->get_aggregate()) + , _accumulator(_aggregate.initial_state) { } }; diff --git a/db/functions/aggregate_function.hh b/db/functions/aggregate_function.hh index 4edb449d44..430501c108 100644 --- a/db/functions/aggregate_function.hh +++ b/db/functions/aggregate_function.hh @@ -27,11 +27,8 @@ protected: private: shared_ptr _reducible; private: - class aggregate_adapter; static shared_ptr make_reducible_variant(stateless_aggregate_function saf); public: - class aggregate; - explicit aggregate_function(stateless_aggregate_function saf, bool reducible_variant = false); /** @@ -39,7 +36,7 @@ public: * * @return a new Aggregate instance. */ - std::unique_ptr new_aggregate(); + const stateless_aggregate_function& get_aggregate() const; /** * Checks wheather the function can be distributed and is able to reduce states. @@ -70,41 +67,6 @@ public: virtual bool is_aggregate() const override; virtual void print(std::ostream& os) const override; virtual sstring column_name(const std::vector& column_names) const override; - - /** - * An aggregation operation. - */ - class aggregate { - public: - using opt_bytes = aggregate_function::opt_bytes; - - virtual ~aggregate() {} - - /** - * Adds the specified input to this aggregate. - * - * @param values the values to add to the aggregate. - */ - virtual void add_input(const std::vector& values) = 0; - - /** - * Computes and returns the aggregate current value. - * - * @return the aggregate current value. - */ - virtual opt_bytes compute() = 0; - - virtual void set_accumulator(const opt_bytes& acc) = 0; - - virtual opt_bytes get_accumulator() const = 0; - - virtual void reduce(const opt_bytes& acc) = 0; - - /** - * Reset this aggregate. - */ - virtual void reset() = 0; - }; }; } diff --git a/db/functions/function.cc b/db/functions/function.cc index cbea068235..62a1e6f62f 100644 --- a/db/functions/function.cc +++ b/db/functions/function.cc @@ -5,59 +5,16 @@ namespace db::functions { -class aggregate_function::aggregate_adapter : public aggregate { - const stateless_aggregate_function& _agg; - bytes_opt _state; -public: - explicit aggregate_adapter(const stateless_aggregate_function& agg) - : _agg(agg) - , _state(agg.initial_state) { - } - - virtual void add_input(const std::vector& values) override { - std::vector state_and_values; - state_and_values.reserve(values.size() + 1); - state_and_values.push_back(std::move(_state)); - std::copy(values.begin(), values.end(), std::back_inserter(state_and_values)); - _state = _agg.aggregation_function->execute(state_and_values); - } - - virtual opt_bytes compute() override { - if (_agg.state_to_result_function) { - std::vector state_vec; - state_vec.push_back(std::move(_state)); - return _agg.state_to_result_function->execute(state_vec); - } else { - return std::move(_state); - } - } - - virtual void set_accumulator(const opt_bytes& acc) override { - _state = acc; - } - - virtual opt_bytes get_accumulator() const override { - return _state; - } - - virtual void reduce(const opt_bytes& acc) override { - std::vector two_states; - two_states.reserve(2); - two_states.push_back(std::move(_state)); - two_states.push_back(acc); - _state = _agg.state_reduction_function->execute(two_states); - } - - virtual void reset() override { - _state = _agg.initial_state; - } -}; - aggregate_function::aggregate_function(stateless_aggregate_function agg, bool reducible_variant) : _agg(std::move(agg)) , _reducible(!reducible_variant ? make_reducible_variant(_agg) : nullptr) { } +const stateless_aggregate_function& +aggregate_function::get_aggregate() const { + return _agg; +} + shared_ptr aggregate_function::make_reducible_variant(stateless_aggregate_function agg) { if (!agg.state_reduction_function) { @@ -69,11 +26,6 @@ aggregate_function::make_reducible_variant(stateless_aggregate_function agg) { return make_shared(new_agg, true); } -std::unique_ptr -aggregate_function::new_aggregate() { - return std::make_unique(_agg); -} - bool aggregate_function::is_reducible() const { return bool(_agg.state_reduction_function); diff --git a/service/forward_service.cc b/service/forward_service.cc index 94d60d80b3..43d24efd99 100644 --- a/service/forward_service.cc +++ b/service/forward_service.cc @@ -60,8 +60,7 @@ static std::vector<::shared_ptr> get_function class forward_aggregates { private: std::vector<::shared_ptr> _funcs; - std::vector> _aggrs; - + std::vector _aggrs; public: forward_aggregates(const query::forward_request& request); void merge(query::forward_result& result, query::forward_result&& other); @@ -85,10 +84,10 @@ public: forward_aggregates::forward_aggregates(const query::forward_request& request) { _funcs = get_functions(request); - std::vector> aggrs; + std::vector aggrs; for (auto& func: _funcs) { - aggrs.push_back(func->new_aggregate()); + aggrs.push_back(func->get_aggregate()); } _aggrs = std::move(aggrs); } @@ -113,9 +112,7 @@ void forward_aggregates::merge(query::forward_result &result, query::forward_res } for (size_t i = 0; i < _aggrs.size(); i++) { - _aggrs[i]->set_accumulator(result.query_results[i]); - _aggrs[i]->reduce(std::move(other.query_results[i])); - result.query_results[i] = _aggrs[i]->get_accumulator(); + result.query_results[i] = _aggrs[i].state_reduction_function->execute(std::vector({std::move(result.query_results[i]), std::move(other.query_results[i])})); } } @@ -126,7 +123,9 @@ void forward_aggregates::finalize(query::forward_result &result) { // as "WHERE p IN ()". We need to build a fake result with the result // of empty aggregation. for (size_t i = 0; i < _aggrs.size(); i++) { - result.query_results.push_back(_aggrs[i]->compute()); + result.query_results.push_back(_aggrs[i].state_to_result_function + ? _aggrs[i].state_to_result_function->execute(std::vector({_aggrs[i].initial_state})) + : _aggrs[i].initial_state); } return; } @@ -141,8 +140,9 @@ void forward_aggregates::finalize(query::forward_result &result) { } for (size_t i = 0; i < _aggrs.size(); i++) { - _aggrs[i]->set_accumulator(result.query_results[i]); - result.query_results[i] = _aggrs[i]->compute(); + result.query_results[i] = _aggrs[i].state_to_result_function + ? _aggrs[i].state_to_result_function->execute(std::vector({std::move(result.query_results[i])})) + : result.query_results[i]; } }