db: functions: fold stateless_aggregate_function_adapter into aggregate_function
Now that all aggregate functions are derived from stateless_aggregate_function_adapter, we can just fold its functionality into the base class. This exposes stateless_aggregate_function to all users of aggregate_function, so they can begin to benefit from the transformation, though this patch doesn't touch those users. The aggregate_function base class is partiallly devirtualized since there is just a single implementation now.
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
#include "types/types.hh"
|
||||
#include "types/tuple.hh"
|
||||
#include "cql3/functions/scalar_function.hh"
|
||||
#include "db/functions/stateless_aggregate_function.hh"
|
||||
#include "db/functions/aggregate_function.hh"
|
||||
#include "cql3/util.hh"
|
||||
#include "utils/big_decimal.hh"
|
||||
#include "aggregate_fcts.hh"
|
||||
@@ -163,7 +163,7 @@ static
|
||||
shared_ptr<aggregate_function>
|
||||
make_sum_function() {
|
||||
using Acc = accumulator_for<Type>;
|
||||
return make_shared<db::functions::stateless_aggregate_function_adapter>(
|
||||
return make_shared<db::functions::aggregate_function>(
|
||||
db::functions::stateless_aggregate_function{
|
||||
.name = function_name::native_function("sum"),
|
||||
.state_type = data_type_for<accumulator_for<Type>>(),
|
||||
@@ -198,7 +198,7 @@ shared_ptr<aggregate_function>
|
||||
make_avg_function() {
|
||||
using sum_type = accumulator_for<Type>;
|
||||
auto accumulator_tuple_type = tuple_type_impl::get_instance({data_type_for<sum_type>(), data_type_for<int64_t>()});
|
||||
return make_shared<db::functions::stateless_aggregate_function_adapter>(
|
||||
return make_shared<db::functions::aggregate_function>(
|
||||
db::functions::stateless_aggregate_function{
|
||||
.name = function_name::native_function("avg"),
|
||||
.state_type = accumulator_tuple_type,
|
||||
@@ -292,7 +292,7 @@ struct aggregate_type_for<time_native_type> {
|
||||
*/
|
||||
template <typename Type>
|
||||
static shared_ptr<aggregate_function> make_count_function() {
|
||||
return make_shared<db::functions::stateless_aggregate_function_adapter>(
|
||||
return make_shared<db::functions::aggregate_function>(
|
||||
db::functions::stateless_aggregate_function{
|
||||
.name = function_name::native_function("count"),
|
||||
.state_type = long_type,
|
||||
@@ -333,7 +333,7 @@ static data_type uda_return_type(const ::shared_ptr<scalar_function>& ffunc, con
|
||||
}
|
||||
|
||||
user_aggregate::user_aggregate(function_name fname, bytes_opt initcond, ::shared_ptr<scalar_function> sfunc, ::shared_ptr<scalar_function> reducefunc, ::shared_ptr<scalar_function> finalfunc)
|
||||
: stateless_aggregate_function_adapter(db::functions::stateless_aggregate_function{
|
||||
: aggregate_function(db::functions::stateless_aggregate_function{
|
||||
.name = fname,
|
||||
.state_type = sfunc->return_type(),
|
||||
.result_type = finalfunc ? finalfunc->return_type() : sfunc->return_type(),
|
||||
@@ -379,7 +379,7 @@ std::ostream& user_aggregate::describe(std::ostream& os) const {
|
||||
|
||||
shared_ptr<aggregate_function>
|
||||
aggregate_fcts::make_count_rows_function() {
|
||||
return make_shared<db::functions::stateless_aggregate_function_adapter>(
|
||||
return make_shared<db::functions::aggregate_function>(
|
||||
db::functions::stateless_aggregate_function{
|
||||
.name = function_name::native_function(COUNT_ROWS_FUNCTION_NAME),
|
||||
.column_name_override = "count",
|
||||
@@ -412,7 +412,7 @@ aggregate_fcts::make_max_function(data_type io_type) {
|
||||
}
|
||||
return std::max(*args[0], *args[1], io_type->as_less_comparator());
|
||||
});
|
||||
return ::make_shared<db::functions::stateless_aggregate_function_adapter>(
|
||||
return ::make_shared<db::functions::aggregate_function>(
|
||||
db::functions::stateless_aggregate_function{
|
||||
.name = function_name::native_function("max"),
|
||||
.state_type = io_type,
|
||||
@@ -440,7 +440,7 @@ aggregate_fcts::make_min_function(data_type io_type) {
|
||||
}
|
||||
return std::min(*args[0], *args[1], io_type->as_less_comparator());
|
||||
});
|
||||
return ::make_shared<db::functions::stateless_aggregate_function_adapter>(
|
||||
return ::make_shared<db::functions::aggregate_function>(
|
||||
db::functions::stateless_aggregate_function{
|
||||
.name = function_name::native_function("min"),
|
||||
.state_type = io_type,
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
namespace cql3 {
|
||||
namespace functions {
|
||||
|
||||
class user_aggregate : public db::functions::stateless_aggregate_function_adapter, public data_dictionary::keyspace_element {
|
||||
class user_aggregate : public db::functions::aggregate_function, public data_dictionary::keyspace_element {
|
||||
public:
|
||||
user_aggregate(function_name fname, bytes_opt initcond, ::shared_ptr<scalar_function> sfunc, ::shared_ptr<scalar_function> reducefunc, ::shared_ptr<scalar_function> finalfunc);
|
||||
bool has_finalfunc() const;
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "function.hh"
|
||||
#include "stateless_aggregate_function.hh"
|
||||
#include <optional>
|
||||
|
||||
namespace db {
|
||||
@@ -21,22 +22,31 @@ namespace functions {
|
||||
* Performs a calculation on a set of values and return a single value.
|
||||
*/
|
||||
class aggregate_function : public virtual function {
|
||||
protected:
|
||||
stateless_aggregate_function _agg;
|
||||
private:
|
||||
shared_ptr<aggregate_function> _reducible;
|
||||
private:
|
||||
class aggregate_adapter;
|
||||
static shared_ptr<aggregate_function> make_reducible_variant(stateless_aggregate_function saf);
|
||||
public:
|
||||
class aggregate;
|
||||
|
||||
explicit aggregate_function(stateless_aggregate_function saf, bool reducible_variant = false);
|
||||
|
||||
/**
|
||||
* Creates a new <code>Aggregate</code> instance.
|
||||
*
|
||||
* @return a new <code>Aggregate</code> instance.
|
||||
*/
|
||||
virtual std::unique_ptr<aggregate> new_aggregate() = 0;
|
||||
std::unique_ptr<aggregate> new_aggregate();
|
||||
|
||||
/**
|
||||
* Checks wheather the function can be distributed and is able to reduce states.
|
||||
*
|
||||
* @return <code>true</code> if the function is reducible, <code>false</code> otherwise.
|
||||
*/
|
||||
virtual bool is_reducible() const = 0;
|
||||
bool is_reducible() const;
|
||||
|
||||
/**
|
||||
* Creates a <code>Aggregate Function</code> that can be reduced.
|
||||
@@ -49,7 +59,17 @@ public:
|
||||
*
|
||||
* @return a reducible <code>Aggregate Function</code>.
|
||||
*/
|
||||
virtual ::shared_ptr<aggregate_function> reducible_aggregate_function() = 0;
|
||||
::shared_ptr<aggregate_function> reducible_aggregate_function();
|
||||
|
||||
virtual const function_name& name() const override;
|
||||
virtual const std::vector<data_type>& arg_types() const override;
|
||||
virtual const data_type& return_type() const override;
|
||||
virtual bool is_pure() const override;
|
||||
virtual bool is_native() const override;
|
||||
virtual bool requires_thread() const override;
|
||||
virtual bool is_aggregate() const override;
|
||||
virtual void print(std::ostream& os) const override;
|
||||
virtual sstring column_name(const std::vector<sstring>& column_names) const override;
|
||||
|
||||
/**
|
||||
* An aggregation operation.
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
// Copyright (C) 2023-present ScyllaDB
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
#include "stateless_aggregate_function.hh"
|
||||
#include "aggregate_function.hh"
|
||||
|
||||
namespace db::functions {
|
||||
|
||||
class stateless_aggregate_function_adapter::aggregate_adapter : public aggregate {
|
||||
class aggregate_function::aggregate_adapter : public aggregate {
|
||||
const stateless_aggregate_function& _agg;
|
||||
bytes_opt _state;
|
||||
public:
|
||||
@@ -53,85 +53,85 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
stateless_aggregate_function_adapter::stateless_aggregate_function_adapter(stateless_aggregate_function agg, bool reducible_variant)
|
||||
aggregate_function::aggregate_function(stateless_aggregate_function agg, bool reducible_variant)
|
||||
: _agg(std::move(agg))
|
||||
, _reducible(!reducible_variant ? make_reducible_variant(_agg) : nullptr) {
|
||||
}
|
||||
|
||||
shared_ptr<aggregate_function>
|
||||
stateless_aggregate_function_adapter::make_reducible_variant(stateless_aggregate_function agg) {
|
||||
aggregate_function::make_reducible_variant(stateless_aggregate_function agg) {
|
||||
if (!agg.state_reduction_function) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_agg = agg;
|
||||
new_agg.state_to_result_function = nullptr;
|
||||
new_agg.result_type = new_agg.aggregation_function->return_type();
|
||||
return make_shared<stateless_aggregate_function_adapter>(new_agg, true);
|
||||
return make_shared<aggregate_function>(new_agg, true);
|
||||
}
|
||||
|
||||
std::unique_ptr<aggregate_function::aggregate>
|
||||
stateless_aggregate_function_adapter::new_aggregate() {
|
||||
aggregate_function::new_aggregate() {
|
||||
return std::make_unique<aggregate_adapter>(_agg);
|
||||
}
|
||||
|
||||
bool
|
||||
stateless_aggregate_function_adapter::is_reducible() const {
|
||||
aggregate_function::is_reducible() const {
|
||||
return bool(_agg.state_reduction_function);
|
||||
}
|
||||
|
||||
::shared_ptr<aggregate_function>
|
||||
stateless_aggregate_function_adapter::reducible_aggregate_function() {
|
||||
aggregate_function::reducible_aggregate_function() {
|
||||
return _reducible;
|
||||
}
|
||||
|
||||
const function_name&
|
||||
stateless_aggregate_function_adapter::name() const {
|
||||
aggregate_function::name() const {
|
||||
return _agg.name;
|
||||
}
|
||||
|
||||
const std::vector<data_type>&
|
||||
stateless_aggregate_function_adapter::arg_types() const {
|
||||
aggregate_function::arg_types() const {
|
||||
return _agg.argument_types;
|
||||
}
|
||||
|
||||
const data_type&
|
||||
stateless_aggregate_function_adapter::return_type() const {
|
||||
aggregate_function::return_type() const {
|
||||
return _agg.result_type;
|
||||
}
|
||||
|
||||
bool
|
||||
stateless_aggregate_function_adapter::is_pure() const {
|
||||
aggregate_function::is_pure() const {
|
||||
return _agg.aggregation_function->is_pure()
|
||||
&& (!_agg.state_to_result_function || _agg.state_to_result_function->is_pure())
|
||||
&& (!_agg.state_reduction_function || _agg.state_reduction_function->is_pure());
|
||||
}
|
||||
|
||||
bool
|
||||
stateless_aggregate_function_adapter::is_native() const {
|
||||
aggregate_function::is_native() const {
|
||||
return _agg.aggregation_function->is_native()
|
||||
&& (!_agg.state_to_result_function || _agg.state_to_result_function->is_native())
|
||||
&& (!_agg.state_reduction_function || _agg.state_reduction_function->is_native());
|
||||
}
|
||||
|
||||
bool
|
||||
stateless_aggregate_function_adapter::requires_thread() const {
|
||||
aggregate_function::requires_thread() const {
|
||||
return _agg.aggregation_function->requires_thread()
|
||||
|| (_agg.state_to_result_function && _agg.state_to_result_function->requires_thread())
|
||||
|| (_agg.state_reduction_function && _agg.state_reduction_function->requires_thread());
|
||||
}
|
||||
|
||||
bool
|
||||
stateless_aggregate_function_adapter::is_aggregate() const {
|
||||
aggregate_function::is_aggregate() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
stateless_aggregate_function_adapter::print(std::ostream& os) const {
|
||||
aggregate_function::print(std::ostream& os) const {
|
||||
os << name();
|
||||
}
|
||||
|
||||
sstring
|
||||
stateless_aggregate_function_adapter::column_name(const std::vector<sstring>& column_names) const {
|
||||
aggregate_function::column_name(const std::vector<sstring>& column_names) const {
|
||||
if (_agg.column_name_override) {
|
||||
return *_agg.column_name_override;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "aggregate_function.hh"
|
||||
#include "scalar_function.hh"
|
||||
#include "function_name.hh"
|
||||
#include <optional>
|
||||
@@ -33,28 +32,4 @@ struct stateless_aggregate_function final {
|
||||
shared_ptr<scalar_function> state_reduction_function;
|
||||
};
|
||||
|
||||
class stateless_aggregate_function_adapter : public aggregate_function {
|
||||
protected:
|
||||
stateless_aggregate_function _agg;
|
||||
private:
|
||||
shared_ptr<aggregate_function> _reducible;
|
||||
private:
|
||||
class aggregate_adapter;
|
||||
static shared_ptr<aggregate_function> make_reducible_variant(stateless_aggregate_function saf);
|
||||
public:
|
||||
explicit stateless_aggregate_function_adapter(stateless_aggregate_function saf, bool reducible_variant = false);
|
||||
virtual std::unique_ptr<aggregate> new_aggregate() override;
|
||||
virtual bool is_reducible() const override;
|
||||
virtual ::shared_ptr<aggregate_function> reducible_aggregate_function() override;
|
||||
virtual const function_name& name() const override;
|
||||
virtual const std::vector<data_type>& arg_types() const override;
|
||||
virtual const data_type& return_type() const override;
|
||||
virtual bool is_pure() const override;
|
||||
virtual bool is_native() const override;
|
||||
virtual bool requires_thread() const override;
|
||||
virtual bool is_aggregate() const override;
|
||||
virtual void print(std::ostream& os) const override;
|
||||
virtual sstring column_name(const std::vector<sstring>& column_names) const override;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user