diff --git a/include/envoy/thread_local/thread_local.h b/include/envoy/thread_local/thread_local.h index 6f082a4607f5..41c77d730d19 100644 --- a/include/envoy/thread_local/thread_local.h +++ b/include/envoy/thread_local/thread_local.h @@ -74,6 +74,17 @@ class Slot { */ using InitializeCb = std::function; virtual void set(InitializeCb cb) PURE; + + /** + * UpdateCb takes the current stored data, and returns an updated/new version data. + * TLS will run the callback and replace the stored data with the returned value *in each thread*. + * + * NOTE: The update callback is not supposed to capture the Slot, or its owner. As the owner may + * be destructed in main thread before the update_cb gets called in a worker thread. + **/ + using UpdateCb = std::function; + virtual void runOnAllThreads(const UpdateCb& update_cb) PURE; + virtual void runOnAllThreads(const UpdateCb& update_cb, Event::PostCb complete_cb) PURE; }; using SlotPtr = std::unique_ptr; diff --git a/source/common/common/non_copyable.h b/source/common/common/non_copyable.h index c248a37f48e4..fb356770c3f5 100644 --- a/source/common/common/non_copyable.h +++ b/source/common/common/non_copyable.h @@ -2,14 +2,19 @@ namespace Envoy { /** - * Mixin class that makes derived classes not copyable. Like boost::noncopyable without boost. + * Mixin class that makes derived classes not copyable and not moveable. Like boost::noncopyable + * without boost. */ class NonCopyable { protected: NonCopyable() = default; -private: - NonCopyable(const NonCopyable&); - NonCopyable& operator=(const NonCopyable&); + // Non-moveable. + NonCopyable(NonCopyable&&) noexcept = delete; + NonCopyable& operator=(NonCopyable&&) noexcept = delete; + + // Non-copyable. + NonCopyable(const NonCopyable&) = delete; + NonCopyable& operator=(const NonCopyable&) = delete; }; } // namespace Envoy diff --git a/source/common/config/config_provider_impl.cc b/source/common/config/config_provider_impl.cc index 11cbf993e51c..5745647e2dbf 100644 --- a/source/common/config/config_provider_impl.cc +++ b/source/common/config/config_provider_impl.cc @@ -23,6 +23,16 @@ ConfigSubscriptionCommonBase::~ConfigSubscriptionCommonBase() { init_target_.ready(); config_provider_manager_.unbindSubscription(manager_identifier_); } + +void ConfigSubscriptionCommonBase::applyConfigUpdate(const ConfigUpdateCb& update_fn) { + tls_->runOnAllThreads([update_fn](ThreadLocal::ThreadLocalObjectSharedPtr previous) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + auto prev_thread_local_config = std::dynamic_pointer_cast(previous); + prev_thread_local_config->config_ = update_fn(prev_thread_local_config->config_); + return previous; + }); +} + bool ConfigSubscriptionInstance::checkAndApplyConfigUpdate(const Protobuf::Message& config_proto, const std::string& config_name, const std::string& version_info) { diff --git a/source/common/config/config_provider_impl.h b/source/common/config/config_provider_impl.h index a1a7b02d71b7..02acf3b91f70 100644 --- a/source/common/config/config_provider_impl.h +++ b/source/common/config/config_provider_impl.h @@ -220,26 +220,8 @@ class ConfigSubscriptionCommonBase * * @param update_fn the callback to run on each thread, it takes the previous version Config and * returns a updated/new version Config. - * @param complete_cb the callback to run when the update propagation is done. */ - void applyConfigUpdate( - const ConfigUpdateCb& update_fn, const Event::PostCb& complete_cb = []() {}) { - // It is safe to call shared_from_this here as this is in main thread, and destruction of a - // ConfigSubscriptionCommonBase owner (i.e., a provider) happens in main thread as well. - auto shared_this = shared_from_this(); - tls_->runOnAllThreads( - [this, update_fn]() { - tls_->getTyped().config_ = update_fn(this->getConfig()); - }, - // During the update propagation, a subscription may get teared down in main thread due to - // all owners/providers destructed in a xDS update (e.g. LDS demolishes a - // RouteConfigProvider and its subscription). - // If such a race condition happens, holding a reference to the "*this" subscription - // instance in this cb will ensure the shared "*this" gets posted back to main thread, after - // all the workers finish calling the update_fn, at which point it's safe to destruct - // "*this" instance. - [shared_this, complete_cb]() { complete_cb(); }); - } + void applyConfigUpdate(const ConfigUpdateCb& update_fn); void setLastUpdated() { last_updated_ = time_source_.systemTime(); } diff --git a/source/common/router/rds_impl.cc b/source/common/router/rds_impl.cc index e316460e0f09..2efa53e5de9d 100644 --- a/source/common/router/rds_impl.cc +++ b/source/common/router/rds_impl.cc @@ -194,8 +194,12 @@ Router::ConfigConstSharedPtr RdsRouteConfigProviderImpl::config() { void RdsRouteConfigProviderImpl::onConfigUpdate() { ConfigConstSharedPtr new_config( new ConfigImpl(config_update_info_->routeConfiguration(), factory_context_, false)); - tls_->runOnAllThreads( - [this, new_config]() -> void { tls_->getTyped().config_ = new_config; }); + tls_->runOnAllThreads([new_config](ThreadLocal::ThreadLocalObjectSharedPtr previous) + -> ThreadLocal::ThreadLocalObjectSharedPtr { + auto prev_config = std::dynamic_pointer_cast(previous); + prev_config->config_ = new_config; + return previous; + }); } RouteConfigProviderManagerImpl::RouteConfigProviderManagerImpl(Server::Admin& admin) { diff --git a/source/common/thread_local/thread_local_impl.cc b/source/common/thread_local/thread_local_impl.cc index 9781db07797b..5d9f584b517e 100644 --- a/source/common/thread_local/thread_local_impl.cc +++ b/source/common/thread_local/thread_local_impl.cc @@ -27,8 +27,9 @@ SlotPtr InstanceImpl::allocateSlot() { if (free_slot_indexes_.empty()) { std::unique_ptr slot(new SlotImpl(*this, slots_.size())); - slots_.push_back(slot.get()); - return slot; + auto wrapper = std::make_unique(*this, std::move(slot)); + slots_.push_back(wrapper->slot_.get()); + return wrapper; } const uint32_t idx = free_slot_indexes_.front(); free_slot_indexes_.pop_front(); @@ -42,11 +43,64 @@ bool InstanceImpl::SlotImpl::currentThreadRegistered() { return thread_local_data_.data_.size() > index_; } +void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) { + parent_.runOnAllThreads([this, cb]() { setThreadLocal(index_, cb(get())); }); +} + +void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { + parent_.runOnAllThreads([this, cb]() { setThreadLocal(index_, cb(get())); }, complete_cb); +} + ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::get() { ASSERT(currentThreadRegistered()); return thread_local_data_.data_[index_]; } +InstanceImpl::Bookkeeper::Bookkeeper(InstanceImpl& parent, std::unique_ptr&& slot) + : parent_(parent), slot_(std::move(slot)), + ref_count_(/*not used.*/ nullptr, + [slot = slot_.get(), &parent = this->parent_](uint32_t* /* not used */) { + // On destruction, post a cleanup callback on main thread, this could happen on + // any thread. + parent.scheduleCleanup(slot); + }) {} + +ThreadLocalObjectSharedPtr InstanceImpl::Bookkeeper::get() { return slot_->get(); } + +void InstanceImpl::Bookkeeper::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { + slot_->runOnAllThreads( + [cb, ref_count = this->ref_count_](ThreadLocalObjectSharedPtr previous) { + return cb(std::move(previous)); + }, + complete_cb); +} + +void InstanceImpl::Bookkeeper::runOnAllThreads(const UpdateCb& cb) { + slot_->runOnAllThreads([cb, ref_count = this->ref_count_](ThreadLocalObjectSharedPtr previous) { + return cb(std::move(previous)); + }); +} + +bool InstanceImpl::Bookkeeper::currentThreadRegistered() { + return slot_->currentThreadRegistered(); +} + +void InstanceImpl::Bookkeeper::runOnAllThreads(Event::PostCb cb) { + // Use ref_count_ to bookkeep how many on-the-fly callback are out there. + slot_->runOnAllThreads([cb, ref_count = this->ref_count_]() { cb(); }); +} + +void InstanceImpl::Bookkeeper::runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) { + // Use ref_count_ to bookkeep how many on-the-fly callback are out there. + slot_->runOnAllThreads([cb, main_callback, ref_count = this->ref_count_]() { cb(); }, + main_callback); +} + +void InstanceImpl::Bookkeeper::set(InitializeCb cb) { + slot_->set([cb, ref_count = this->ref_count_](Event::Dispatcher& dispatcher) + -> ThreadLocalObjectSharedPtr { return cb(dispatcher); }); +} + void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_thread) { ASSERT(std::this_thread::get_id() == main_thread_id_); ASSERT(!shutdown_); @@ -61,6 +115,38 @@ void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_threa } } +// Puts the slot into a deferred delete container, the slot will be destructed when its out-going +// callback reference count goes to 0. +void InstanceImpl::recycle(std::unique_ptr&& slot) { + ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(slot != nullptr); + auto* slot_addr = slot.get(); + deferred_deletes_.insert({slot_addr, std::move(slot)}); +} + +// Called by the Bookkeeper ref_count destructor, the SlotImpl in the deferred deletes map can be +// destructed now. +void InstanceImpl::scheduleCleanup(SlotImpl* slot) { + if (shutdown_) { + // If server is shutting down, do nothing here. + // The destruction of Bookkeeper has already transferred the SlotImpl to the deferred_deletes_ + // queue. No matter if this method is called from a Worker thread, the SlotImpl will be + // destructed on main thread when InstanceImpl destructs. + return; + } + if (std::this_thread::get_id() == main_thread_id_) { + // If called from main thread, save a callback. + ASSERT(deferred_deletes_.contains(slot)); + deferred_deletes_.erase(slot); + return; + } + main_thread_dispatcher_->post([slot, this]() { + ASSERT(deferred_deletes_.contains(slot)); + // The slot is guaranteed to be put into the deferred_deletes_ map by Bookkeeper destructor. + deferred_deletes_.erase(slot); + }); +} + void InstanceImpl::removeSlot(SlotImpl& slot) { ASSERT(std::this_thread::get_id() == main_thread_id_); diff --git a/source/common/thread_local/thread_local_impl.h b/source/common/thread_local/thread_local_impl.h index 39a0f12a3e4f..49f1889e44d7 100644 --- a/source/common/thread_local/thread_local_impl.h +++ b/source/common/thread_local/thread_local_impl.h @@ -8,6 +8,9 @@ #include "envoy/thread_local/thread_local.h" #include "common/common/logger.h" +#include "common/common/non_copyable.h" + +#include "absl/container/flat_hash_map.h" namespace Envoy { namespace ThreadLocal { @@ -15,7 +18,7 @@ namespace ThreadLocal { /** * Implementation of ThreadLocal that relies on static thread_local objects. */ -class InstanceImpl : Logger::Loggable, public Instance { +class InstanceImpl : Logger::Loggable, public NonCopyable, public Instance { public: InstanceImpl() : main_thread_id_(std::this_thread::get_id()) {} ~InstanceImpl() override; @@ -35,6 +38,8 @@ class InstanceImpl : Logger::Loggable, public Instance { // ThreadLocal::Slot ThreadLocalObjectSharedPtr get() override; bool currentThreadRegistered() override; + void runOnAllThreads(const UpdateCb& cb) override; + void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) override; void runOnAllThreads(Event::PostCb cb) override { parent_.runOnAllThreads(cb); } void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override { parent_.runOnAllThreads(cb, main_callback); @@ -45,17 +50,47 @@ class InstanceImpl : Logger::Loggable, public Instance { const uint64_t index_; }; + // A Wrapper of SlotImpl which on destruction returns the SlotImpl to the deferred delete queue + // (detaches it). + struct Bookkeeper : public Slot { + Bookkeeper(InstanceImpl& parent, std::unique_ptr&& slot); + ~Bookkeeper() override { parent_.recycle(std::move(slot_)); } + + // ThreadLocal::Slot + ThreadLocalObjectSharedPtr get() override; + void runOnAllThreads(const UpdateCb& cb) override; + void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) override; + bool currentThreadRegistered() override; + void runOnAllThreads(Event::PostCb cb) override; + void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override; + void set(InitializeCb cb) override; + + InstanceImpl& parent_; + std::unique_ptr slot_; + std::shared_ptr ref_count_; + }; + struct ThreadLocalData { Event::Dispatcher* dispatcher_{}; std::vector data_; }; + void recycle(std::unique_ptr&& slot); + // Cleanup the deferred deletes queue. + void scheduleCleanup(SlotImpl* slot); + void removeSlot(SlotImpl& slot); void runOnAllThreads(Event::PostCb cb); void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback); static void setThreadLocal(uint32_t index, ThreadLocalObjectSharedPtr object); static thread_local ThreadLocalData thread_local_data_; + + // A indexed container for Slots that has to be deferred to delete due to out-going callbacks + // pointing to the Slot. To let the ref_count_ deleter find the SlotImpl by address, the container + // is defined as a map of SlotImpl address to the unique_ptr. + absl::flat_hash_map> deferred_deletes_; + std::vector slots_; // A list of index of freed slots. std::list free_slot_indexes_; diff --git a/test/common/thread_local/thread_local_impl_test.cc b/test/common/thread_local/thread_local_impl_test.cc index 5a678c839384..fb3bd1cf3962 100644 --- a/test/common/thread_local/thread_local_impl_test.cc +++ b/test/common/thread_local/thread_local_impl_test.cc @@ -85,6 +85,40 @@ TEST_F(ThreadLocalInstanceImplTest, All) { tls_.shutdownThread(); } +// Test that the config passed into the update callback is the previous version stored in the slot. +TEST_F(ThreadLocalInstanceImplTest, UpdateCallback) { + InSequence s; + + SlotPtr slot = tls_.allocateSlot(); + + auto newer_version = std::make_shared(); + bool update_called = false; + + TestThreadLocalObject& object_ref = setObject(*slot); + auto update_cb = [&object_ref, &update_called, + newer_version](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr { + // The unit test setup have two dispatchers registered, but only one thread, this lambda will be + // called twice in the same thread. + if (!update_called) { + EXPECT_EQ(obj.get(), &object_ref); + update_called = true; + } else { + EXPECT_EQ(obj.get(), newer_version.get()); + } + + return newer_version; + }; + EXPECT_CALL(thread_dispatcher_, post(_)); + EXPECT_CALL(object_ref, onDestroy()); + EXPECT_CALL(*newer_version, onDestroy()); + slot->runOnAllThreads(update_cb); + + EXPECT_EQ(newer_version.get(), &slot->getTyped()); + + tls_.shutdownGlobalThreading(); + tls_.shutdownThread(); +} + // TODO(ramaraochavali): Run this test with real threads. The current issue in the unit // testing environment is, the post to main_dispatcher is not working as expected. diff --git a/test/mocks/thread_local/mocks.h b/test/mocks/thread_local/mocks.h index 3d7a43efaef8..a9abc6a6d562 100644 --- a/test/mocks/thread_local/mocks.h +++ b/test/mocks/thread_local/mocks.h @@ -63,6 +63,14 @@ class MockInstance : public Instance { void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override { parent_.runOnAllThreads(cb, main_callback); } + void runOnAllThreads(const UpdateCb& cb) override { + parent_.runOnAllThreads([cb, this]() { parent_.data_[index_] = cb(parent_.data_[index_]); }); + } + void runOnAllThreads(const UpdateCb& cb, Event::PostCb main_callback) override { + parent_.runOnAllThreads([cb, this]() { parent_.data_[index_] = cb(parent_.data_[index_]); }, + main_callback); + } + void set(InitializeCb cb) override { parent_.data_[index_] = cb(parent_.dispatcher_); } MockInstance& parent_;