From ee0947fe8e4266f9f064adee2995f11a30b35555 Mon Sep 17 00:00:00 2001 From: weimch Date: Thu, 13 Jun 2024 16:26:48 +0800 Subject: [PATCH] Feature: add abstract and factory classes for overload controller module --- trpc/common/BUILD | 1 + trpc/common/trpc_plugin.cc | 7 ++ trpc/overload_control/BUILD | 37 ++++++ trpc/overload_control/common/BUILD | 8 ++ .../common/overload_control_factory.h | 105 ++++++++++++++++++ .../server_overload_controller.h | 48 ++++++++ .../server_overload_controller_factory.h | 16 +++ ...server_overload_controller_factory_test.cc | 55 +++++++++ trpc/overload_control/testing/BUILD | 19 ++++ .../testing/overload_control_testing.h | 54 +++++++++ .../overload_control/trpc_overload_control.cc | 25 +++++ trpc/overload_control/trpc_overload_control.h | 17 +++ 12 files changed, 392 insertions(+) create mode 100644 trpc/overload_control/common/overload_control_factory.h create mode 100644 trpc/overload_control/server_overload_controller.h create mode 100644 trpc/overload_control/server_overload_controller_factory.h create mode 100644 trpc/overload_control/server_overload_controller_factory_test.cc create mode 100644 trpc/overload_control/testing/BUILD create mode 100644 trpc/overload_control/testing/overload_control_testing.h create mode 100644 trpc/overload_control/trpc_overload_control.cc create mode 100644 trpc/overload_control/trpc_overload_control.h diff --git a/trpc/common/BUILD b/trpc/common/BUILD index ef8e9d6a..9dc9dad1 100644 --- a/trpc/common/BUILD +++ b/trpc/common/BUILD @@ -136,6 +136,7 @@ cc_library( "//trpc/transport/common:ssl_helper", "//trpc/util/log/default:default_log", "//trpc/util:net_util", + "//trpc/overload_control:trpc_overload_control", ] + select({ "//trpc:trpc_include_rpcz": [ "//trpc/rpcz:collector", diff --git a/trpc/common/trpc_plugin.cc b/trpc/common/trpc_plugin.cc index 7d184b22..b2854aa1 100644 --- a/trpc/common/trpc_plugin.cc +++ b/trpc/common/trpc_plugin.cc @@ -35,6 +35,7 @@ #ifdef TRPC_BUILD_INCLUDE_RPCZ #include "trpc/rpcz/collector.h" #endif +#include "trpc/overload_control/trpc_overload_control.h" #include "trpc/runtime/common/periphery_task_scheduler.h" #include "trpc/runtime/common/runtime_info_report/runtime_info_reporter.h" #include "trpc/runtime/common/stats/frame_stats.h" @@ -82,6 +83,8 @@ int TrpcPlugin::RegisterPlugins() { TRPC_ASSERT(telemetry::Init()); TRPC_ASSERT(naming::Init()); + TRPC_ASSERT(overload_control::Init()); + CollectPlugins(); InitPlugins(); @@ -229,6 +232,8 @@ int TrpcPlugin::UnregisterPlugins() { StopPlugins(); + overload_control::Stop(); + PeripheryTaskScheduler::GetInstance()->Stop(); PeripheryTaskScheduler::GetInstance()->Join(); @@ -529,6 +534,8 @@ void TrpcPlugin::DestroyResource() { log::Destroy(); + overload_control::Destroy(); + GetTrpcClient()->Destroy(); is_all_inited_ = false; diff --git a/trpc/overload_control/BUILD b/trpc/overload_control/BUILD index 3bbf042c..edee480d 100644 --- a/trpc/overload_control/BUILD +++ b/trpc/overload_control/BUILD @@ -12,3 +12,40 @@ cc_library( }), visibility = ["//visibility:public"], ) + +cc_library( + name = "server_overload_controller", + hdrs = ["server_overload_controller.h"], + deps = [ + "//trpc/server:server_context", + ], +) + +cc_library( + name = "server_overload_controller_factory", + hdrs = ["server_overload_controller_factory.h"], + deps = [ + ":server_overload_controller", + "//trpc/overload_control/common:overload_control_factory", + ], +) + +cc_test( + name = "server_overload_controller_factory_test", + srcs = ["server_overload_controller_factory_test.cc"], + deps = [ + ":server_overload_controller_factory", + "//trpc/overload_control/testing:overload_control_testing", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "trpc_overload_control", + srcs = ["trpc_overload_control.cc"], + hdrs = ["trpc_overload_control.h"], + deps = [ + ":server_overload_controller_factory", + "//trpc/filter:filter_manager", + ], +) diff --git a/trpc/overload_control/common/BUILD b/trpc/overload_control/common/BUILD index 97869057..811b7472 100644 --- a/trpc/overload_control/common/BUILD +++ b/trpc/overload_control/common/BUILD @@ -188,3 +188,11 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "overload_control_factory", + hdrs = ["overload_control_factory.h"], + deps = [ + "//trpc/log:trpc_log", + ], +) diff --git a/trpc/overload_control/common/overload_control_factory.h b/trpc/overload_control/common/overload_control_factory.h new file mode 100644 index 00000000..06cb7b32 --- /dev/null +++ b/trpc/overload_control/common/overload_control_factory.h @@ -0,0 +1,105 @@ +// Copyright (c) 2024, Tencent Inc. +// All rights reserved. + +#pragma once + +#include +#include + +#include "trpc/log/trpc_log.h" + +namespace trpc::overload_control { + +/// @brief Factory of overload control strategy(as template T). +template +class OverloadControlFactory { + public: + /// @brief Singleton + static OverloadControlFactory* GetInstance() { + static OverloadControlFactory instance; + return &instance; + } + + /// @brief Can't construct by user. + OverloadControlFactory(const OverloadControlFactory&) = delete; + OverloadControlFactory& operator=(const OverloadControlFactory&) = delete; + + /// @brief Register the overload control strategy. + /// @param obj strategy + /// @note Non-thread-safe + bool Register(const T& obj); + + /// @brief Get the overload control strategy by name + /// @param name name of strategy + /// @return strategy + T Get(const std::string& name); + + /// @brief Get number of strategies. + /// @return number of strategies + /// @note Non-thread-safe + size_t Size() const { return objs_map_.size(); } + + /// @brief Stop all of overload control strategies. + // Mainly used to stop inner thread createdy by each strategy + /// @note Non-thread-safe. + void Stop(); + + /// @brief Destroy resource of overload control strategies. + /// @note Non-thread-safe. + void Destroy(); + + /// @brief Clear overload control strategies in this factory. + /// @note Non-thread-safe. + void Clear() { objs_map_.clear(); } + + private: + OverloadControlFactory() = default; + + private: + // strategies mapping(name->stratege obj) + std::unordered_map objs_map_; +}; + +template +bool OverloadControlFactory::Register(const T& obj) { + if (!obj) { + TRPC_FMT_ERROR("register object is nullptr"); + return false; + } + if (Get(obj->Name())) { + return false; + } + if (obj->Init()) { + objs_map_.emplace(obj->Name(), obj); + return true; + } + TRPC_FMT_ERROR("{} is `Init` failed ", obj->Name()); + return false; +} + +template +T OverloadControlFactory::Get(const std::string& name) { + T obj = nullptr; + auto iter = objs_map_.find(name); + if (iter != objs_map_.end()) { + obj = iter->second; + } + return obj; +} + +template +void OverloadControlFactory::Stop() { + for (auto& obj : objs_map_) { + obj.second->Stop(); + } +} + +template +void OverloadControlFactory::Destroy() { + for (auto& obj : objs_map_) { + obj.second->Destroy(); + } + Clear(); +} + +} // namespace trpc::overload_control diff --git a/trpc/overload_control/server_overload_controller.h b/trpc/overload_control/server_overload_controller.h new file mode 100644 index 00000000..3c6771bf --- /dev/null +++ b/trpc/overload_control/server_overload_controller.h @@ -0,0 +1,48 @@ +// Copyright (c) 2024, Tencent Inc. +// All rights reserved. + +#pragma once + +#include +#include + +#include "trpc/server/server_context.h" + +namespace trpc::overload_control { + +/// @brief Base class of overload controller. +class ServerOverloadController { + public: + virtual ~ServerOverloadController() = default; + + /// @brief Name of this controller. + virtual std::string Name() = 0; + + /// @brief Initialize controller. + /// You can allocate resources or start thread as controller need. + /// @return bool true: succ; false: failed + virtual bool Init() { return true; } + + /// @brief Whether this request should be scheduled to handle. + /// When reject this request, you should also set status with error code TRPC_SERVER_OVERLOAD_ERR + /// into context. + /// @param context server context. + /// @return bool true: this request will be handled; false: this request should be rejected. + virtual bool BeforeSchedule(const ServerContextPtr& context) = 0; + + /// @brief After this request being sheduled. At this point, it may be handled or rejected. + // You can check status from context to distinguish these 2 scenes when implement. + /// @param context server context. + /// @return bool true: succ; false: failed. + virtual bool AfterSchedule(const ServerContextPtr& context) = 0; + + /// @brief Stop controller. One can stop the thread execution of controller implemetation. + virtual void Stop() {} + + /// @brief Destroy resources of controller. + virtual void Destroy() {} +}; + +using ServerOverloadControllerPtr = std::shared_ptr; + +} // namespace trpc::overload_control diff --git a/trpc/overload_control/server_overload_controller_factory.h b/trpc/overload_control/server_overload_controller_factory.h new file mode 100644 index 00000000..a6d1be50 --- /dev/null +++ b/trpc/overload_control/server_overload_controller_factory.h @@ -0,0 +1,16 @@ +// Copyright (c) 2024, Tencent Inc. +// All rights reserved. + +#pragma once + +#include +#include + +#include "trpc/overload_control/common/overload_control_factory.h" +#include "trpc/overload_control/server_overload_controller.h" + +namespace trpc::overload_control { + +using ServerOverloadControllerFactory = OverloadControlFactory; + +} // namespace trpc::overload_control diff --git a/trpc/overload_control/server_overload_controller_factory_test.cc b/trpc/overload_control/server_overload_controller_factory_test.cc new file mode 100644 index 00000000..0ec90dda --- /dev/null +++ b/trpc/overload_control/server_overload_controller_factory_test.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2024, Tencent Inc. +// All rights reserved. + +#include "trpc/overload_control/server_overload_controller_factory.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "trpc/log/trpc_log.h" +#include "trpc/overload_control/testing/overload_control_testing.h" + +namespace trpc::overload_control { + +namespace testing { + +TEST(ServerOverloadControllerFactory, All) { + // Testing register interface + { + // 1. Register nullptr, failed. + ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(nullptr)); + // 2. First time register, succ. + ServerOverloadControllerPtr controller = std::make_shared(); + MockServerOverloadController* mock_controller = static_cast(controller.get()); + EXPECT_CALL(*mock_controller, Init()).WillOnce(::testing::Return(false)); + EXPECT_CALL(*mock_controller, Name()).WillRepeatedly(::testing::Return(std::string("mock_controller"))); + ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(controller)); + + EXPECT_CALL(*mock_controller, Init()).WillOnce(::testing::Return(true)); + ASSERT_TRUE(ServerOverloadControllerFactory::GetInstance()->Register(controller)); + // 3. Duplicated register, failed. + ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(controller)); + + auto size = ServerOverloadControllerFactory::GetInstance()->Size(); + ASSERT_EQ(size, 1); + } + // Testing get interface + { + ServerOverloadControllerPtr controller = ServerOverloadControllerFactory::GetInstance()->Get("xxx"); + ASSERT_EQ(controller, nullptr); + controller = ServerOverloadControllerFactory::GetInstance()->Get("mock_controller"); + ASSERT_NE(controller, nullptr); + } + + // Testing series of cleaning interface + { + ServerOverloadControllerFactory::GetInstance()->Stop(); + ServerOverloadControllerFactory::GetInstance()->Destroy(); + auto size = ServerOverloadControllerFactory::GetInstance()->Size(); + ASSERT_EQ(size, 0); + } +} + +} // namespace testing + +} // namespace trpc::overload_control diff --git a/trpc/overload_control/testing/BUILD b/trpc/overload_control/testing/BUILD new file mode 100644 index 00000000..13fc4b36 --- /dev/null +++ b/trpc/overload_control/testing/BUILD @@ -0,0 +1,19 @@ +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "overload_control_testing", + hdrs = ["overload_control_testing.h"], + visibility = ["//visibility:public"], + deps = [ + "//trpc/codec:protocol", + "//trpc/coroutine:fiber", + "//trpc/coroutine/testing:fiber_runtime_test", + "//trpc/filter:filter_manager", + "//trpc/overload_control:server_overload_controller", + "//trpc/server:service", + "//trpc/server/testing:service_adapter_testing", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/trpc/overload_control/testing/overload_control_testing.h b/trpc/overload_control/testing/overload_control_testing.h new file mode 100644 index 00000000..bd5c6f8d --- /dev/null +++ b/trpc/overload_control/testing/overload_control_testing.h @@ -0,0 +1,54 @@ +// Copyright (c) 2024, Tencent Inc. +// All rights reserved. + +#pragma once + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "trpc/codec/protocol.h" +#include "trpc/coroutine/fiber.h" +#include "trpc/coroutine/fiber_latch.h" +#include "trpc/coroutine/testing/fiber_runtime.h" +#include "trpc/filter/filter_manager.h" +#include "trpc/overload_control/server_overload_controller.h" +#include "trpc/server/service.h" +#include "trpc/server/testing/service_adapter_testing.h" + +namespace trpc::overload_control { +namespace testing { + +// Mock protocol, only allowed to be used at overload control module. +class MockProtocol : public Protocol { + public: + MOCK_METHOD1(ZeroCopyDecode, bool(NoncontiguousBuffer&)); + MOCK_METHOD1(ZeroCopyEncode, bool(NoncontiguousBuffer&)); + MOCK_METHOD1(SetCallType, void(RpcCallType)); + MOCK_METHOD0(GetCallType, RpcCallType()); +}; + +using MockProtocolPtr = std::shared_ptr; + +// Get filter object by filter point and filter name. +inline MessageServerFilterPtr GetGlobalServerFilterByName(FilterPoint type, const std::string& name) { + const std::deque& filters = FilterManager::GetInstance()->GetMessageServerGlobalFilter(type); + for (auto& filter : filters) { + if (!filter->Name().compare(name)) { + return filter; + } + } + return nullptr; +} + +class MockServerOverloadController : public ServerOverloadController { + public: + MOCK_METHOD0(Name, std::string()); + MOCK_METHOD0(Init, bool()); + MOCK_METHOD1(BeforeSchedule, bool(const ServerContextPtr&)); + MOCK_METHOD1(AfterSchedule, bool(const ServerContextPtr&)); +}; + +} // namespace testing +} // namespace trpc::overload_control diff --git a/trpc/overload_control/trpc_overload_control.cc b/trpc/overload_control/trpc_overload_control.cc new file mode 100644 index 00000000..a960be39 --- /dev/null +++ b/trpc/overload_control/trpc_overload_control.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2024, Tencent Inc. +// All rights reserved. + +#include "trpc/overload_control/trpc_overload_control.h" + +#include "trpc/overload_control/server_overload_controller_factory.h" + +namespace trpc::overload_control { + +bool Init() { + // Register plugins here + return true; +} + +void Stop() { + // Stop plugins here + ServerOverloadControllerFactory::GetInstance()->Stop(); +} + +void Destroy() { + // Destroy plugins here + ServerOverloadControllerFactory::GetInstance()->Destroy(); +} + +} // namespace trpc::overload_control diff --git a/trpc/overload_control/trpc_overload_control.h b/trpc/overload_control/trpc_overload_control.h new file mode 100644 index 00000000..a14e77ba --- /dev/null +++ b/trpc/overload_control/trpc_overload_control.h @@ -0,0 +1,17 @@ +// Copyright (c) 2024, Tencent Inc. +// All rights reserved. + +#pragma once + +namespace trpc::overload_control { + +/// @brief Intialize overload plugins +bool Init(); + +/// @brief Stop inner started thread(maybe) of overload plugins +void Stop(); + +/// @brief Destroy resource of overload plugins +void Destroy(); + +} // namespace trpc::overload_control