cql3: functions: add set_intersection()

Given two sets of equivalent types, return the set
intersection.

This is a generic function which adapts to the actual
input type.

A unit test is added.

Closes scylladb/scylladb#22763
This commit is contained in:
Avi Kivity
2024-12-31 22:23:27 +02:00
committed by Nadav Har'El
parent 4a2654865d
commit 81821d26cd
2 changed files with 118 additions and 0 deletions

View File

@@ -22,6 +22,8 @@
#include "cql3/prepare_context.hh"
#include "user_aggregate.hh"
#include "cql3/expr/expression.hh"
#include "types/set.hh"
#include "types/listlike_partial_deserializing_iterator.hh"
#include "error_injection_fcts.hh"
@@ -42,6 +44,8 @@ namespace functions {
logging::logger log("cql3_fuctions");
static auto SET_INTERSECTION_FUNCTION_NAME = function_name::native_function("set_intersection");
bool abstract_function::requires_thread() const { return false; }
bool as_json_function::requires_thread() const { return false; }
@@ -298,6 +302,81 @@ static shared_ptr<function> get_dynamic_aggregate(const function_name &name, con
return {};
}
static
shared_ptr<function>
get_set_intersection_function(data_dictionary::database db,
const sstring& keyspace,
const function_name& name,
const std::vector<shared_ptr<assignment_testable>>& provided_args,
const sstring& receiver_ks,
std::optional<const std::string_view> receiver_cf,
const column_specification* receiver) {
if (provided_args.size() != 2) {
throw exceptions::invalid_request_exception("set_intersection() accepts 2 arguments only");
}
auto known_arg_types = provided_args
| std::views::filter([] (const shared_ptr<assignment_testable>& arg) { return arg->assignment_testable_type_opt().has_value(); })
| std::views::transform([] (const shared_ptr<assignment_testable>& arg) -> data_type { return *arg->assignment_testable_type_opt(); })
| std::ranges::to<std::vector>();
if (known_arg_types.empty()) {
throw exceptions::invalid_request_exception("set_intersection() can only be called if at least one argument type is known");
}
auto known_set_types = known_arg_types
| std::views::transform([] (const data_type& arg_type) {
return dynamic_pointer_cast<const set_type_impl>(arg_type);
})
| std::ranges::to<std::vector>();
if (!std::ranges::all_of(known_set_types, [] (data_type t) { return t != nullptr; })) {
throw exceptions::invalid_request_exception("set_intersection() can only be called if both arguments are of set type");
}
// Normalize everything to be a frozen set
for (auto& set_type : known_set_types) {
set_type = dynamic_pointer_cast<const set_type_impl>(set_type->freeze());
}
auto unique_remove = std::ranges::unique(known_set_types);
known_set_types.erase(unique_remove.begin(), unique_remove.end());
if (known_set_types.size() != 1) {
throw exceptions::invalid_request_exception(fmt::format("set_intersection() can only be called if both arguments are of the same set type: {}",
known_set_types | std::views::transform(&abstract_type::name)));
}
auto set_type = known_set_types.front();
auto element_type = set_type->get_elements_type();
return make_native_scalar_function<true>("set_intersection", set_type, {set_type, set_type},
[set_type, element_type] (std::span<const bytes_opt> parameters) -> bytes_opt {
if (!parameters[0].has_value() || !parameters[1].has_value()) {
return {};
}
auto set_as_range = [&] (const bytes_opt& serialized_set, managed_bytes_view& buffer) {
buffer = managed_bytes_view(*serialized_set);
return std::ranges::subrange<listlike_partial_deserializing_iterator>(
listlike_partial_deserializing_iterator::begin(buffer),
listlike_partial_deserializing_iterator::end(buffer));
};
auto element_less = [&] (managed_bytes_view_opt v1, managed_bytes_view_opt v2) {
if (!v1 || !v2) {
on_internal_error(log, "set_intersection: unexpected null value");
}
return element_type->compare(*v1, *v2) < 0;
};
std::vector<managed_bytes_view_opt> result_vector;
managed_bytes_view buffer1, buffer2;
std::ranges::set_intersection(
set_as_range(parameters[0], buffer1),
set_as_range(parameters[1], buffer2),
std::back_inserter(result_vector),
element_less);
return to_bytes(set_type->pack_fragmented(result_vector.begin(), result_vector.end(), result_vector.size()));
});
}
shared_ptr<function>
functions::get(data_dictionary::database db,
const sstring& keyspace,
@@ -359,6 +438,14 @@ functions::get(data_dictionary::database db,
return make_from_json_function(db, keyspace, receiver->type);
}
// FIXME: add proper support for generic functions
if (name.has_keyspace()
? name == SET_INTERSECTION_FUNCTION_NAME
: name.name == SET_INTERSECTION_FUNCTION_NAME.name) {
return get_set_intersection_function(db, keyspace, name, provided_args, receiver_ks, receiver_cf, receiver);
}
auto aggr_fun = get_dynamic_aggregate(name, provided_args);
if (aggr_fun) {
return aggr_fun;

View File

@@ -19,6 +19,13 @@ def table1(cql, test_keyspace):
with new_test_table(cql, test_keyspace, "p int, i int, g bigint, b blob, s text, t timestamp, u timeuuid, d date, PRIMARY KEY (p)") as table:
yield table
@pytest.fixture(scope="module")
def tbl_set(cql, test_keyspace):
with new_test_table(cql, test_keyspace,
"p int PRIMARY KEY, s1 set<int>, s2 set<int>,"
"s3 set<tinyint>, m map<int, int>, s4 frozen<set<int>>") as table:
yield table
# Check that a function that can take a column name as a parameter, can also
# take a constant. This feature is barely useful for WHERE clauses, and
# even less useful for selectors, but should be allowed for both.
@@ -161,3 +168,27 @@ def test_totimestamp_date_extreme(cql, table1):
# and work fine.
cql.execute(f"INSERT INTO {table1} (p, d) VALUES ({p}, {2**30})")
cql.execute(f"SELECT totimestamp(d) FROM {table1} WHERE p={p}")
# Test set_intersection() function. Not supported in Cassandra.
def test_set_intersection_fn(cql, tbl_set, scylla_only):
p1 = unique_key_int()
p2 = unique_key_int()
cql.execute(f"INSERT INTO {tbl_set} (p, s1, s2, s3, m, s4) VALUES ({p1}, {{1,2,3}}, {{2,3,4}}, {{1}}, {{1:2, 2:3}}, {{-1,2}})")
cql.execute(f"INSERT INTO {tbl_set} (p, s1, s2, s3, m, s4) VALUES ({p2}, {{1,2,3}}, NULL, {{1}}, {{1:2, 2:3}}, {{-1,2}})")
# Normal intersection.
assert [(set([2,3]),)] == list(cql.execute(f"SELECT set_intersection(s1, s2) FROM {tbl_set} WHERE p={p1}"))
# Intersecting with NULL.
assert [(None,)] == list(cql.execute(f"SELECT set_intersection(s1, s2) FROM {tbl_set} WHERE p={p2}"))
# Frozen and non-frozen.
assert [(set([2]),)] == list(cql.execute(f"SELECT set_intersection(s1, s4) FROM {tbl_set} WHERE p={p1}"))
# Nesting
assert [(set([2]),)] == list(cql.execute(f"SELECT set_intersection(s1, set_intersection(s2, s4)) FROM {tbl_set} WHERE p={p1}"))
# Some error cases
with pytest.raises(InvalidRequest, match='accepts 2 arguments'):
cql.execute(f"SELECT set_intersection(s1, s2, s3) FROM {tbl_set} WHERE p={p1}")
with pytest.raises(InvalidRequest, match='accepts 2 arguments'):
cql.execute(f"SELECT set_intersection(s1) FROM {tbl_set} WHERE p={p1}")
with pytest.raises(InvalidRequest, match='both arguments are of set type'):
cql.execute(f"SELECT set_intersection(s1, p) FROM {tbl_set} WHERE p={p1}")
with pytest.raises(InvalidRequest, match='both arguments are of the same set type'):
cql.execute(f"SELECT set_intersection(s1, s3) FROM {tbl_set} WHERE p={p1}")