Skip to content

Commit

Permalink
Feature: add abstract and factory classes for overload controller module
Browse files Browse the repository at this point in the history
  • Loading branch information
weimch committed Jun 25, 2024
1 parent 00e845f commit ee0947f
Show file tree
Hide file tree
Showing 12 changed files with 392 additions and 0 deletions.
1 change: 1 addition & 0 deletions trpc/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ cc_library(
"//trpc/transport/common:ssl_helper",
"//trpc/util/log/default:default_log",
"//trpc/util:net_util",
"//trpc/overload_control:trpc_overload_control",
] + select({
"//trpc:trpc_include_rpcz": [
"//trpc/rpcz:collector",
Expand Down
7 changes: 7 additions & 0 deletions trpc/common/trpc_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#ifdef TRPC_BUILD_INCLUDE_RPCZ
#include "trpc/rpcz/collector.h"
#endif
#include "trpc/overload_control/trpc_overload_control.h"
#include "trpc/runtime/common/periphery_task_scheduler.h"
#include "trpc/runtime/common/runtime_info_report/runtime_info_reporter.h"
#include "trpc/runtime/common/stats/frame_stats.h"
Expand Down Expand Up @@ -82,6 +83,8 @@ int TrpcPlugin::RegisterPlugins() {
TRPC_ASSERT(telemetry::Init());
TRPC_ASSERT(naming::Init());

TRPC_ASSERT(overload_control::Init());

CollectPlugins();
InitPlugins();

Expand Down Expand Up @@ -229,6 +232,8 @@ int TrpcPlugin::UnregisterPlugins() {

StopPlugins();

overload_control::Stop();

PeripheryTaskScheduler::GetInstance()->Stop();
PeripheryTaskScheduler::GetInstance()->Join();

Expand Down Expand Up @@ -529,6 +534,8 @@ void TrpcPlugin::DestroyResource() {

log::Destroy();

overload_control::Destroy();

GetTrpcClient()->Destroy();

is_all_inited_ = false;
Expand Down
37 changes: 37 additions & 0 deletions trpc/overload_control/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,40 @@ cc_library(
}),
visibility = ["//visibility:public"],
)

cc_library(
name = "server_overload_controller",
hdrs = ["server_overload_controller.h"],
deps = [
"//trpc/server:server_context",
],
)

cc_library(
name = "server_overload_controller_factory",
hdrs = ["server_overload_controller_factory.h"],
deps = [
":server_overload_controller",
"//trpc/overload_control/common:overload_control_factory",
],
)

cc_test(
name = "server_overload_controller_factory_test",
srcs = ["server_overload_controller_factory_test.cc"],
deps = [
":server_overload_controller_factory",
"//trpc/overload_control/testing:overload_control_testing",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "trpc_overload_control",
srcs = ["trpc_overload_control.cc"],
hdrs = ["trpc_overload_control.h"],
deps = [
":server_overload_controller_factory",
"//trpc/filter:filter_manager",
],
)
8 changes: 8 additions & 0 deletions trpc/overload_control/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,11 @@ cc_test(
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "overload_control_factory",
hdrs = ["overload_control_factory.h"],
deps = [
"//trpc/log:trpc_log",
],
)
105 changes: 105 additions & 0 deletions trpc/overload_control/common/overload_control_factory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#pragma once

#include <string>
#include <unordered_map>

#include "trpc/log/trpc_log.h"

namespace trpc::overload_control {

/// @brief Factory of overload control strategy(as template T).
template <class T>
class OverloadControlFactory {
public:
/// @brief Singleton
static OverloadControlFactory* GetInstance() {
static OverloadControlFactory instance;
return &instance;
}

/// @brief Can't construct by user.
OverloadControlFactory(const OverloadControlFactory&) = delete;
OverloadControlFactory& operator=(const OverloadControlFactory&) = delete;

/// @brief Register the overload control strategy.
/// @param obj strategy
/// @note Non-thread-safe
bool Register(const T& obj);

/// @brief Get the overload control strategy by name
/// @param name name of strategy
/// @return strategy
T Get(const std::string& name);

/// @brief Get number of strategies.
/// @return number of strategies
/// @note Non-thread-safe
size_t Size() const { return objs_map_.size(); }

/// @brief Stop all of overload control strategies.
// Mainly used to stop inner thread createdy by each strategy
/// @note Non-thread-safe.
void Stop();

/// @brief Destroy resource of overload control strategies.
/// @note Non-thread-safe.
void Destroy();

/// @brief Clear overload control strategies in this factory.
/// @note Non-thread-safe.
void Clear() { objs_map_.clear(); }

private:
OverloadControlFactory() = default;

private:
// strategies mapping(name->stratege obj)
std::unordered_map<std::string, T> objs_map_;
};

template <class T>
bool OverloadControlFactory<T>::Register(const T& obj) {
if (!obj) {
TRPC_FMT_ERROR("register object is nullptr");
return false;
}
if (Get(obj->Name())) {
return false;
}
if (obj->Init()) {
objs_map_.emplace(obj->Name(), obj);
return true;
}
TRPC_FMT_ERROR("{} is `Init` failed ", obj->Name());
return false;
}

template <class T>
T OverloadControlFactory<T>::Get(const std::string& name) {
T obj = nullptr;
auto iter = objs_map_.find(name);
if (iter != objs_map_.end()) {
obj = iter->second;
}
return obj;
}

template <class T>
void OverloadControlFactory<T>::Stop() {
for (auto& obj : objs_map_) {
obj.second->Stop();
}
}

template <class T>
void OverloadControlFactory<T>::Destroy() {
for (auto& obj : objs_map_) {
obj.second->Destroy();
}
Clear();
}

} // namespace trpc::overload_control
48 changes: 48 additions & 0 deletions trpc/overload_control/server_overload_controller.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#pragma once

#include <memory>
#include <string>

#include "trpc/server/server_context.h"

namespace trpc::overload_control {

/// @brief Base class of overload controller.
class ServerOverloadController {
public:
virtual ~ServerOverloadController() = default;

/// @brief Name of this controller.
virtual std::string Name() = 0;

/// @brief Initialize controller.
/// You can allocate resources or start thread as controller need.
/// @return bool true: succ; false: failed
virtual bool Init() { return true; }

/// @brief Whether this request should be scheduled to handle.
/// When reject this request, you should also set status with error code TRPC_SERVER_OVERLOAD_ERR
/// into context.
/// @param context server context.
/// @return bool true: this request will be handled; false: this request should be rejected.
virtual bool BeforeSchedule(const ServerContextPtr& context) = 0;

/// @brief After this request being sheduled. At this point, it may be handled or rejected.
// You can check status from context to distinguish these 2 scenes when implement.
/// @param context server context.
/// @return bool true: succ; false: failed.
virtual bool AfterSchedule(const ServerContextPtr& context) = 0;

/// @brief Stop controller. One can stop the thread execution of controller implemetation.
virtual void Stop() {}

/// @brief Destroy resources of controller.
virtual void Destroy() {}
};

using ServerOverloadControllerPtr = std::shared_ptr<ServerOverloadController>;

} // namespace trpc::overload_control
16 changes: 16 additions & 0 deletions trpc/overload_control/server_overload_controller_factory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#pragma once

#include <string>
#include <unordered_map>

#include "trpc/overload_control/common/overload_control_factory.h"
#include "trpc/overload_control/server_overload_controller.h"

namespace trpc::overload_control {

using ServerOverloadControllerFactory = OverloadControlFactory<ServerOverloadControllerPtr>;

} // namespace trpc::overload_control
55 changes: 55 additions & 0 deletions trpc/overload_control/server_overload_controller_factory_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#include "trpc/overload_control/server_overload_controller_factory.h"

#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include "trpc/log/trpc_log.h"
#include "trpc/overload_control/testing/overload_control_testing.h"

namespace trpc::overload_control {

namespace testing {

TEST(ServerOverloadControllerFactory, All) {
// Testing register interface
{
// 1. Register nullptr, failed.
ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(nullptr));
// 2. First time register, succ.
ServerOverloadControllerPtr controller = std::make_shared<MockServerOverloadController>();
MockServerOverloadController* mock_controller = static_cast<MockServerOverloadController*>(controller.get());
EXPECT_CALL(*mock_controller, Init()).WillOnce(::testing::Return(false));
EXPECT_CALL(*mock_controller, Name()).WillRepeatedly(::testing::Return(std::string("mock_controller")));
ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(controller));

EXPECT_CALL(*mock_controller, Init()).WillOnce(::testing::Return(true));
ASSERT_TRUE(ServerOverloadControllerFactory::GetInstance()->Register(controller));
// 3. Duplicated register, failed.
ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(controller));

auto size = ServerOverloadControllerFactory::GetInstance()->Size();
ASSERT_EQ(size, 1);
}
// Testing get interface
{
ServerOverloadControllerPtr controller = ServerOverloadControllerFactory::GetInstance()->Get("xxx");
ASSERT_EQ(controller, nullptr);
controller = ServerOverloadControllerFactory::GetInstance()->Get("mock_controller");
ASSERT_NE(controller, nullptr);
}

// Testing series of cleaning interface
{
ServerOverloadControllerFactory::GetInstance()->Stop();
ServerOverloadControllerFactory::GetInstance()->Destroy();
auto size = ServerOverloadControllerFactory::GetInstance()->Size();
ASSERT_EQ(size, 0);
}
}

} // namespace testing

} // namespace trpc::overload_control
19 changes: 19 additions & 0 deletions trpc/overload_control/testing/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
licenses(["notice"])

package(default_visibility = ["//visibility:public"])

cc_library(
name = "overload_control_testing",
hdrs = ["overload_control_testing.h"],
visibility = ["//visibility:public"],
deps = [
"//trpc/codec:protocol",
"//trpc/coroutine:fiber",
"//trpc/coroutine/testing:fiber_runtime_test",
"//trpc/filter:filter_manager",
"//trpc/overload_control:server_overload_controller",
"//trpc/server:service",
"//trpc/server/testing:service_adapter_testing",
"@com_google_googletest//:gtest_main",
],
)
54 changes: 54 additions & 0 deletions trpc/overload_control/testing/overload_control_testing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#pragma once

#include <atomic>

#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include "trpc/codec/protocol.h"
#include "trpc/coroutine/fiber.h"
#include "trpc/coroutine/fiber_latch.h"
#include "trpc/coroutine/testing/fiber_runtime.h"
#include "trpc/filter/filter_manager.h"
#include "trpc/overload_control/server_overload_controller.h"
#include "trpc/server/service.h"
#include "trpc/server/testing/service_adapter_testing.h"

namespace trpc::overload_control {
namespace testing {

// Mock protocol, only allowed to be used at overload control module.
class MockProtocol : public Protocol {
public:
MOCK_METHOD1(ZeroCopyDecode, bool(NoncontiguousBuffer&));
MOCK_METHOD1(ZeroCopyEncode, bool(NoncontiguousBuffer&));
MOCK_METHOD1(SetCallType, void(RpcCallType));
MOCK_METHOD0(GetCallType, RpcCallType());
};

using MockProtocolPtr = std::shared_ptr<MockProtocol>;

// Get filter object by filter point and filter name.
inline MessageServerFilterPtr GetGlobalServerFilterByName(FilterPoint type, const std::string& name) {
const std::deque<MessageServerFilterPtr>& filters = FilterManager::GetInstance()->GetMessageServerGlobalFilter(type);
for (auto& filter : filters) {
if (!filter->Name().compare(name)) {
return filter;
}
}
return nullptr;
}

class MockServerOverloadController : public ServerOverloadController {
public:
MOCK_METHOD0(Name, std::string());
MOCK_METHOD0(Init, bool());
MOCK_METHOD1(BeforeSchedule, bool(const ServerContextPtr&));
MOCK_METHOD1(AfterSchedule, bool(const ServerContextPtr&));
};

} // namespace testing
} // namespace trpc::overload_control
Loading

0 comments on commit ee0947f

Please sign in to comment.