cql3: functions: add set_intersection()
Given two sets of equivalent types, return the set intersection. This is a generic function which adapts to the actual input type. A unit test is added. Closes scylladb/scylladb#22763
This commit is contained in:
@@ -22,6 +22,8 @@
|
||||
#include "cql3/prepare_context.hh"
|
||||
#include "user_aggregate.hh"
|
||||
#include "cql3/expr/expression.hh"
|
||||
#include "types/set.hh"
|
||||
#include "types/listlike_partial_deserializing_iterator.hh"
|
||||
|
||||
#include "error_injection_fcts.hh"
|
||||
|
||||
@@ -42,6 +44,8 @@ namespace functions {
|
||||
|
||||
logging::logger log("cql3_fuctions");
|
||||
|
||||
static auto SET_INTERSECTION_FUNCTION_NAME = function_name::native_function("set_intersection");
|
||||
|
||||
bool abstract_function::requires_thread() const { return false; }
|
||||
|
||||
bool as_json_function::requires_thread() const { return false; }
|
||||
@@ -298,6 +302,81 @@ static shared_ptr<function> get_dynamic_aggregate(const function_name &name, con
|
||||
return {};
|
||||
}
|
||||
|
||||
static
|
||||
shared_ptr<function>
|
||||
get_set_intersection_function(data_dictionary::database db,
|
||||
const sstring& keyspace,
|
||||
const function_name& name,
|
||||
const std::vector<shared_ptr<assignment_testable>>& provided_args,
|
||||
const sstring& receiver_ks,
|
||||
std::optional<const std::string_view> receiver_cf,
|
||||
const column_specification* receiver) {
|
||||
if (provided_args.size() != 2) {
|
||||
throw exceptions::invalid_request_exception("set_intersection() accepts 2 arguments only");
|
||||
}
|
||||
auto known_arg_types = provided_args
|
||||
| std::views::filter([] (const shared_ptr<assignment_testable>& arg) { return arg->assignment_testable_type_opt().has_value(); })
|
||||
| std::views::transform([] (const shared_ptr<assignment_testable>& arg) -> data_type { return *arg->assignment_testable_type_opt(); })
|
||||
| std::ranges::to<std::vector>();
|
||||
|
||||
if (known_arg_types.empty()) {
|
||||
throw exceptions::invalid_request_exception("set_intersection() can only be called if at least one argument type is known");
|
||||
}
|
||||
|
||||
auto known_set_types = known_arg_types
|
||||
| std::views::transform([] (const data_type& arg_type) {
|
||||
return dynamic_pointer_cast<const set_type_impl>(arg_type);
|
||||
})
|
||||
| std::ranges::to<std::vector>();
|
||||
|
||||
if (!std::ranges::all_of(known_set_types, [] (data_type t) { return t != nullptr; })) {
|
||||
throw exceptions::invalid_request_exception("set_intersection() can only be called if both arguments are of set type");
|
||||
}
|
||||
|
||||
// Normalize everything to be a frozen set
|
||||
for (auto& set_type : known_set_types) {
|
||||
set_type = dynamic_pointer_cast<const set_type_impl>(set_type->freeze());
|
||||
}
|
||||
|
||||
auto unique_remove = std::ranges::unique(known_set_types);
|
||||
known_set_types.erase(unique_remove.begin(), unique_remove.end());
|
||||
|
||||
if (known_set_types.size() != 1) {
|
||||
throw exceptions::invalid_request_exception(fmt::format("set_intersection() can only be called if both arguments are of the same set type: {}",
|
||||
known_set_types | std::views::transform(&abstract_type::name)));
|
||||
}
|
||||
|
||||
auto set_type = known_set_types.front();
|
||||
auto element_type = set_type->get_elements_type();
|
||||
|
||||
return make_native_scalar_function<true>("set_intersection", set_type, {set_type, set_type},
|
||||
[set_type, element_type] (std::span<const bytes_opt> parameters) -> bytes_opt {
|
||||
if (!parameters[0].has_value() || !parameters[1].has_value()) {
|
||||
return {};
|
||||
}
|
||||
auto set_as_range = [&] (const bytes_opt& serialized_set, managed_bytes_view& buffer) {
|
||||
buffer = managed_bytes_view(*serialized_set);
|
||||
return std::ranges::subrange<listlike_partial_deserializing_iterator>(
|
||||
listlike_partial_deserializing_iterator::begin(buffer),
|
||||
listlike_partial_deserializing_iterator::end(buffer));
|
||||
};
|
||||
auto element_less = [&] (managed_bytes_view_opt v1, managed_bytes_view_opt v2) {
|
||||
if (!v1 || !v2) {
|
||||
on_internal_error(log, "set_intersection: unexpected null value");
|
||||
}
|
||||
return element_type->compare(*v1, *v2) < 0;
|
||||
};
|
||||
std::vector<managed_bytes_view_opt> result_vector;
|
||||
managed_bytes_view buffer1, buffer2;
|
||||
std::ranges::set_intersection(
|
||||
set_as_range(parameters[0], buffer1),
|
||||
set_as_range(parameters[1], buffer2),
|
||||
std::back_inserter(result_vector),
|
||||
element_less);
|
||||
return to_bytes(set_type->pack_fragmented(result_vector.begin(), result_vector.end(), result_vector.size()));
|
||||
});
|
||||
}
|
||||
|
||||
shared_ptr<function>
|
||||
functions::get(data_dictionary::database db,
|
||||
const sstring& keyspace,
|
||||
@@ -359,6 +438,14 @@ functions::get(data_dictionary::database db,
|
||||
return make_from_json_function(db, keyspace, receiver->type);
|
||||
}
|
||||
|
||||
// FIXME: add proper support for generic functions
|
||||
if (name.has_keyspace()
|
||||
? name == SET_INTERSECTION_FUNCTION_NAME
|
||||
: name.name == SET_INTERSECTION_FUNCTION_NAME.name) {
|
||||
return get_set_intersection_function(db, keyspace, name, provided_args, receiver_ks, receiver_cf, receiver);
|
||||
}
|
||||
|
||||
|
||||
auto aggr_fun = get_dynamic_aggregate(name, provided_args);
|
||||
if (aggr_fun) {
|
||||
return aggr_fun;
|
||||
|
||||
@@ -19,6 +19,13 @@ def table1(cql, test_keyspace):
|
||||
with new_test_table(cql, test_keyspace, "p int, i int, g bigint, b blob, s text, t timestamp, u timeuuid, d date, PRIMARY KEY (p)") as table:
|
||||
yield table
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tbl_set(cql, test_keyspace):
|
||||
with new_test_table(cql, test_keyspace,
|
||||
"p int PRIMARY KEY, s1 set<int>, s2 set<int>,"
|
||||
"s3 set<tinyint>, m map<int, int>, s4 frozen<set<int>>") as table:
|
||||
yield table
|
||||
|
||||
# Check that a function that can take a column name as a parameter, can also
|
||||
# take a constant. This feature is barely useful for WHERE clauses, and
|
||||
# even less useful for selectors, but should be allowed for both.
|
||||
@@ -161,3 +168,27 @@ def test_totimestamp_date_extreme(cql, table1):
|
||||
# and work fine.
|
||||
cql.execute(f"INSERT INTO {table1} (p, d) VALUES ({p}, {2**30})")
|
||||
cql.execute(f"SELECT totimestamp(d) FROM {table1} WHERE p={p}")
|
||||
|
||||
# Test set_intersection() function. Not supported in Cassandra.
|
||||
def test_set_intersection_fn(cql, tbl_set, scylla_only):
|
||||
p1 = unique_key_int()
|
||||
p2 = unique_key_int()
|
||||
cql.execute(f"INSERT INTO {tbl_set} (p, s1, s2, s3, m, s4) VALUES ({p1}, {{1,2,3}}, {{2,3,4}}, {{1}}, {{1:2, 2:3}}, {{-1,2}})")
|
||||
cql.execute(f"INSERT INTO {tbl_set} (p, s1, s2, s3, m, s4) VALUES ({p2}, {{1,2,3}}, NULL, {{1}}, {{1:2, 2:3}}, {{-1,2}})")
|
||||
# Normal intersection.
|
||||
assert [(set([2,3]),)] == list(cql.execute(f"SELECT set_intersection(s1, s2) FROM {tbl_set} WHERE p={p1}"))
|
||||
# Intersecting with NULL.
|
||||
assert [(None,)] == list(cql.execute(f"SELECT set_intersection(s1, s2) FROM {tbl_set} WHERE p={p2}"))
|
||||
# Frozen and non-frozen.
|
||||
assert [(set([2]),)] == list(cql.execute(f"SELECT set_intersection(s1, s4) FROM {tbl_set} WHERE p={p1}"))
|
||||
# Nesting
|
||||
assert [(set([2]),)] == list(cql.execute(f"SELECT set_intersection(s1, set_intersection(s2, s4)) FROM {tbl_set} WHERE p={p1}"))
|
||||
# Some error cases
|
||||
with pytest.raises(InvalidRequest, match='accepts 2 arguments'):
|
||||
cql.execute(f"SELECT set_intersection(s1, s2, s3) FROM {tbl_set} WHERE p={p1}")
|
||||
with pytest.raises(InvalidRequest, match='accepts 2 arguments'):
|
||||
cql.execute(f"SELECT set_intersection(s1) FROM {tbl_set} WHERE p={p1}")
|
||||
with pytest.raises(InvalidRequest, match='both arguments are of set type'):
|
||||
cql.execute(f"SELECT set_intersection(s1, p) FROM {tbl_set} WHERE p={p1}")
|
||||
with pytest.raises(InvalidRequest, match='both arguments are of the same set type'):
|
||||
cql.execute(f"SELECT set_intersection(s1, s3) FROM {tbl_set} WHERE p={p1}")
|
||||
|
||||
Reference in New Issue
Block a user