diff --git a/cql3/prepared_statements_cache.hh b/cql3/prepared_statements_cache.hh index 08b48e8f43..40ef0179c9 100644 --- a/cql3/prepared_statements_cache.hh +++ b/cql3/prepared_statements_cache.hh @@ -105,6 +105,7 @@ public: static const std::chrono::minutes entry_expiry; using key_type = prepared_cache_key_type; + using pinned_value_type = cache_value_ptr; using value_type = checked_weak_ptr; using statement_is_too_big = typename cache_type::entry_is_too_big; @@ -116,9 +117,14 @@ public: : _cache(size, entry_expiry, logger) {} + template + future get_pinned(const key_type& key, LoadFunc&& load) { + return _cache.get_ptr(key.key(), [load = std::forward(load)] (const cache_key_type&) { return load(); }); + } + template future get(const key_type& key, LoadFunc&& load) { - return _cache.get_ptr(key.key(), [load = std::forward(load)] (const cache_key_type&) { return load(); }).then([] (cache_value_ptr v_ptr) { + return get_pinned(key, std::forward(load)).then([] (cache_value_ptr v_ptr) { return make_ready_future((*v_ptr)->checked_weak_from_this()); }); } diff --git a/cql3/query_processor.cc b/cql3/query_processor.cc index a38cd6b81e..2254134d1c 100644 --- a/cql3/query_processor.cc +++ b/cql3/query_processor.cc @@ -697,7 +697,7 @@ future<::shared_ptr> query_processor::prepare(sstring query_string, const service::client_state& client_state, cql3::dialect d) { try { auto key = compute_id(query_string, client_state.get_raw_keyspace(), d); - auto prep_ptr = co_await _prepared_cache.get(key, [this, &query_string, &client_state, d] { + auto prep_entry = co_await _prepared_cache.get_pinned(key, [this, &query_string, &client_state, d] { auto prepared = get_statement(query_string, client_state, d); prepared->calculate_metadata_id(); auto bound_terms = prepared->statement->get_bound_terms(); @@ -711,13 +711,13 @@ query_processor::prepare(sstring query_string, const service::client_state& clie return make_ready_future>(std::move(prepared)); }); - const auto& warnings = prep_ptr->warnings; - const auto msg = ::make_shared(prepared_cache_key_type::cql_id(key), std::move(prep_ptr), + co_await utils::get_local_injector().inject( + "query_processor_prepare_wait_after_cache_get", + utils::wait_for_message(std::chrono::seconds(60))); + + auto msg = ::make_shared(prepared_cache_key_type::cql_id(key), std::move(prep_entry), client_state.is_protocol_extension_set(cql_transport::cql_protocol_extension::LWT_ADD_METADATA_MARK)); - for (const auto& w : warnings) { - msg->add_warning(w); - } - co_return ::shared_ptr(std::move(msg)); + co_return std::move(msg); } catch(typename prepared_statements_cache::statement_is_too_big&) { throw prepared_statement_is_too_big(query_string); } diff --git a/service/paxos/paxos_state.cc b/service/paxos/paxos_state.cc index dd8ec19235..d2acc564f1 100644 --- a/service/paxos/paxos_state.cc +++ b/service/paxos/paxos_state.cc @@ -454,7 +454,7 @@ static future do_execute_cql_with_timeout(sstring req, auto ps_ptr = qp.get_prepared(cache_key); if (!ps_ptr) { const auto msg_ptr = co_await qp.prepare(req, qs, cql3::internal_dialect()); - ps_ptr = std::move(msg_ptr->get_prepared()); + ps_ptr = msg_ptr->get_prepared(); if (!ps_ptr) { on_internal_error(paxos_state::logger, "prepared statement is null"); } diff --git a/table_helper.cc b/table_helper.cc index 58a3ed4223..051d38e50e 100644 --- a/table_helper.cc +++ b/table_helper.cc @@ -75,7 +75,7 @@ future table_helper::try_prepare(bool fallback, cql3::query_processor& qp, auto& stmt = fallback ? _insert_cql_fallback.value() : _insert_cql; try { shared_ptr msg_ptr = co_await qp.prepare(stmt, qs.get_client_state(), dialect); - _prepared_stmt = std::move(msg_ptr->get_prepared()); + _prepared_stmt = msg_ptr->get_prepared(); shared_ptr cql_stmt = _prepared_stmt->statement; _insert_stmt = dynamic_pointer_cast(cql_stmt); _is_fallback_stmt = fallback; diff --git a/test/cluster/test_prepare_race.py b/test/cluster/test_prepare_race.py new file mode 100644 index 0000000000..c9776df704 --- /dev/null +++ b/test/cluster/test_prepare_race.py @@ -0,0 +1,65 @@ +# +# Copyright (C) 2026-present ScyllaDB +# +# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 +# + +import asyncio +import pytest + +from test.cluster.util import new_test_keyspace, new_test_table +from test.pylib.manager_client import ManagerClient +from test.pylib.rest_client import inject_error_one_shot + + +@pytest.mark.asyncio +@pytest.mark.skip_mode(mode="release", reason="error injections are not supported in release mode") +async def test_prepare_fails_if_cached_statement_is_invalidated_mid_prepare(manager: ManagerClient): + server = await manager.server_add() + cql = manager.get_cql() + log = await manager.server_open_log(server.server_id) + + async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1};") as ks: + async with new_test_table(manager, ks, "pk int PRIMARY KEY") as table: + query = f"SELECT * FROM {table} WHERE pk = ?" + loop = asyncio.get_running_loop() + await cql.run_async(f"INSERT INTO {table} (pk) VALUES (7)") + await cql.run_async(f"INSERT INTO {table} (pk) VALUES (8)") + + handler = await inject_error_one_shot(manager.api, server.ip_addr, "query_processor_prepare_wait_after_cache_get") + mark = await log.mark() + prepare_future = loop.run_in_executor(None, lambda: cql.prepare(query)) + await log.wait_for("query_processor_prepare_wait_after_cache_get: waiting for message", from_mark=mark, timeout=60) + + # Trigger table schema update (metadata-only) to invalidate prepared statements while PREPARE is paused. + await cql.run_async(f"ALTER TABLE {table} WITH comment = 'invalidate-prepared-race'") + + await handler.message() + done, _ = await asyncio.wait({prepare_future}, timeout=15) + if not done: + pytest.fail("Timed out waiting for PREPARE to complete after signaling injection") + + result = done.pop().result() + print(f"PREPARE succeeded as expected: {result!r}") + + rows = cql.execute(result, [7]) + row = rows.one() + assert row is not None and row.pk == 7 + + # Invalidate prepared statements again, then execute the same prepared object. + # The driver should transparently re-prepare and re-request execution. + await cql.run_async(f"ALTER TABLE {table} WITH comment = 'invalidate-prepared-race-again'") + + reprepare_handler = await inject_error_one_shot(manager.api, server.ip_addr, "query_processor_prepare_wait_after_cache_get") + reprepare_mark = await log.mark() + execute_future = loop.run_in_executor(None, lambda: cql.execute(result, [8])) + await log.wait_for("query_processor_prepare_wait_after_cache_get: waiting for message", from_mark=reprepare_mark, timeout=60) + + await reprepare_handler.message() + execute_done, _ = await asyncio.wait({execute_future}, timeout=15) + if not execute_done: + pytest.fail("Timed out waiting for driver execute to finish after re-prepare signaling") + + retried_rows = execute_done.pop().result() + retried_row = retried_rows.one() + assert retried_row is not None and retried_row.pk == 8 diff --git a/transport/messages/result_message.cc b/transport/messages/result_message.cc index bdba1e93d3..017bb12453 100644 --- a/transport/messages/result_message.cc +++ b/transport/messages/result_message.cc @@ -67,14 +67,17 @@ void result_message::visitor_base::visit(const result_message::exception& ex) { ex.throw_me(); } -result_message::prepared::prepared(cql3::statements::prepared_statement::checked_weak_ptr prepared, bool support_lwt_opt) - : _prepared(std::move(prepared)) +result_message::prepared::prepared(cql3::prepared_statements_cache::pinned_value_type prepared_entry, bool support_lwt_opt) + : _prepared_entry(std::move(prepared_entry)) , _metadata( - _prepared->bound_names, - _prepared->partition_key_bind_indices, - support_lwt_opt ? _prepared->statement->is_conditional() : false) - , _result_metadata{extract_result_metadata(_prepared->statement)} + (*_prepared_entry)->bound_names, + (*_prepared_entry)->partition_key_bind_indices, + support_lwt_opt ? (*_prepared_entry)->statement->is_conditional() : false) + , _result_metadata{extract_result_metadata((*_prepared_entry)->statement)} { + for (const auto& w : (*_prepared_entry)->warnings){ + add_warning(w); + } } ::shared_ptr result_message::prepared::extract_result_metadata(::shared_ptr statement) { diff --git a/transport/messages/result_message.hh b/transport/messages/result_message.hh index d92f94c6ef..9511d465da 100644 --- a/transport/messages/result_message.hh +++ b/transport/messages/result_message.hh @@ -13,6 +13,7 @@ #include #include "cql3/result_set.hh" +#include "cql3/prepared_statements_cache.hh" #include "cql3/statements/prepared_statement.hh" #include "cql3/query_options.hh" @@ -30,14 +31,14 @@ namespace messages { class result_message::prepared : public result_message { private: - cql3::statements::prepared_statement::checked_weak_ptr _prepared; + cql3::prepared_statements_cache::pinned_value_type _prepared_entry; cql3::prepared_metadata _metadata; ::shared_ptr _result_metadata; protected: - prepared(cql3::statements::prepared_statement::checked_weak_ptr prepared, bool support_lwt_opt); + prepared(cql3::prepared_statements_cache::pinned_value_type prepared_entry, bool support_lwt_opt); public: - cql3::statements::prepared_statement::checked_weak_ptr& get_prepared() { - return _prepared; + cql3::statements::prepared_statement::checked_weak_ptr get_prepared() { + return (*_prepared_entry)->checked_weak_from_this(); } const cql3::prepared_metadata& metadata() const { @@ -49,7 +50,7 @@ public: } cql3::cql_metadata_id_type get_metadata_id() const { - return _prepared->get_metadata_id(); + return (*_prepared_entry)->get_metadata_id(); } class cql; @@ -166,8 +167,8 @@ std::ostream& operator<<(std::ostream& os, const result_message::set_keyspace& m class result_message::prepared::cql : public result_message::prepared { bytes _id; public: - cql(const bytes& id, cql3::statements::prepared_statement::checked_weak_ptr p, bool support_lwt_opt) - : result_message::prepared(std::move(p), support_lwt_opt) + cql(const bytes& id, cql3::prepared_statements_cache::pinned_value_type prepared_entry, bool support_lwt_opt) + : result_message::prepared(std::move(prepared_entry), support_lwt_opt) , _id{id} { }