diff --git a/cql3/expr/expression.hh b/cql3/expr/expression.hh index 39c4494df0..9fb200205c 100644 --- a/cql3/expr/expression.hh +++ b/cql3/expr/expression.hh @@ -700,6 +700,9 @@ std::optional try_prepare_expression(const expression& expr, data_di // Does some basic type checks but no advanced validation. extern binary_operator prepare_binary_operator(binary_operator binop, data_dictionary::database db, const schema& table_schema); +// Pre-compile any constant LIKE patterns and return equivalent expression +expression optimize_like(const expression& e); + /** * @return whether this object can be assigned to the provided receiver. We distinguish diff --git a/cql3/expr/prepare_expr.cc b/cql3/expr/prepare_expr.cc index 3a485a4fb7..9e1616a464 100644 --- a/cql3/expr/prepare_expr.cc +++ b/cql3/expr/prepare_expr.cc @@ -19,6 +19,7 @@ #include "types/map.hh" #include "types/user.hh" #include "exceptions/unrecognized_entity_exception.hh" +#include "utils/like_matcher.hh" #include @@ -1242,6 +1243,86 @@ static lw_shared_ptr get_rhs_receiver(lw_shared_ptr _lhs_types; +public: + like_constant_function(data_type arg_type, bytes_view pattern) + : _name("system", fmt::format("like({})", + std::string_view(reinterpret_cast(pattern.data()), pattern.size()))) + , _matcher(pattern) { + _lhs_types.push_back(std::move(arg_type)); + } + + virtual const functions::function_name& name() const override { + return _name; + } + + virtual const std::vector& arg_types() const override { + return _lhs_types; + } + + virtual const data_type& return_type() const override { + return boolean_type; + } + + virtual bool is_pure() const override { + return true; + } + + virtual bool is_native() const override { + return true; + } + + virtual bool requires_thread() const override { + return false; + } + + virtual bool is_aggregate() const override { + return false; + } + + virtual void print(std::ostream& os) const override { + os << "LIKE(compiled)"; + } + + virtual sstring column_name(const std::vector& column_names) const override { + return "LIKE"; + } + + virtual bytes_opt execute(const std::vector& parameters) override { + auto& str_opt = parameters[0]; + if (!str_opt) { + return std::nullopt; + } + bool match_result = _matcher(*str_opt); + return data_value(match_result).serialize(); + } +}; + +expression +optimize_like(const expression& e) { + // Check for LIKE with constant pattern; replace with anonymous + // function that contains the compiled regex. + return search_and_replace(e, [] (const expression& subexpression) -> std::optional { + if (auto* binop = as_if(&subexpression)) { + if (binop->op == oper_t::LIKE) { + if (auto* rhs = as_if(&binop->rhs)) { + if ((type_of(*rhs) == utf8_type || type_of(*rhs) == ascii_type) && !rhs->is_null()) { + auto pattern = to_bytes(rhs->value.view()); + auto func = ::make_shared(type_of(binop->lhs), pattern); + auto args = std::vector(); + args.push_back(binop->lhs); + return function_call{std::move(func), std::move(args)}; + } + } + } + } + return std::nullopt; + }); +} + binary_operator prepare_binary_operator(binary_operator binop, data_dictionary::database db, const schema& table_schema) { std::optional prepared_lhs_opt = try_prepare_expression(binop.lhs, db, table_schema.ks_name(), &table_schema, {}); if (!prepared_lhs_opt) { diff --git a/test/boost/expr_test.cc b/test/boost/expr_test.cc index 06e0fa8698..a5d12425cb 100644 --- a/test/boost/expr_test.cc +++ b/test/boost/expr_test.cc @@ -4040,4 +4040,36 @@ BOOST_AUTO_TEST_CASE(prepare_binary_operator_with_null_rhs) { table_schema); } } -} \ No newline at end of file +} + +BOOST_AUTO_TEST_CASE(optimized_constant_like) { + auto check = [] (expression e, std::optional target, bool expect_optimization, std::optional pattern_arg = {}) { + auto optimized = optimize_like(e); + bool was_optimized = find_binop(optimized, [] (const binary_operator&) { return true; }) == nullptr; + if (was_optimized != expect_optimization) { + return false; + } + auto params = std::vector({target ? make_text_raw(*target) : raw_value::make_null()}); + if (pattern_arg) { + params.push_back(make_text_raw(*pattern_arg)); + } + return evaluate_with_bind_variables(optimized, params) == evaluate_with_bind_variables(e, params); + }; + + auto target_var = make_bind_variable(0, utf8_type); + auto pattern_var = make_bind_variable(1, utf8_type); + + BOOST_REQUIRE(check(binary_operator(target_var, oper_t::LIKE, make_text_const("xx%")), "xxyyz", true)); + BOOST_REQUIRE(check(binary_operator(target_var, oper_t::LIKE, make_text_const("xx%")), "qxyyz", true)); + BOOST_REQUIRE(check(binary_operator(target_var, oper_t::LIKE, make_text_const("xx%")), std::nullopt, true)); + BOOST_REQUIRE(check(binary_operator(target_var, oper_t::LIKE, pattern_var), "xxyyz", false, "xx%")); + BOOST_REQUIRE(check(binary_operator(target_var, oper_t::LIKE, pattern_var), "qxyyz", false, "xx%")); + BOOST_REQUIRE(check(binary_operator(target_var, oper_t::LIKE, pattern_var), std::nullopt, false, "xx%")); + + // Verify that optimization works for subexpressions, not just top-level expressions + auto complex = make_conjunction( + binary_operator(target_var, oper_t::LIKE, make_text_const("xx%")), + // repeated for simplicity + binary_operator(target_var, oper_t::LIKE, make_text_const("xx%"))); + BOOST_REQUIRE(check(std::move(complex), "xxyyz", true)); +}