diff --git a/idl/mapreduce_request.idl.hh b/idl/mapreduce_request.idl.hh index 1bf205d924..b9c8bcc7cc 100644 --- a/idl/mapreduce_request.idl.hh +++ b/idl/mapreduce_request.idl.hh @@ -45,5 +45,5 @@ struct mapreduce_result { std::vector query_results; }; -verb mapreduce_request(query::mapreduce_request req [[ref]], std::optional trace_info [[ref]]) -> query::mapreduce_result; +verb [[cancellable]] mapreduce_request(query::mapreduce_request req [[ref]], std::optional trace_info [[ref]]) -> query::mapreduce_result; } diff --git a/service/mapreduce_service.cc b/service/mapreduce_service.cc index 2a434b196f..3195d30c94 100644 --- a/service/mapreduce_service.cc +++ b/service/mapreduce_service.cc @@ -21,6 +21,7 @@ #include "gms/gossiper.hh" #include "idl/mapreduce_request.dist.hh" #include "locator/abstract_replication_strategy.hh" +#include "utils/error_injection.hh" #include "utils/log.hh" #include "message/messaging_service.hh" #include "query-request.hh" @@ -283,7 +284,7 @@ public: // Try to send this mapreduce_request to another node. try { co_return co_await ser::mapreduce_request_rpc_verbs::send_mapreduce_request( - &_mapreducer._messaging, id, req, _tr_info + &_mapreducer._messaging, id, _mapreducer._abort_outgoing_tasks, req, _tr_info ); } catch (rpc::closed_error& e) { if (_mapreducer._shutdown) { @@ -364,6 +365,8 @@ future mapreduce_service::dispatch_to_shards( query::mapreduce_request req, std::optional tr_info ) { + co_await utils::get_local_injector().inject("mapreduce_pause_dispatch_to_shards", utils::wait_for_message(5min)); + _stats.requests_dispatched_to_own_shards += 1; std::optional result; std::vector> futures; diff --git a/service/mapreduce_service.hh b/service/mapreduce_service.hh index 486be5a48e..8a87f29994 100644 --- a/service/mapreduce_service.hh +++ b/service/mapreduce_service.hh @@ -123,6 +123,7 @@ class mapreduce_service : public seastar::peering_sharded_service& _db; const locator::shared_token_metadata& _shared_token_metadata; + abort_source _abort_outgoing_tasks; struct stats { uint64_t requests_dispatched_to_other_nodes = 0; @@ -141,7 +142,10 @@ public: , _proxy(p) , _db(db) , _shared_token_metadata(stm) - , _early_abort_subscription(as.subscribe([this] () noexcept { _shutdown = true; })) + , _early_abort_subscription(as.subscribe([this] () noexcept { + _shutdown = true; + _abort_outgoing_tasks.request_abort(); + })) { register_metrics(); init_messaging_service(); diff --git a/test/topology/test_aggregation.py b/test/topology/test_aggregation.py new file mode 100644 index 0000000000..97c6cf917d --- /dev/null +++ b/test/topology/test_aggregation.py @@ -0,0 +1,79 @@ +# +# Copyright (C) 2025-present ScyllaDB +# +# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 +# +import asyncio +import pytest +import time +import logging +import random + +from cassandra.cluster import NoHostAvailable # type: ignore + +from test.pylib.manager_client import ManagerClient +from test.pylib.rest_client import inject_error +from test.pylib.util import wait_for, wait_for_cql_and_get_hosts +from test.topology.conftest import skip_mode +from test.topology.util import new_test_keyspace, new_test_table + +logger = logging.getLogger(__name__) + +@pytest.mark.asyncio +@skip_mode("release", "error injections are not supported in release mode") +async def test_cancel_mapreduce(manager: ManagerClient): + """ + This test verifies that stopping the supercoordinator of a mapreduce task cancels + outgoing queries to other nodes, which would otherwise prevent the shutdown. + """ + + running_servers = await manager.running_servers() + assert len(running_servers) >= 2 + + s1, s2 = running_servers[0], running_servers[1] + cql = manager.get_cql() + hosts = await wait_for_cql_and_get_hosts(cql, [s1, s2], time.time() + 30) + + await manager.api.set_logger_level(s1.ip_addr, "forward_service", "debug") + + [host1] = filter(lambda host: host.address == s1.ip_addr, hosts) + host_id2 = await manager.get_host_id(s2.server_id) + + async with new_test_keyspace(cql, "WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1}") as ks: + async with new_test_table(cql, ks, "pk int PRIMARY KEY, v int") as t: + # Distribute data across the nodes. + for _ in range(250): + # Note: CQL int is a 32-bit integer. + pk = random.randint(-2**30, 2**30) + v = random.randint(-2**30, 2**30) + await cql.run_async(f"INSERT INTO {t} (pk, v) VALUES ({pk}, {v})") + + s1_log = await manager.server_open_log(s1.server_id) + s2_log = await manager.server_open_log(s2.server_id) + + s1_mark = await s1_log.mark() + s2_mark = await s2_log.mark() + + # Prevent finishing local mapreduce tasks on node 2. + async with inject_error(manager.api, s2.ip_addr, "mapreduce_pause_dispatch_to_shards"): + async def do_select(): + # Make node 1 the supercoordinator of the mapreduce task corresponding to aggregation. + # We use this timeout because it's longer than the cumulative timeout of the following + # steps. For the test to be reliable, the query cannot end on its own. + try: + await cql.run_async(f"SELECT count(*) FROM {t} BYPASS CACHE USING TIMEOUT 600s", host=host1) + pytest.fail(f"Query finished, but it wasn't supposed to") + except NoHostAvailable: + pass + + async def wait_and_shutdown(): + # Make sure node 1 is the supercoordinator and sends a mapreduce task to node 2. + await s1_log.wait_for(f"dispatching mapreduce_request=.* to address={host_id2}", from_mark=s1_mark, timeout=60) + # Make sure that node 2 is preventing its local mapreduce task from finishing. + await s2_log.wait_for("mapreduce_pause_dispatch_to_shards: waiting for message", from_mark=s2_mark, timeout=60) + # Verify that the supercoordinator stops without an issue despite the ongoing mapreduce task. + await manager.server_stop_gracefully(s1.server_id, timeout=120) + + async with asyncio.TaskGroup() as tg: + _ = tg.create_task(do_select()) + _ = tg.create_task(wait_and_shutdown())