cql3/selection, forward_service: use use stateless_aggregate_function directly
Now that stateless_aggregate_function is directly exposed by aggregate_function, we can use it directly, avoiding the intermediary aggregate_function::aggregate, which is removed.
This commit is contained in:
@@ -35,11 +35,13 @@ protected:
|
||||
public:
|
||||
static shared_ptr<factory> new_factory(shared_ptr<functions::function> fun, shared_ptr<selector_factories> factories);
|
||||
|
||||
abstract_function_selector(shared_ptr<functions::function> fun, std::vector<shared_ptr<selector>> arg_selectors)
|
||||
// If reserve_extra_arg is set, the internal buffer used for holding function argument lists is enlarged to account for
|
||||
// an aggregate function's accumulator.
|
||||
abstract_function_selector(shared_ptr<functions::function> fun, std::vector<shared_ptr<selector>> arg_selectors, bool reserve_extra_arg = false)
|
||||
: _fun(std::move(fun)), _arg_selectors(std::move(arg_selectors)),
|
||||
_requires_thread(boost::algorithm::any_of(_arg_selectors, [] (auto& s) { return s->requires_thread(); })
|
||||
|| _fun->requires_thread()) {
|
||||
_args.resize(_arg_selectors.size());
|
||||
_args.resize(_arg_selectors.size() + unsigned(reserve_extra_arg));
|
||||
}
|
||||
|
||||
virtual bool requires_thread() const override;
|
||||
@@ -74,8 +76,10 @@ protected:
|
||||
|
||||
shared_ptr<const T> fun() const { return _tfun; }
|
||||
public:
|
||||
abstract_function_selector_for(shared_ptr<T> fun, std::vector<shared_ptr<selector>> arg_selectors)
|
||||
: abstract_function_selector(fun, std::move(arg_selectors))
|
||||
// If reserve_extra_arg is set, the internal buffer used for holding function argument lists is enlarged to account for
|
||||
// an aggregate function's accumulator.
|
||||
abstract_function_selector_for(shared_ptr<T> fun, std::vector<shared_ptr<selector>> arg_selectors, bool reserve_extra_arg = false)
|
||||
: abstract_function_selector(fun, std::move(arg_selectors), reserve_extra_arg)
|
||||
, _tfun(dynamic_pointer_cast<T>(fun)) {
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ namespace cql3 {
|
||||
namespace selection {
|
||||
|
||||
class aggregate_function_selector : public abstract_function_selector_for<functions::aggregate_function> {
|
||||
std::unique_ptr<functions::aggregate_function::aggregate> _aggregate;
|
||||
const db::functions::stateless_aggregate_function& _aggregate;
|
||||
bytes_opt _accumulator;
|
||||
public:
|
||||
virtual bool is_aggregate() const override {
|
||||
return true;
|
||||
@@ -26,28 +27,32 @@ public:
|
||||
virtual void add_input(result_set_builder& rs) override {
|
||||
// Aggregation of aggregation is not supported
|
||||
size_t m = _arg_selectors.size();
|
||||
_args[0] = std::move(_accumulator);
|
||||
for (size_t i = 0; i < m; ++i) {
|
||||
auto&& s = _arg_selectors[i];
|
||||
s->add_input(rs);
|
||||
_args[i] = s->get_output();
|
||||
_args[i + 1] = s->get_output();
|
||||
s->reset();
|
||||
}
|
||||
_aggregate->add_input(_args);
|
||||
_accumulator = _aggregate.aggregation_function->execute(_args);
|
||||
}
|
||||
|
||||
virtual bytes_opt get_output() override {
|
||||
return _aggregate->compute();
|
||||
return _aggregate.state_to_result_function
|
||||
? _aggregate.state_to_result_function->execute({std::move(_accumulator)})
|
||||
: std::move(_accumulator);
|
||||
}
|
||||
|
||||
virtual void reset() override {
|
||||
_aggregate->reset();
|
||||
_accumulator = _aggregate.initial_state;
|
||||
}
|
||||
|
||||
aggregate_function_selector(shared_ptr<functions::function> func,
|
||||
std::vector<shared_ptr<selector>> arg_selectors)
|
||||
: abstract_function_selector_for<functions::aggregate_function>(
|
||||
dynamic_pointer_cast<functions::aggregate_function>(func), std::move(arg_selectors))
|
||||
, _aggregate(fun()->new_aggregate()) {
|
||||
dynamic_pointer_cast<functions::aggregate_function>(func), std::move(arg_selectors), true)
|
||||
, _aggregate(fun()->get_aggregate())
|
||||
, _accumulator(_aggregate.initial_state) {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -27,11 +27,8 @@ protected:
|
||||
private:
|
||||
shared_ptr<aggregate_function> _reducible;
|
||||
private:
|
||||
class aggregate_adapter;
|
||||
static shared_ptr<aggregate_function> make_reducible_variant(stateless_aggregate_function saf);
|
||||
public:
|
||||
class aggregate;
|
||||
|
||||
explicit aggregate_function(stateless_aggregate_function saf, bool reducible_variant = false);
|
||||
|
||||
/**
|
||||
@@ -39,7 +36,7 @@ public:
|
||||
*
|
||||
* @return a new <code>Aggregate</code> instance.
|
||||
*/
|
||||
std::unique_ptr<aggregate> new_aggregate();
|
||||
const stateless_aggregate_function& get_aggregate() const;
|
||||
|
||||
/**
|
||||
* Checks wheather the function can be distributed and is able to reduce states.
|
||||
@@ -70,41 +67,6 @@ public:
|
||||
virtual bool is_aggregate() const override;
|
||||
virtual void print(std::ostream& os) const override;
|
||||
virtual sstring column_name(const std::vector<sstring>& column_names) const override;
|
||||
|
||||
/**
|
||||
* An aggregation operation.
|
||||
*/
|
||||
class aggregate {
|
||||
public:
|
||||
using opt_bytes = aggregate_function::opt_bytes;
|
||||
|
||||
virtual ~aggregate() {}
|
||||
|
||||
/**
|
||||
* Adds the specified input to this aggregate.
|
||||
*
|
||||
* @param values the values to add to the aggregate.
|
||||
*/
|
||||
virtual void add_input(const std::vector<opt_bytes>& values) = 0;
|
||||
|
||||
/**
|
||||
* Computes and returns the aggregate current value.
|
||||
*
|
||||
* @return the aggregate current value.
|
||||
*/
|
||||
virtual opt_bytes compute() = 0;
|
||||
|
||||
virtual void set_accumulator(const opt_bytes& acc) = 0;
|
||||
|
||||
virtual opt_bytes get_accumulator() const = 0;
|
||||
|
||||
virtual void reduce(const opt_bytes& acc) = 0;
|
||||
|
||||
/**
|
||||
* Reset this aggregate.
|
||||
*/
|
||||
virtual void reset() = 0;
|
||||
};
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@@ -5,59 +5,16 @@
|
||||
|
||||
namespace db::functions {
|
||||
|
||||
class aggregate_function::aggregate_adapter : public aggregate {
|
||||
const stateless_aggregate_function& _agg;
|
||||
bytes_opt _state;
|
||||
public:
|
||||
explicit aggregate_adapter(const stateless_aggregate_function& agg)
|
||||
: _agg(agg)
|
||||
, _state(agg.initial_state) {
|
||||
}
|
||||
|
||||
virtual void add_input(const std::vector<opt_bytes>& values) override {
|
||||
std::vector<bytes_opt> state_and_values;
|
||||
state_and_values.reserve(values.size() + 1);
|
||||
state_and_values.push_back(std::move(_state));
|
||||
std::copy(values.begin(), values.end(), std::back_inserter(state_and_values));
|
||||
_state = _agg.aggregation_function->execute(state_and_values);
|
||||
}
|
||||
|
||||
virtual opt_bytes compute() override {
|
||||
if (_agg.state_to_result_function) {
|
||||
std::vector<bytes_opt> state_vec;
|
||||
state_vec.push_back(std::move(_state));
|
||||
return _agg.state_to_result_function->execute(state_vec);
|
||||
} else {
|
||||
return std::move(_state);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void set_accumulator(const opt_bytes& acc) override {
|
||||
_state = acc;
|
||||
}
|
||||
|
||||
virtual opt_bytes get_accumulator() const override {
|
||||
return _state;
|
||||
}
|
||||
|
||||
virtual void reduce(const opt_bytes& acc) override {
|
||||
std::vector<bytes_opt> two_states;
|
||||
two_states.reserve(2);
|
||||
two_states.push_back(std::move(_state));
|
||||
two_states.push_back(acc);
|
||||
_state = _agg.state_reduction_function->execute(two_states);
|
||||
}
|
||||
|
||||
virtual void reset() override {
|
||||
_state = _agg.initial_state;
|
||||
}
|
||||
};
|
||||
|
||||
aggregate_function::aggregate_function(stateless_aggregate_function agg, bool reducible_variant)
|
||||
: _agg(std::move(agg))
|
||||
, _reducible(!reducible_variant ? make_reducible_variant(_agg) : nullptr) {
|
||||
}
|
||||
|
||||
const stateless_aggregate_function&
|
||||
aggregate_function::get_aggregate() const {
|
||||
return _agg;
|
||||
}
|
||||
|
||||
shared_ptr<aggregate_function>
|
||||
aggregate_function::make_reducible_variant(stateless_aggregate_function agg) {
|
||||
if (!agg.state_reduction_function) {
|
||||
@@ -69,11 +26,6 @@ aggregate_function::make_reducible_variant(stateless_aggregate_function agg) {
|
||||
return make_shared<aggregate_function>(new_agg, true);
|
||||
}
|
||||
|
||||
std::unique_ptr<aggregate_function::aggregate>
|
||||
aggregate_function::new_aggregate() {
|
||||
return std::make_unique<aggregate_adapter>(_agg);
|
||||
}
|
||||
|
||||
bool
|
||||
aggregate_function::is_reducible() const {
|
||||
return bool(_agg.state_reduction_function);
|
||||
|
||||
@@ -60,8 +60,7 @@ static std::vector<::shared_ptr<db::functions::aggregate_function>> get_function
|
||||
class forward_aggregates {
|
||||
private:
|
||||
std::vector<::shared_ptr<db::functions::aggregate_function>> _funcs;
|
||||
std::vector<std::unique_ptr<db::functions::aggregate_function::aggregate>> _aggrs;
|
||||
|
||||
std::vector<db::functions::stateless_aggregate_function> _aggrs;
|
||||
public:
|
||||
forward_aggregates(const query::forward_request& request);
|
||||
void merge(query::forward_result& result, query::forward_result&& other);
|
||||
@@ -85,10 +84,10 @@ public:
|
||||
|
||||
forward_aggregates::forward_aggregates(const query::forward_request& request) {
|
||||
_funcs = get_functions(request);
|
||||
std::vector<std::unique_ptr<db::functions::aggregate_function::aggregate>> aggrs;
|
||||
std::vector<db::functions::stateless_aggregate_function> aggrs;
|
||||
|
||||
for (auto& func: _funcs) {
|
||||
aggrs.push_back(func->new_aggregate());
|
||||
aggrs.push_back(func->get_aggregate());
|
||||
}
|
||||
_aggrs = std::move(aggrs);
|
||||
}
|
||||
@@ -113,9 +112,7 @@ void forward_aggregates::merge(query::forward_result &result, query::forward_res
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < _aggrs.size(); i++) {
|
||||
_aggrs[i]->set_accumulator(result.query_results[i]);
|
||||
_aggrs[i]->reduce(std::move(other.query_results[i]));
|
||||
result.query_results[i] = _aggrs[i]->get_accumulator();
|
||||
result.query_results[i] = _aggrs[i].state_reduction_function->execute(std::vector({std::move(result.query_results[i]), std::move(other.query_results[i])}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,7 +123,9 @@ void forward_aggregates::finalize(query::forward_result &result) {
|
||||
// as "WHERE p IN ()". We need to build a fake result with the result
|
||||
// of empty aggregation.
|
||||
for (size_t i = 0; i < _aggrs.size(); i++) {
|
||||
result.query_results.push_back(_aggrs[i]->compute());
|
||||
result.query_results.push_back(_aggrs[i].state_to_result_function
|
||||
? _aggrs[i].state_to_result_function->execute(std::vector({_aggrs[i].initial_state}))
|
||||
: _aggrs[i].initial_state);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -141,8 +140,9 @@ void forward_aggregates::finalize(query::forward_result &result) {
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < _aggrs.size(); i++) {
|
||||
_aggrs[i]->set_accumulator(result.query_results[i]);
|
||||
result.query_results[i] = _aggrs[i]->compute();
|
||||
result.query_results[i] = _aggrs[i].state_to_result_function
|
||||
? _aggrs[i].state_to_result_function->execute(std::vector({std::move(result.query_results[i])}))
|
||||
: result.query_results[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user