db: functions: fold stateless_aggregate_function_adapter into aggregate_function

Now that all aggregate functions are derived from
stateless_aggregate_function_adapter, we can just fold its functionality
into the base class. This exposes stateless_aggregate_function to
all users of aggregate_function, so they can begin to benefit from
the transformation, though this patch doesn't touch those users.

The aggregate_function base class is partiallly devirtualized since
there is just a single implementation now.
This commit is contained in:
Avi Kivity
2023-03-07 21:05:45 +02:00
parent 68529896aa
commit 58eb21aa5d
5 changed files with 49 additions and 54 deletions

View File

@@ -12,7 +12,7 @@
#include "types/types.hh"
#include "types/tuple.hh"
#include "cql3/functions/scalar_function.hh"
#include "db/functions/stateless_aggregate_function.hh"
#include "db/functions/aggregate_function.hh"
#include "cql3/util.hh"
#include "utils/big_decimal.hh"
#include "aggregate_fcts.hh"
@@ -163,7 +163,7 @@ static
shared_ptr<aggregate_function>
make_sum_function() {
using Acc = accumulator_for<Type>;
return make_shared<db::functions::stateless_aggregate_function_adapter>(
return make_shared<db::functions::aggregate_function>(
db::functions::stateless_aggregate_function{
.name = function_name::native_function("sum"),
.state_type = data_type_for<accumulator_for<Type>>(),
@@ -198,7 +198,7 @@ shared_ptr<aggregate_function>
make_avg_function() {
using sum_type = accumulator_for<Type>;
auto accumulator_tuple_type = tuple_type_impl::get_instance({data_type_for<sum_type>(), data_type_for<int64_t>()});
return make_shared<db::functions::stateless_aggregate_function_adapter>(
return make_shared<db::functions::aggregate_function>(
db::functions::stateless_aggregate_function{
.name = function_name::native_function("avg"),
.state_type = accumulator_tuple_type,
@@ -292,7 +292,7 @@ struct aggregate_type_for<time_native_type> {
*/
template <typename Type>
static shared_ptr<aggregate_function> make_count_function() {
return make_shared<db::functions::stateless_aggregate_function_adapter>(
return make_shared<db::functions::aggregate_function>(
db::functions::stateless_aggregate_function{
.name = function_name::native_function("count"),
.state_type = long_type,
@@ -333,7 +333,7 @@ static data_type uda_return_type(const ::shared_ptr<scalar_function>& ffunc, con
}
user_aggregate::user_aggregate(function_name fname, bytes_opt initcond, ::shared_ptr<scalar_function> sfunc, ::shared_ptr<scalar_function> reducefunc, ::shared_ptr<scalar_function> finalfunc)
: stateless_aggregate_function_adapter(db::functions::stateless_aggregate_function{
: aggregate_function(db::functions::stateless_aggregate_function{
.name = fname,
.state_type = sfunc->return_type(),
.result_type = finalfunc ? finalfunc->return_type() : sfunc->return_type(),
@@ -379,7 +379,7 @@ std::ostream& user_aggregate::describe(std::ostream& os) const {
shared_ptr<aggregate_function>
aggregate_fcts::make_count_rows_function() {
return make_shared<db::functions::stateless_aggregate_function_adapter>(
return make_shared<db::functions::aggregate_function>(
db::functions::stateless_aggregate_function{
.name = function_name::native_function(COUNT_ROWS_FUNCTION_NAME),
.column_name_override = "count",
@@ -412,7 +412,7 @@ aggregate_fcts::make_max_function(data_type io_type) {
}
return std::max(*args[0], *args[1], io_type->as_less_comparator());
});
return ::make_shared<db::functions::stateless_aggregate_function_adapter>(
return ::make_shared<db::functions::aggregate_function>(
db::functions::stateless_aggregate_function{
.name = function_name::native_function("max"),
.state_type = io_type,
@@ -440,7 +440,7 @@ aggregate_fcts::make_min_function(data_type io_type) {
}
return std::min(*args[0], *args[1], io_type->as_less_comparator());
});
return ::make_shared<db::functions::stateless_aggregate_function_adapter>(
return ::make_shared<db::functions::aggregate_function>(
db::functions::stateless_aggregate_function{
.name = function_name::native_function("min"),
.state_type = io_type,

View File

@@ -17,7 +17,7 @@
namespace cql3 {
namespace functions {
class user_aggregate : public db::functions::stateless_aggregate_function_adapter, public data_dictionary::keyspace_element {
class user_aggregate : public db::functions::aggregate_function, public data_dictionary::keyspace_element {
public:
user_aggregate(function_name fname, bytes_opt initcond, ::shared_ptr<scalar_function> sfunc, ::shared_ptr<scalar_function> reducefunc, ::shared_ptr<scalar_function> finalfunc);
bool has_finalfunc() const;

View File

@@ -11,6 +11,7 @@
#pragma once
#include "function.hh"
#include "stateless_aggregate_function.hh"
#include <optional>
namespace db {
@@ -21,22 +22,31 @@ namespace functions {
* Performs a calculation on a set of values and return a single value.
*/
class aggregate_function : public virtual function {
protected:
stateless_aggregate_function _agg;
private:
shared_ptr<aggregate_function> _reducible;
private:
class aggregate_adapter;
static shared_ptr<aggregate_function> make_reducible_variant(stateless_aggregate_function saf);
public:
class aggregate;
explicit aggregate_function(stateless_aggregate_function saf, bool reducible_variant = false);
/**
* Creates a new <code>Aggregate</code> instance.
*
* @return a new <code>Aggregate</code> instance.
*/
virtual std::unique_ptr<aggregate> new_aggregate() = 0;
std::unique_ptr<aggregate> new_aggregate();
/**
* Checks wheather the function can be distributed and is able to reduce states.
*
* @return <code>true</code> if the function is reducible, <code>false</code> otherwise.
*/
virtual bool is_reducible() const = 0;
bool is_reducible() const;
/**
* Creates a <code>Aggregate Function</code> that can be reduced.
@@ -49,7 +59,17 @@ public:
*
* @return a reducible <code>Aggregate Function</code>.
*/
virtual ::shared_ptr<aggregate_function> reducible_aggregate_function() = 0;
::shared_ptr<aggregate_function> reducible_aggregate_function();
virtual const function_name& name() const override;
virtual const std::vector<data_type>& arg_types() const override;
virtual const data_type& return_type() const override;
virtual bool is_pure() const override;
virtual bool is_native() const override;
virtual bool requires_thread() const override;
virtual bool is_aggregate() const override;
virtual void print(std::ostream& os) const override;
virtual sstring column_name(const std::vector<sstring>& column_names) const override;
/**
* An aggregation operation.

View File

@@ -1,11 +1,11 @@
// Copyright (C) 2023-present ScyllaDB
// SPDX-License-Identifier: AGPL-3.0-or-later
#include "stateless_aggregate_function.hh"
#include "aggregate_function.hh"
namespace db::functions {
class stateless_aggregate_function_adapter::aggregate_adapter : public aggregate {
class aggregate_function::aggregate_adapter : public aggregate {
const stateless_aggregate_function& _agg;
bytes_opt _state;
public:
@@ -53,85 +53,85 @@ public:
}
};
stateless_aggregate_function_adapter::stateless_aggregate_function_adapter(stateless_aggregate_function agg, bool reducible_variant)
aggregate_function::aggregate_function(stateless_aggregate_function agg, bool reducible_variant)
: _agg(std::move(agg))
, _reducible(!reducible_variant ? make_reducible_variant(_agg) : nullptr) {
}
shared_ptr<aggregate_function>
stateless_aggregate_function_adapter::make_reducible_variant(stateless_aggregate_function agg) {
aggregate_function::make_reducible_variant(stateless_aggregate_function agg) {
if (!agg.state_reduction_function) {
return nullptr;
}
auto new_agg = agg;
new_agg.state_to_result_function = nullptr;
new_agg.result_type = new_agg.aggregation_function->return_type();
return make_shared<stateless_aggregate_function_adapter>(new_agg, true);
return make_shared<aggregate_function>(new_agg, true);
}
std::unique_ptr<aggregate_function::aggregate>
stateless_aggregate_function_adapter::new_aggregate() {
aggregate_function::new_aggregate() {
return std::make_unique<aggregate_adapter>(_agg);
}
bool
stateless_aggregate_function_adapter::is_reducible() const {
aggregate_function::is_reducible() const {
return bool(_agg.state_reduction_function);
}
::shared_ptr<aggregate_function>
stateless_aggregate_function_adapter::reducible_aggregate_function() {
aggregate_function::reducible_aggregate_function() {
return _reducible;
}
const function_name&
stateless_aggregate_function_adapter::name() const {
aggregate_function::name() const {
return _agg.name;
}
const std::vector<data_type>&
stateless_aggregate_function_adapter::arg_types() const {
aggregate_function::arg_types() const {
return _agg.argument_types;
}
const data_type&
stateless_aggregate_function_adapter::return_type() const {
aggregate_function::return_type() const {
return _agg.result_type;
}
bool
stateless_aggregate_function_adapter::is_pure() const {
aggregate_function::is_pure() const {
return _agg.aggregation_function->is_pure()
&& (!_agg.state_to_result_function || _agg.state_to_result_function->is_pure())
&& (!_agg.state_reduction_function || _agg.state_reduction_function->is_pure());
}
bool
stateless_aggregate_function_adapter::is_native() const {
aggregate_function::is_native() const {
return _agg.aggregation_function->is_native()
&& (!_agg.state_to_result_function || _agg.state_to_result_function->is_native())
&& (!_agg.state_reduction_function || _agg.state_reduction_function->is_native());
}
bool
stateless_aggregate_function_adapter::requires_thread() const {
aggregate_function::requires_thread() const {
return _agg.aggregation_function->requires_thread()
|| (_agg.state_to_result_function && _agg.state_to_result_function->requires_thread())
|| (_agg.state_reduction_function && _agg.state_reduction_function->requires_thread());
}
bool
stateless_aggregate_function_adapter::is_aggregate() const {
aggregate_function::is_aggregate() const {
return true;
}
void
stateless_aggregate_function_adapter::print(std::ostream& os) const {
aggregate_function::print(std::ostream& os) const {
os << name();
}
sstring
stateless_aggregate_function_adapter::column_name(const std::vector<sstring>& column_names) const {
aggregate_function::column_name(const std::vector<sstring>& column_names) const {
if (_agg.column_name_override) {
return *_agg.column_name_override;
}

View File

@@ -3,7 +3,6 @@
#pragma once
#include "aggregate_function.hh"
#include "scalar_function.hh"
#include "function_name.hh"
#include <optional>
@@ -33,28 +32,4 @@ struct stateless_aggregate_function final {
shared_ptr<scalar_function> state_reduction_function;
};
class stateless_aggregate_function_adapter : public aggregate_function {
protected:
stateless_aggregate_function _agg;
private:
shared_ptr<aggregate_function> _reducible;
private:
class aggregate_adapter;
static shared_ptr<aggregate_function> make_reducible_variant(stateless_aggregate_function saf);
public:
explicit stateless_aggregate_function_adapter(stateless_aggregate_function saf, bool reducible_variant = false);
virtual std::unique_ptr<aggregate> new_aggregate() override;
virtual bool is_reducible() const override;
virtual ::shared_ptr<aggregate_function> reducible_aggregate_function() override;
virtual const function_name& name() const override;
virtual const std::vector<data_type>& arg_types() const override;
virtual const data_type& return_type() const override;
virtual bool is_pure() const override;
virtual bool is_native() const override;
virtual bool requires_thread() const override;
virtual bool is_aggregate() const override;
virtual void print(std::ostream& os) const override;
virtual sstring column_name(const std::vector<sstring>& column_names) const override;
};
}