cql3/selection, forward_service: use use stateless_aggregate_function directly

Now that stateless_aggregate_function is directly exposed by
aggregate_function, we can use it directly, avoiding the intermediary
aggregate_function::aggregate, which is removed.
This commit is contained in:
Avi Kivity
2023-03-07 22:36:11 +02:00
parent 58eb21aa5d
commit 6977df5539
5 changed files with 36 additions and 113 deletions

View File

@@ -35,11 +35,13 @@ protected:
public:
static shared_ptr<factory> new_factory(shared_ptr<functions::function> fun, shared_ptr<selector_factories> factories);
abstract_function_selector(shared_ptr<functions::function> fun, std::vector<shared_ptr<selector>> arg_selectors)
// If reserve_extra_arg is set, the internal buffer used for holding function argument lists is enlarged to account for
// an aggregate function's accumulator.
abstract_function_selector(shared_ptr<functions::function> fun, std::vector<shared_ptr<selector>> arg_selectors, bool reserve_extra_arg = false)
: _fun(std::move(fun)), _arg_selectors(std::move(arg_selectors)),
_requires_thread(boost::algorithm::any_of(_arg_selectors, [] (auto& s) { return s->requires_thread(); })
|| _fun->requires_thread()) {
_args.resize(_arg_selectors.size());
_args.resize(_arg_selectors.size() + unsigned(reserve_extra_arg));
}
virtual bool requires_thread() const override;
@@ -74,8 +76,10 @@ protected:
shared_ptr<const T> fun() const { return _tfun; }
public:
abstract_function_selector_for(shared_ptr<T> fun, std::vector<shared_ptr<selector>> arg_selectors)
: abstract_function_selector(fun, std::move(arg_selectors))
// If reserve_extra_arg is set, the internal buffer used for holding function argument lists is enlarged to account for
// an aggregate function's accumulator.
abstract_function_selector_for(shared_ptr<T> fun, std::vector<shared_ptr<selector>> arg_selectors, bool reserve_extra_arg = false)
: abstract_function_selector(fun, std::move(arg_selectors), reserve_extra_arg)
, _tfun(dynamic_pointer_cast<T>(fun)) {
}

View File

