diff --git a/client_data.hh b/client_data.hh index 91a77c13b7..f2f6a2370f 100644 --- a/client_data.hh +++ b/client_data.hh @@ -45,6 +45,7 @@ struct client_data { std::optional ssl_enabled; std::optional ssl_protocol; std::optional username; + std::optional scheduling_group_name; sstring stage_str() const { return to_string(connection_stage); } sstring client_type_str() const { return to_string(ct); } diff --git a/generic_server.cc b/generic_server.cc index 96da666fe9..9076034bb8 100644 --- a/generic_server.cc +++ b/generic_server.cc @@ -38,6 +38,19 @@ connection::~connection() _server._connections_list.erase(iter); } +connection::execute_under_tenant_type +connection::no_tenant() { + // return a function that runs the process loop with no scheduling group games + return [] (connection_process_loop loop) { + return loop(); + }; +} + +void connection::switch_tenant(execute_under_tenant_type exec) { + _execute_under_current_tenant = std::move(exec); + _tenant_switch = true; +} + future<> server::for_each_gently(noncopyable_function fn) { _gentle_iterators.emplace_front(*this); std::list::iterator gi = _gentle_iterators.begin(); @@ -63,13 +76,26 @@ static bool is_broken_pipe_or_connection_reset(std::exception_ptr ep) { return false; } +future<> connection::process_until_tenant_switch() { + _tenant_switch = false; + { + return do_until([this] { + return _read_buf.eof() || _tenant_switch; + }, [this] { + return process_request(); + }); + } +} + future<> connection::process() { return with_gate(_pending_requests_gate, [this] { return do_until([this] { return _read_buf.eof(); }, [this] { - return process_request(); + return _execute_under_current_tenant([this] { + return process_until_tenant_switch(); + }); }).then_wrapped([this] (future<> f) { handle_error(std::move(f)); }); diff --git a/generic_server.hh b/generic_server.hh index 9a1db7ab92..5d856f9b87 100644 --- a/generic_server.hh +++ b/generic_server.hh @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -35,6 +36,11 @@ class server; // member function to perform request processing. This base class provides a // `_read_buf` and a `_write_buf` for reading requests and writing responses. class connection : public boost::intrusive::list_base_hook<> { +public: + using connection_process_loop = noncopyable_function ()>; + using execute_under_tenant_type = noncopyable_function (connection_process_loop)>; + bool _tenant_switch = false; + execute_under_tenant_type _execute_under_current_tenant = no_tenant(); protected: server& _server; connected_socket _fd; @@ -44,6 +50,8 @@ protected: seastar::gate _pending_requests_gate; seastar::gate::holder _hold_server; +private: + future<> process_until_tenant_switch(); public: connection(server& server, connected_socket&& fd); virtual ~connection(); @@ -57,6 +65,10 @@ public: virtual void on_connection_close(); virtual future<> shutdown(); + + void switch_tenant(execute_under_tenant_type execute); + + static execute_under_tenant_type no_tenant(); }; // A generic TCP socket server. diff --git a/service/qos/service_level_controller.cc b/service/qos/service_level_controller.cc index ea62020fdf..6e56ef5dfb 100644 --- a/service/qos/service_level_controller.cc +++ b/service/qos/service_level_controller.cc @@ -536,6 +536,17 @@ scheduling_group service_level_controller::get_scheduling_group(sstring service_ } } +future service_level_controller::get_user_scheduling_group(const std::optional& usr) { + if (usr && usr->name) { + auto sl_opt = co_await find_effective_service_level(*usr->name); + auto& sl_name = (sl_opt && sl_opt->shares_name) ? *sl_opt->shares_name : default_service_level_name; + co_return get_scheduling_group(sl_name); + } + else { + co_return get_default_scheduling_group(); + } +} + std::optional service_level_controller::get_active_service_level() { unsigned sched_idx = internal::scheduling_group_index(current_scheduling_group()); if (_sl_lookup[sched_idx].first) { diff --git a/service/qos/service_level_controller.hh b/service/qos/service_level_controller.hh index 5b85b1fd5d..9004a1625c 100644 --- a/service/qos/service_level_controller.hh +++ b/service/qos/service_level_controller.hh @@ -206,6 +206,28 @@ public: void abort_group0_operations(); + /** + * this is an executor of a function with arguments under a service level + * that corresponds to a given user. + * @param usr - the user for determining the service level + * @param func - the function to be executed + * @return a future that is resolved when the function's operation is resolved + * (if it returns a future). or a ready future containing the returned value + * from the function/ + */ + template > + requires std::invocable + futurize_t with_user_service_level(const std::optional& usr, Func&& func) { + if (usr && usr->name) { + return find_effective_service_level(*usr->name).then([this, func = std::move(func)] (std::optional opts) mutable { + auto& service_level_name = (opts && opts->shares_name) ? *opts->shares_name : default_service_level_name; + return with_service_level(service_level_name, std::move(func)); + }); + } else { + return with_service_level(default_service_level_name, std::move(func)); + } + } + /** * this is an executor of a function with arguments under a specific * service level. @@ -235,6 +257,12 @@ public: * get_scheduling_group("default") */ scheduling_group get_scheduling_group(sstring service_level_name); + /** + * Get the scheduling group of a specific user + * @param user - the user for determining the service level + * @return if the user is authenticated the user's scheduling group. otherwise get_scheduling_group("default") + */ + future get_user_scheduling_group(const std::optional& usr); /** * @return the name of the currently active service level if such exists or an empty * optional if no active service level. diff --git a/transport/controller.cc b/transport/controller.cc index e3b3fd435a..458ba7beeb 100644 --- a/transport/controller.cc +++ b/transport/controller.cc @@ -351,6 +351,16 @@ future> controller::get_client_data() { return _server ? _server->local().get_client_data() : protocol_server::get_client_data(); } +future<> controller::update_connections_scheduling_group() { + if (!_server) { + co_return; + } + + co_await _server->invoke_on_all([] (auto& server) { + return server.update_connections_scheduling_group(); + }); +} + future> controller::get_connections_service_level_params() { if (!_server) { co_return std::vector(); diff --git a/transport/controller.hh b/transport/controller.hh index 30d02077b4..35ab8fdbb7 100644 --- a/transport/controller.hh +++ b/transport/controller.hh @@ -79,6 +79,7 @@ public: virtual future<> stop_server() override; virtual future<> request_stop_server() override; virtual future> get_client_data() override; + future<> update_connections_scheduling_group(); future> get_connections_service_level_params(); }; diff --git a/transport/server.cc b/transport/server.cc index 64690029c2..39a9e65858 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -16,14 +16,17 @@ #include "cql3/statements/batch_statement.hh" #include "cql3/statements/modification_statement.hh" +#include "seastar/core/scheduling.hh" #include "types/collection.hh" #include "types/list.hh" #include "types/set.hh" #include "types/map.hh" #include "dht/token-sharding.hh" #include "service/migration_manager.hh" +#include "service/storage_service.hh" #include "service/memory_limiter.hh" #include "service/storage_proxy.hh" +#include "service/qos/service_level_controller.hh" #include "db/consistency_level_type.hh" #include "db/write_type.hh" #include @@ -199,12 +202,15 @@ cql_sg_stats::cql_sg_stats(maintenance_socket_enabled used_by_maintenance_socket if (std::find(vector_ref.begin(), vector_ref.end(), current_scheduling_group().name()) != vector_ref.end()) { return; } + + _use_metrics = true; register_metrics(); } void cql_sg_stats::register_metrics() { namespace sm = seastar::metrics; + auto new_metrics = sm::metric_groups(); std::vector transport_metrics; auto cur_sg_name = current_scheduling_group().name(); @@ -230,7 +236,14 @@ void cql_sg_stats::register_metrics() ); } - _metrics.add_group("transport", std::move(transport_metrics)); + new_metrics.add_group("transport", std::move(transport_metrics)); + _metrics = std::exchange(new_metrics, {}); +} + +void cql_sg_stats::rename_metrics() { + if (_use_metrics) { + register_metrics(); + } } cql_server::cql_server(distributed& qp, auth::service& auth_service, @@ -605,6 +618,7 @@ cql_server::connection::connection(cql_server& server, socket_address server_add , _server(server) , _server_addr(server_addr) , _client_state(service::client_state::external_tag{}, server._auth_service, &server._sl_controller, server.timeout_config(), addr) + , _current_scheduling_group(default_scheduling_group()) { _shedding_timer.set_callback([this] { clogger.debug("Shedding all incoming requests due to overload"); @@ -640,6 +654,7 @@ client_data cql_server::connection::make_client_data() const { } else if (_authenticating) { cd.connection_stage = client_connection_stage::authenticating; } + cd.scheduling_group_name = _current_scheduling_group.name(); return cd; } @@ -933,6 +948,14 @@ future> cql_server::connection::process_st co_return res; } +void cql_server::connection::update_scheduling_group() { + switch_tenant([this] (noncopyable_function ()> process_loop) -> future<> { + auto shg = co_await _server._sl_controller.get_user_scheduling_group(_client_state.user()); + _current_scheduling_group = shg; + co_return co_await _server._sl_controller.with_user_service_level(_client_state.user(), std::move(process_loop)); + }); +} + future> cql_server::connection::process_auth_response(uint16_t stream, request_reader in, service::client_state& client_state, tracing::trace_state_ptr trace_state) { auto sasl_challenge = client_state.get_auth_service()->underlying_authenticator().new_sasl_challenge(); @@ -941,6 +964,7 @@ future> cql_server::connection::process_au if (sasl_challenge->is_complete()) { return sasl_challenge->get_authenticated_user().then([this, sasl_challenge, stream, &client_state, challenge = std::move(challenge), trace_state](auth::authenticated_user user) mutable { client_state.set_login(std::move(user)); + update_scheduling_group(); auto f = client_state.check_user_can_login(); f = f.then([&client_state] { return client_state.maybe_update_per_service_level_params(); @@ -1230,7 +1254,6 @@ process_batch_internal(service::client_state& client_state, distributed(ps->statement.get()) == nullptr) { throw exceptions::invalid_request_exception("Invalid statement in batch: only UPDATE, INSERT and DELETE statements are allowed."); } - ::shared_ptr modif_statement_ptr = static_pointer_cast(ps->statement); if (init_trace) { tracing::add_table_name(trace_state, modif_statement_ptr->keyspace(), modif_statement_ptr->column_family()); @@ -2053,6 +2076,13 @@ future> cql_server::get_client_data() { co_return ret; } +future<> cql_server::update_connections_scheduling_group() { + return for_each_gently([] (generic_server::connection& conn) { + connection& cql_conn = dynamic_cast(conn); + cql_conn.update_scheduling_group(); + }); +} + future<> cql_server::update_connections_service_level_params() { if (!_sl_controller.is_v2()) { // Auto update of connections' service level params requires @@ -2071,6 +2101,7 @@ future<> cql_server::update_connections_service_level_params() { cs.update_per_service_level_params(*slo); } } + cql_conn.update_scheduling_group(); }); } diff --git a/transport/server.hh b/transport/server.hh index 527ea508e7..9e4308a6df 100644 --- a/transport/server.hh +++ b/transport/server.hh @@ -10,6 +10,7 @@ #include "auth/service.hh" #include +#include "seastar/core/scheduling.hh" #include "service/endpoint_lifecycle_subscriber.hh" #include "service/migration_listener.hh" #include "auth/authenticator.hh" @@ -130,7 +131,9 @@ struct cql_sg_stats { cql_sg_stats(maintenance_socket_enabled); request_kind_stats& get_cql_opcode_stats(cql_binary_opcode op) { return _cql_requests_stats[static_cast(op)]; } void register_metrics(); + void rename_metrics(); private: + bool _use_metrics = false; seastar::metrics::metric_groups _metrics; std::vector _cql_requests_stats; }; @@ -198,6 +201,7 @@ public: } future> get_client_data(); + future<> update_connections_scheduling_group(); future<> update_connections_service_level_params(); future> get_connections_service_level_params(); private: @@ -214,10 +218,12 @@ private: cql_compression _compression = cql_compression::none; service::client_state _client_state; timer _shedding_timer; + scheduling_group _current_scheduling_group; bool _shed_incoming_requests = false; unsigned _request_cpu = 0; bool _ready = false; bool _authenticating = false; + bool _tenant_switch = false; enum class tracing_request_type : uint8_t { not_requested, @@ -244,6 +250,7 @@ private: static std::pair make_client_key(const service::client_state& cli_state); client_data make_client_data() const; const service::client_state& get_client_state() const { return _client_state; } + void update_scheduling_group(); service::client_state& get_client_state() { return _client_state; } private: friend class process_request_executor;