service: client_state: Extend with abort_source

We make `client_state` store a pointer to an `abort_source`. This will
be useful in the following commit that will implement aborting ongoing
requests to strongly consistent tables upon connection shutdowns.
It might also be useful in some other places in the code in the future.

We set the abort source for client states in relevant places.
This commit is contained in:
Dawid Mędrek
2026-03-17 18:23:50 +01:00
parent 89c049b889
commit 4a87bdc778
2 changed files with 36 additions and 12 deletions

View File

@@ -76,14 +76,15 @@ public:
: _cs(cs), _auth_service(auth_service), _sl_controller(sl_controller) {}
friend client_state;
public:
client_state get() const {
return client_state(_cs, _auth_service, _sl_controller);
client_state get(abort_source* as = nullptr) const {
return client_state(_cs, _auth_service, _sl_controller, as);
}
};
private:
client_state(const client_state* cs,
seastar::sharded<auth::service>* auth_service,
seastar::sharded<qos::service_level_controller>* sl_controller)
seastar::sharded<qos::service_level_controller>* sl_controller,
abort_source* as)
: _keyspace(cs->_keyspace)
, _user(cs->_user)
, _auth_state(cs->_auth_state)
@@ -94,6 +95,7 @@ private:
, _sl_controller(sl_controller ? &sl_controller->local() : nullptr)
, _default_timeout_config(cs->_default_timeout_config)
, _timeout_config(cs->_timeout_config)
, _as(as)
, _enabled_protocol_extensions(cs->_enabled_protocol_extensions)
{}
friend client_state_for_another_shard;
@@ -153,6 +155,11 @@ private:
workload_type _workload_type = workload_type::unspecified;
// Used to communicate with the code executing user requests.
// It's a way to indicate that we might abort processing the
// request, e.g. if the corresponding connection has been severed.
abort_source* _as{nullptr};
public:
struct internal_tag {};
struct external_tag {};
@@ -211,14 +218,16 @@ public:
qos::service_level_controller* sl_controller,
timeout_config timeout_config,
const socket_address& remote_address = socket_address(),
bool bypass_auth_checks = false)
bool bypass_auth_checks = false,
abort_source* as = nullptr)
: _is_internal(false)
, _bypass_auth_checks(bypass_auth_checks)
, _remote_address(remote_address)
, _auth_service(&auth_service)
, _sl_controller(sl_controller)
, _default_timeout_config(timeout_config)
, _timeout_config(timeout_config) {
, _timeout_config(timeout_config)
, _as(as) {
if (!auth_service.underlying_authenticator().require_authentication()) {
_user = auth::authenticated_user();
}
@@ -244,29 +253,32 @@ public:
return *_sl_controller;
}
client_state(internal_tag) : client_state(internal_tag{}, infinite_timeout_config)
client_state(internal_tag, abort_source* as = nullptr) : client_state(internal_tag{}, infinite_timeout_config, as)
{}
client_state(internal_tag, const timeout_config& config)
client_state(internal_tag, const timeout_config& config, abort_source* as = nullptr)
: _keyspace("system")
, _is_internal(true)
, _bypass_auth_checks(true)
, _default_timeout_config(config)
, _timeout_config(config)
, _as(as)
{}
client_state(internal_tag, auth::service& auth_service, qos::service_level_controller& sl_controller, sstring username)
client_state(internal_tag, auth::service& auth_service, qos::service_level_controller& sl_controller, sstring username, abort_source* as = nullptr)
: _user(auth::authenticated_user(username))
, _auth_state(auth_state::READY)
, _is_internal(true)
, _bypass_auth_checks(true)
, _auth_service(&auth_service)
, _sl_controller(&sl_controller)
, _as(as)
{}
client_state(auth::service& auth_service,
qos::service_level_controller* sl_controller,
forwarded_client_state&& forwarded_state)
forwarded_client_state&& forwarded_state,
abort_source* as = nullptr)
: _keyspace(std::move(forwarded_state.keyspace))
, _user(forwarded_state.username ? auth::authenticated_user(*forwarded_state.username) : auth::authenticated_user{})
, _auth_state(auth_state::READY)
@@ -277,6 +289,7 @@ public:
, _sl_controller(sl_controller)
, _default_timeout_config(forwarded_state.timeout_config)
, _timeout_config(std::move(forwarded_state.timeout_config))
, _as(as)
, _enabled_protocol_extensions(cql_transport::cql_protocol_extension_enum_set::from_mask(
forwarded_state.protocol_extensions_mask))
{}
@@ -392,6 +405,16 @@ public:
return _keyspace;
}
abort_source& get_abort_source() {
if (_as == nullptr) {
utils::on_internal_error("client_state::get_abort_source(): Tried to dereference nullptr");
}
return *_as;
}
abort_source* get_abort_source_ptr() noexcept {
return _as;
}
/**
* Sets active user. Does _not_ validate anything
*/

View File

@@ -393,7 +393,8 @@ void cql_server::init_messaging_service() {
co_return co_await container().invoke_on(shard, [src_host, req = std::move(req)] (cql_server& shard_svc) mutable -> future<forward_cql_execute_response> {
service::client_state cs(shard_svc._auth_service,
&shard_svc._sl_controller,
std::move(req.client_state));
std::move(req.client_state),
&shard_svc._abort_source);
tracing::trace_state_ptr trace_state_ptr;
if (req.trace_info) {
trace_state_ptr = tracing::tracing::get_local_tracing_instance().create_session(*req.trace_info);
@@ -1052,7 +1053,7 @@ cql_server::connection::connection(cql_server& server, socket_address server_add
: generic_server::connection{server, std::move(fd), sem, std::move(initial_sem_units)}
, _server(server)
, _server_addr(server_addr)
, _client_state(service::client_state::external_tag{}, server._auth_service, &server._sl_controller, server.timeout_config(), addr, bool(server._used_by_maintenance_socket))
, _client_state(service::client_state::external_tag{}, server._auth_service, &server._sl_controller, server.timeout_config(), addr, bool(server._used_by_maintenance_socket), &server._abort_source)
, _current_scheduling_group(server.get_scheduling_group_for_new_connection())
{
_shedding_timer.set_callback([this] {
@@ -1835,7 +1836,7 @@ cql_server::process(uint16_t stream, request_reader in, service::client_state& c
msg = co_await container().invoke_on(shard, sg, [&, stream, dialect, version] (cql_server& server) -> future<process_fn_return_type> {
bytes_ostream linearization_buffer;
request_reader in(is, linearization_buffer);
auto local_client_state = gcs.get();
auto local_client_state = gcs.get(&server._abort_source);
auto local_trace_state = gt.get();
co_return co_await process_fn(local_client_state, server._query_processor, in, stream, version,
/* FIXME */empty_service_permit(), std::move(local_trace_state), false, cached_vals, dialect);