diff --git a/sstables_loader.cc b/sstables_loader.cc index eab2642318..c2ab351a46 100644 --- a/sstables_loader.cc +++ b/sstables_loader.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -581,7 +582,22 @@ class sstables_loader::download_task_impl : public tasks::task_manager::task::im sstring _prefix; sstables_loader::stream_scope _scope; std::vector _sstables; - std::vector> _progress_per_shard; + struct progress_holder { + // Wrap stream_progress in a smart pointer to enable polymorphism. + // This allows derived progress types to be passed down for per-tablet + // progress tracking while maintaining the base interface. + shared_ptr progress = make_shared(); + }; + // user could query for the progress even before _progress_per_shard + // is completed started, and this._status.state does not reflect the + // state of progress, so we have to track it separately. + enum class progress_state { + uninitialized, + initialized, + finalized, + } _progress_state = progress_state::uninitialized; + sharded _progress_per_shard; + tasks::task_manager::task::progress _final_progress; protected: virtual future<> run() override; @@ -599,7 +615,6 @@ public: , _prefix(std::move(prefix)) , _scope(scope) , _sstables(std::move(sstables)) - , _progress_per_shard(smp::count) { _status.progress_units = "batches"; } @@ -620,29 +635,29 @@ public: return tasks::is_abortable::yes; } + virtual future<> release_resources() noexcept override { + // preserve the final progress, so we can access it after the task is + // finished + _final_progress = co_await get_progress(); + _progress_state = progress_state::finalized; + co_await _progress_per_shard.stop(); + } + virtual future get_progress() const override { - struct adder { - stream_progress result; - future<> operator()(stream_progress p) { - llog.debug("get_progress: {} / {}", p.completed, p.total); - result.completed += p.completed; - result.total += p.total; - return make_ready_future<>(); - } - stream_progress get() const { - return result; - } - }; - auto p = co_await _loader.map_reduce( - adder{}, - [this] (auto&) -> stream_progress { - auto p = _progress_per_shard[this_shard_id()]; - if (p) { - return *p; - } else { - // the task was aborted - return {}; - } + switch (_progress_state) { + case progress_state::uninitialized: + co_return tasks::task_manager::task::progress{}; + case progress_state::finalized: + co_return _final_progress; + case progress_state::initialized: + break; + } + auto p = co_await _progress_per_shard.map_reduce( + adder{}, + [] (const progress_holder& holder) -> stream_progress { + auto p = holder.progress; + SCYLLA_ASSERT(p); + return *p; }); co_return tasks::task_manager::task::progress { .completed = p.completed, @@ -678,11 +693,11 @@ future<> sstables_loader::download_task_impl::run() { } catch (...) { } }); + co_await _progress_per_shard.start(); + _progress_state = progress_state::initialized; co_await _loader.invoke_on_all([this, &sstables_on_shards, table_id] (sstables_loader& loader) mutable -> future<> { - auto progress = make_shared(); - _progress_per_shard[this_shard_id()] = progress; co_await loader.load_and_stream(_ks, _cf, table_id, std::move(sstables_on_shards[this_shard_id()]), false, false, _scope, - progress); + _progress_per_shard.local().progress); }); } catch (...) { ex = std::current_exception(); diff --git a/sstables_loader.hh b/sstables_loader.hh index 7475b09d2e..d3454ff3c9 100644 --- a/sstables_loader.hh +++ b/sstables_loader.hh @@ -34,7 +34,11 @@ struct stream_progress { float completed = 0.; virtual ~stream_progress() = default; - + stream_progress& operator+=(const stream_progress& p) { + total += p.total; + completed += p.completed; + return *this; + } void start(float amount) { assert(amount >= 0); total = amount;