diff --git a/api/api-doc/task_manager_test.json b/api/api-doc/task_manager_test.json index bbd57ad58c..ce4774d5f4 100644 --- a/api/api-doc/task_manager_test.json +++ b/api/api-doc/task_manager_test.json @@ -93,6 +93,14 @@ "allowMultiple":false, "type":"string", "paramType":"query" + }, + { + "name":"user_task", + "description":"A flag indicating whether a task was started by user (false by default)", + "required":false, + "allowMultiple":false, + "type":"boolean", + "paramType":"query" } ] }, diff --git a/api/task_manager_test.cc b/api/task_manager_test.cc index 3ebb94b8ad..8df86ac00e 100644 --- a/api/task_manager_test.cc +++ b/api/task_manager_test.cc @@ -11,6 +11,7 @@ #include #include "task_manager_test.hh" +#include "api/api.hh" #include "api/api-doc/task_manager_test.json.hh" #include "tasks/test_module.hh" #include "utils/overloaded_functor.hh" @@ -22,10 +23,10 @@ using namespace json; using namespace seastar::httpd; static future make_test_task(tasks::task_manager& task_manager, sstring module_name, unsigned shard, tasks::task_id id, std::string keyspace, - std::string table, std::string entity, tasks::task_info parent_d) { - return task_manager.container().invoke_on(shard, [id, module = std::move(module_name), keyspace = std::move(keyspace), table = std::move(table), entity = std::move(entity), parent_d] (tasks::task_manager& tm) { + std::string table, std::string entity, tasks::task_info parent_d, tasks::is_user_task user_task) { + return task_manager.container().invoke_on(shard, [id, module = std::move(module_name), keyspace = std::move(keyspace), table = std::move(table), entity = std::move(entity), parent_d, user_task] (tasks::task_manager& tm) { auto module_ptr = tm.find_module(module); - auto task_impl_ptr = seastar::make_shared(module_ptr, id ? id : tasks::task_id::create_random_id(), parent_d ? 0 : module_ptr->new_sequence_number(), std::move(keyspace), std::move(table), std::move(entity), parent_d.id); + auto task_impl_ptr = seastar::make_shared(module_ptr, id ? id : tasks::task_id::create_random_id(), parent_d ? 0 : module_ptr->new_sequence_number(), std::move(keyspace), std::move(table), std::move(entity), parent_d.id, user_task); return module_ptr->make_task(std::move(task_impl_ptr), parent_d).then([] (auto task) { return task->id(); }); @@ -69,9 +70,10 @@ void set_task_manager_test(http_context& ctx, routes& r, shardedget_status().shard; } + auto user_task = tasks::is_user_task{req_param(*req, "user_task", false)}; auto module = tms.local().find_module("test"); - id = co_await make_test_task(module->get_task_manager(), module->get_name(), shard, id, keyspace, table, entity, data); + id = co_await make_test_task(module->get_task_manager(), module->get_name(), shard, id, keyspace, table, entity, data, user_task); co_await tms.invoke_on(shard, [id] (tasks::task_manager& tm) { auto it = tm.get_local_tasks().find(id); if (it != tm.get_local_tasks().end()) { diff --git a/tasks/test_module.hh b/tasks/test_module.hh index c131f5a528..0c835f4d6a 100644 --- a/tasks/test_module.hh +++ b/tasks/test_module.hh @@ -23,9 +23,11 @@ class test_task_impl : public task_manager::task::impl { private: promise<> _finish_run; bool _finished = false; + tasks::is_user_task _user_task; public: - test_task_impl(task_manager::module_ptr module, task_id id, uint64_t sequence_number = 0, std::string keyspace = "", std::string table = "", std::string entity = "", task_id parent_id = task_id::create_null_id()) noexcept + test_task_impl(task_manager::module_ptr module, task_id id, uint64_t sequence_number = 0, std::string keyspace = "", std::string table = "", std::string entity = "", task_id parent_id = task_id::create_null_id(), tasks::is_user_task user_task = tasks::is_user_task::no) noexcept : task_manager::task::impl(module, id, sequence_number, "test", std::move(keyspace), std::move(table), std::move(entity), parent_id) + , _user_task(user_task) {} virtual std::string type() const override { @@ -36,6 +38,10 @@ public: return _finish_run.get_future(); } + tasks::is_user_task is_user_task() const noexcept override { + return _user_task; + } + friend class test_task; }; diff --git a/test/rest_api/rest_util.py b/test/rest_api/rest_util.py index 0e2c26b26b..e9a9786e36 100644 --- a/test/rest_api/rest_util.py +++ b/test/rest_api/rest_util.py @@ -86,6 +86,17 @@ def set_tmp_task_ttl(rest_api, seconds): resp = rest_api.send("POST", "task_manager/ttl", { "ttl" : old_ttl }) resp.raise_for_status() +@contextmanager +def set_tmp_user_task_ttl(rest_api, seconds): + resp = rest_api.send("POST", "task_manager/user_ttl", { "user_ttl" : seconds }) + resp.raise_for_status() + old_ttl = resp.json() + try: + yield old_ttl + finally: + resp = rest_api.send("POST", "task_manager/user_ttl", { "user_ttl" : old_ttl }) + resp.raise_for_status() + # Unfortunately by default Python threads print their exceptions # (e.g., assertion failures) but don't propagate them to the join(), # so the overall test doesn't fail. The following Thread wrapper diff --git a/test/rest_api/test_task_manager.py b/test/rest_api/test_task_manager.py index 815bb69011..8da38c4d45 100644 --- a/test/rest_api/test_task_manager.py +++ b/test/rest_api/test_task_manager.py @@ -6,7 +6,7 @@ import time # Use the util.py library from ../cqlpy: sys.path.insert(1, sys.path[0] + '/test/cqlpy') from util import new_test_table, new_test_keyspace -from test.rest_api.rest_util import new_test_module, new_test_task, set_tmp_task_ttl, ThreadWrapper, scylla_inject_error +from test.rest_api.rest_util import new_test_module, new_test_task, set_tmp_task_ttl, ThreadWrapper, scylla_inject_error, set_tmp_user_task_ttl from test.rest_api.task_manager_utils import check_field_correctness, check_status_correctness, assert_task_does_not_exist, list_modules, get_task_status, list_tasks, get_task_status_recursively, wait_for_task, drain_module_tasks, abort_task long_time = 1000000000 @@ -127,6 +127,27 @@ def test_task_manager_ttl(rest_api): assert_task_does_not_exist(rest_api, task0) assert_task_does_not_exist(rest_api, task1) +def test_task_manager_user_ttl(rest_api): + with new_test_module(rest_api): + args0 = {"keyspace": "keyspace0", "table": "table0", "user_task": True} + args1 = {"keyspace": "keyspace0", "table": "table0", "shard": "1", "user_task": True} + with new_test_task(rest_api, args0) as task0: + print(f"created test task {task0}") + with new_test_task(rest_api, args1) as task1: + print(f"created test task {task1}") + ttl = 10000 + user_ttl = 2 + with set_tmp_task_ttl(rest_api, ttl): + with set_tmp_user_task_ttl(rest_api, user_ttl): + resp = rest_api.send("POST", f"task_manager_test/finish_test_task/{task0}") + resp.raise_for_status() + resp = rest_api.send("POST", f"task_manager_test/finish_test_task/{task1}") + resp.raise_for_status() + + time.sleep(user_ttl + 1) + assert_task_does_not_exist(rest_api, task0) + assert_task_does_not_exist(rest_api, task1) + def test_task_manager_sequence_number(rest_api): with new_test_module(rest_api): args0 = { "shard": 0 } # sequence_number == 1