From c59c87c23359be0346fa3dddb293dba432513872 Mon Sep 17 00:00:00 2001 From: Calle Wilund Date: Tue, 21 Jan 2025 12:19:32 +0000 Subject: [PATCH] generic_server: Allow sharing reloadability of certificates across shards Adds an optional callback to "listen", returning the shard local object instance. If provided, instead of creating a "full" reloadable cerificate object, only do so on shard 0, and use callback to reload other shards "manually". --- generic_server.cc | 43 +++++++++++++++++++++++++++++++------------ generic_server.hh | 8 +++++++- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/generic_server.cc b/generic_server.cc index 9076034bb8..4f143f79fb 100644 --- a/generic_server.cc +++ b/generic_server.cc @@ -13,6 +13,7 @@ #include #include #include +#include namespace generic_server { @@ -167,16 +168,34 @@ future<> server::shutdown() { } future<> -server::listen(socket_address addr, std::shared_ptr builder, bool is_shard_aware, bool keepalive, std::optional unix_domain_socket_permissions) { - shared_ptr creds = nullptr; - if (builder) { - creds = co_await builder->build_reloadable_server_credentials([this](const std::unordered_set& files, std::exception_ptr ep) { - if (ep) { - _logger.warn("Exception loading {}: {}", files, ep); - } else { - _logger.info("Reloaded {}", files); - } - }); +server::listen(socket_address addr, std::shared_ptr builder, bool is_shard_aware, bool keepalive, std::optional unix_domain_socket_permissions, std::function get_shard_instance) { + // Note: We are making the assumption that if builder is provided it will be the same for each + // invokation, regardless of address etc. In general, only CQL server will call this multiple times, + // and if TLS, it will use the same cert set. + // Could hold certs in a map and ensure separation, but then we will for all + // current uses of this class create duplicate reloadable certs for shard 0, which is + // kind of what we wanted to avoid in the first place... + if (builder && !_credentials) { + if (!get_shard_instance || this_shard_id() == 0) { + _credentials = co_await builder->build_reloadable_server_credentials([this, get_shard_instance = std::move(get_shard_instance)](const tls::credentials_builder& b, const std::unordered_set& files, std::exception_ptr ep) -> future<> { + if (ep) { + _logger.warn("Exception loading {}: {}", files, ep); + } else { + if (get_shard_instance) { + co_await smp::invoke_on_others([&]() { + auto& s = get_shard_instance(); + if (s._credentials) { + b.rebuild(*s._credentials); + } + }); + + } + _logger.info("Reloaded {}", files); + } + }); + } else { + _credentials = builder->build_server_credentials(); + } } listen_options lo; lo.reuse_address = true; @@ -186,8 +205,8 @@ server::listen(socket_address addr, std::shared_ptr {}", _server_name, addr, std::current_exception())); diff --git a/generic_server.hh b/generic_server.hh index 5d856f9b87..ab5e1954ee 100644 --- a/generic_server.hh +++ b/generic_server.hh @@ -104,6 +104,7 @@ protected: }; std::list _gentle_iterators; std::vector _listeners; + shared_ptr _credentials; public: server(const sstring& server_name, logging::logger& logger); @@ -119,7 +120,12 @@ public: future<> shutdown(); future<> stop(); - future<> listen(socket_address addr, std::shared_ptr creds, bool is_shard_aware, bool keepalive, std::optional unix_domain_socket_permissions); + future<> listen(socket_address addr, + std::shared_ptr creds, + bool is_shard_aware, bool keepalive, + std::optional unix_domain_socket_permissions, + std::function get_shard_instance = {} + ); future<> do_accepts(int which, bool keepalive, socket_address server_addr);