@@ -17,7 +17,8 @@ namespace cql3 {
namespace selection {
class aggregate_function_selector : public abstract_function_selector_for<functions::aggregate_function> {
std::unique_ptr<functions::aggregate_function::aggregate> _aggregate;
const db::functions::stateless_aggregate_function& _aggregate;
bytes_opt _accumulator;
public:
virtual bool is_aggregate() const override {
return true;
@@ -26,28 +27,32 @@ public:
virtual void add_input(result_set_builder& rs) override {
// Aggregation of aggregation is not supported
size_t m = _arg_selectors.size();
_args[0] = std::move(_accumulator);
for (size_t i = 0; i < m; ++i) {
auto&& s = _arg_selectors[i];
s->add_input(rs);
_args[i] = s->get_output();
_args[i + 1] = s->get_output();
s->reset();
}
_aggregate->add_input(_args);
_accumulator = _aggregate.aggregation_function->execute(_args);
}
virtual bytes_opt get_output() override {
return _aggregate->compute();
return _aggregate.state_to_result_function
? _aggregate.state_to_result_function->execute({std::move(_accumulator)})
: std::move(_accumulator);
}
virtual void reset() override {
_aggregate->reset();
_accumulator = _aggregate.initial_state;
}
aggregate_function_selector(shared_ptr<functions::function> func,
std::vector<shared_ptr<selector>> arg_selectors)
: abstract_function_selector_for<functions::aggregate_function>(
dynamic_pointer_cast<functions::aggregate_function>(func), std::move(arg_selectors))
, _aggregate(fun()->new_aggregate()) {
dynamic_pointer_cast<functions::aggregate_function>(func), std::move(arg_selectors), true)
, _aggregate(fun()->get_aggregate())
, _accumulator(_aggregate.initial_state) {
}
};

View File

@@ -27,11 +27,8 @@ protected:
private:
shared_ptr<aggregate_function> _reducible;
private:
class aggregate_adapter;
static shared_ptr<aggregate_function> make_reducible_variant(stateless_aggregate_function saf);
public:
class aggregate;
explicit aggregate_function(stateless_aggregate_function saf, bool reducible_variant = false);
/**
@@ -39,7 +36,7 @@ public:
*
* @return a new <code>Aggregate</code> instance.
*/
std::unique_ptr<aggregate> new_aggregate();
const stateless_aggregate_function& get_aggregate() const;
/**
* Checks wheather the function can be distributed and is able to reduce states.
@@ -70,41 +67,6 @@ public:
virtual bool is_aggregate() const override;
virtual void print(std::ostream& os) const override;
virtual sstring column_name(const std::vector<sstring>& column_names) const override;
/**
* An aggregation operation.
*/
class aggregate {
public:
using opt_bytes = aggregate_function::opt_bytes;
virtual ~aggregate() {}
/**
* Adds the specified input to this aggregate.
*
* @param values the values to add to the aggregate.
*/
virtual void add_input(const std::vector<opt_bytes>& values) = 0;
/**
* Computes and returns the aggregate current value.
*
* @return the aggregate current value.
*/
virtual opt_bytes compute() = 0;
virtual void set_accumulator(const opt_bytes& acc) = 0;
virtual opt_bytes get_accumulator() const = 0;
virtual void reduce(const opt_bytes& acc) = 0;
/**
* Reset this aggregate.
*/
virtual void reset() = 0;
};
};
}

View File

@@ -5,59 +5,16 @@
namespace db::functions {
class aggregate_function::aggregate_adapter : public aggregate {
const stateless_aggregate_function& _agg;
bytes_opt _state;
public:
explicit aggregate_adapter(const stateless_aggregate_function& agg)
: _agg(agg)
, _state(agg.initial_state) {
}
virtual void add_input(const std::vector<opt_bytes>& values) override {
std::vector<bytes_opt> state_and_values;
state_and_values.reserve(values.size() + 1);
state_and_values.push_back(std::move(_state));
std::copy(values.begin(), values.end(), std::back_inserter(state_and_values));
_state = _agg.aggregation_function->execute(state_and_values);
}
virtual opt_bytes compute() override {
if (_agg.state_to_result_function) {
std::vector<bytes_opt> state_vec;
state_vec.push_back(std::move(_state));
return _agg.state_to_result_function->execute(state_vec);
} else {
return std::move(_state);
}
}
virtual void set_accumulator(const opt_bytes& acc) override {
_state = acc;
}
virtual opt_bytes get_accumulator() const override {
return _state;
}
virtual void reduce(const opt_bytes& acc) override {
std::vector<bytes_opt> two_states;
two_states.reserve(2);
two_states.push_back(std::move(_state));
two_states.push_back(acc);
_state = _agg.state_reduction_function->execute(two_states);
}
virtual void reset() override {
_state = _agg.initial_state;
}
};
aggregate_function::aggregate_function(stateless_aggregate_function agg, bool reducible_variant)
: _agg(std::move(agg))
, _reducible(!reducible_variant ? make_reducible_variant(_agg) : nullptr) {
}
const stateless_aggregate_function&
aggregate_function::get_aggregate() const {
return _agg;
}
shared_ptr<aggregate_function>
aggregate_function::make_reducible_variant(stateless_aggregate_function agg) {
if (!agg.state_reduction_function) {
@@ -69,11 +26,6 @@ aggregate_function::make_reducible_variant(stateless_aggregate_function agg) {
return make_shared<aggregate_function>(new_agg, true);
}
std::unique_ptr<aggregate_function::aggregate>
aggregate_function::new_aggregate() {
return std::make_unique<aggregate_adapter>(_agg);
}
bool
aggregate_function::is_reducible() const {
return bool(_agg.state_reduction_function);

View File

@@ -60,8 +60,7 @@ static std::vector<::shared_ptr<db::functions::aggregate_function>> get_function
class forward_aggregates {
private:
std::vector<::shared_ptr<db::functions::aggregate_function>> _funcs;
std::vector<std::unique_ptr<db::functions::aggregate_function::aggregate>> _aggrs;
std::vector<db::functions::stateless_aggregate_function> _aggrs;
public:
forward_aggregates(const query::forward_request& request);
void merge(query::forward_result& result, query::forward_result&& other);
@@ -85,10 +84,10 @@ public:
forward_aggregates::forward_aggregates(const query::forward_request& request) {
_funcs = get_functions(request);
std::vector<std::unique_ptr<db::functions::aggregate_function::aggregate>> aggrs;
std::vector<db::functions::stateless_aggregate_function> aggrs;
for (auto& func: _funcs) {
aggrs.push_back(func->new_aggregate());
aggrs.push_back(func->get_aggregate());
}
_aggrs = std::move(aggrs);
}
@@ -113,9 +112,7 @@ void forward_aggregates::merge(query::forward_result &result, query::forward_res
}
for (size_t i = 0; i < _aggrs.size(); i++) {
_aggrs[i]->set_accumulator(result.query_results[i]);
_aggrs[i]->reduce(std::move(other.query_results[i]));
result.query_results[i] = _aggrs[i]->get_accumulator();
result.query_results[i] = _aggrs[i].state_reduction_function->execute(std::vector({std::move(result.query_results[i]), std::move(other.query_results[i])}));
}
}
@@ -126,7 +123,9 @@ void forward_aggregates::finalize(query::forward_result &result) {
// as "WHERE p IN ()". We need to build a fake result with the result
// of empty aggregation.
for (size_t i = 0; i < _aggrs.size(); i++) {
result.query_results.push_back(_aggrs[i]->compute());
result.query_results.push_back(_aggrs[i].state_to_result_function
? _aggrs[i].state_to_result_function->execute(std::vector({_aggrs[i].initial_state}))
: _aggrs[i].initial_state);
}
return;
}
@@ -141,8 +140,9 @@ void forward_aggregates::finalize(query::forward_result &result) {
}
for (size_t i = 0; i < _aggrs.size(); i++) {
_aggrs[i]->set_accumulator(result.query_results[i]);
result.query_results[i] = _aggrs[i]->compute();
result.query_results[i] = _aggrs[i].state_to_result_function
? _aggrs[i].state_to_result_function->execute(std::vector({std::move(result.query_results[i])}))
: result.query_results[i];
}
}