Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: add abstract and factory classes for overload controller module #134

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trpc/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ cc_library(
"//trpc/transport/common:ssl_helper",
"//trpc/util/log/default:default_log",
"//trpc/util:net_util",
"//trpc/overload_control:trpc_overload_control",
] + select({
"//trpc:trpc_include_rpcz": [
"//trpc/rpcz:collector",
Expand Down
7 changes: 7 additions & 0 deletions trpc/common/trpc_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#ifdef TRPC_BUILD_INCLUDE_RPCZ
#include "trpc/rpcz/collector.h"
#endif
#include "trpc/overload_control/trpc_overload_control.h"
#include "trpc/runtime/common/periphery_task_scheduler.h"
#include "trpc/runtime/common/runtime_info_report/runtime_info_reporter.h"
#include "trpc/runtime/common/stats/frame_stats.h"
Expand Down Expand Up @@ -82,6 +83,8 @@ int TrpcPlugin::RegisterPlugins() {
TRPC_ASSERT(telemetry::Init());
TRPC_ASSERT(naming::Init());

TRPC_ASSERT(overload_control::Init());

CollectPlugins();
InitPlugins();

Expand Down Expand Up @@ -229,6 +232,8 @@ int TrpcPlugin::UnregisterPlugins() {

StopPlugins();

overload_control::Stop();

PeripheryTaskScheduler::GetInstance()->Stop();
PeripheryTaskScheduler::GetInstance()->Join();

Expand Down Expand Up @@ -529,6 +534,8 @@ void TrpcPlugin::DestroyResource() {

log::Destroy();

overload_control::Destroy();

GetTrpcClient()->Destroy();

is_all_inited_ = false;
Expand Down
37 changes: 37 additions & 0 deletions trpc/overload_control/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,40 @@ cc_library(
}),
visibility = ["//visibility:public"],
)

cc_library(
name = "server_overload_controller",
hdrs = ["server_overload_controller.h"],
deps = [
"//trpc/server:server_context",
],
)

cc_library(
name = "server_overload_controller_factory",
hdrs = ["server_overload_controller_factory.h"],
deps = [
":server_overload_controller",
"//trpc/overload_control/common:overload_control_factory",
],
)

cc_test(
name = "server_overload_controller_factory_test",
srcs = ["server_overload_controller_factory_test.cc"],
deps = [
":server_overload_controller_factory",
"//trpc/overload_control/testing:overload_control_testing",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "trpc_overload_control",
srcs = ["trpc_overload_control.cc"],
hdrs = ["trpc_overload_control.h"],
deps = [
":server_overload_controller_factory",
"//trpc/filter:filter_manager",
],
)
8 changes: 8 additions & 0 deletions trpc/overload_control/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,11 @@ cc_test(
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "overload_control_factory",
hdrs = ["overload_control_factory.h"],
deps = [
"//trpc/log:trpc_log",
],
)
105 changes: 105 additions & 0 deletions trpc/overload_control/common/overload_control_factory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#pragma once

#include <string>
#include <unordered_map>

#include "trpc/log/trpc_log.h"

namespace trpc::overload_control {

/// @brief Factory of overload control strategy(as template T).
template <class T>
class OverloadControlFactory {
public:
/// @brief Singleton
static OverloadControlFactory* GetInstance() {
static OverloadControlFactory instance;
return &instance;
}

/// @brief Can't construct by user.
OverloadControlFactory(const OverloadControlFactory&) = delete;
OverloadControlFactory& operator=(const OverloadControlFactory&) = delete;

/// @brief Register the overload control strategy.
/// @param obj strategy
/// @note Non-thread-safe
bool Register(const T& obj);

/// @brief Get the overload control strategy by name
/// @param name name of strategy
/// @return strategy
T Get(const std::string& name);

/// @brief Get number of strategies.
/// @return number of strategies
/// @note Non-thread-safe
size_t Size() const { return objs_map_.size(); }

/// @brief Stop all of overload control strategies.
// Mainly used to stop inner thread createdy by each strategy
/// @note Non-thread-safe.
void Stop();

/// @brief Destroy resource of overload control strategies.
/// @note Non-thread-safe.
void Destroy();

/// @brief Clear overload control strategies in this factory.
/// @note Non-thread-safe.
void Clear() { objs_map_.clear(); }

private:
OverloadControlFactory() = default;

private:
// strategies mapping(name->stratege obj)
std::unordered_map<std::string, T> objs_map_;
};

template <class T>
bool OverloadControlFactory<T>::Register(const T& obj) {
if (!obj) {
TRPC_FMT_ERROR("register object is nullptr");
return false;
}
if (Get(obj->Name())) {
return false;
}
if (obj->Init()) {
objs_map_.emplace(obj->Name(), obj);
return true;
}
TRPC_FMT_ERROR("{} is `Init` failed ", obj->Name());
return false;
}

template <class T>
T OverloadControlFactory<T>::Get(const std::string& name) {
T obj = nullptr;
auto iter = objs_map_.find(name);
if (iter != objs_map_.end()) {
obj = iter->second;
}
return obj;
}

template <class T>
void OverloadControlFactory<T>::Stop() {
for (auto& obj : objs_map_) {
obj.second->Stop();
}
}

template <class T>
void OverloadControlFactory<T>::Destroy() {
for (auto& obj : objs_map_) {
obj.second->Destroy();
}
Clear();
}

} // namespace trpc::overload_control
48 changes: 48 additions & 0 deletions trpc/overload_control/server_overload_controller.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#pragma once

#include <memory>
#include <string>

#include "trpc/server/server_context.h"

namespace trpc::overload_control {

/// @brief Base class of overload controller.
class ServerOverloadController {
public:
virtual ~ServerOverloadController() = default;

/// @brief Name of this controller.
virtual std::string Name() = 0;

/// @brief Initialize controller.
/// You can allocate resources or start thread as controller need.
/// @return bool true: succ; false: failed
virtual bool Init() { return true; }

/// @brief Whether this request should be scheduled to handle.
/// When reject this request, you should also set status with error code TRPC_SERVER_OVERLOAD_ERR
/// into context.
/// @param context server context.
/// @return bool true: this request will be handled; false: this request should be rejected.
virtual bool BeforeSchedule(const ServerContextPtr& context) = 0;

/// @brief After this request being sheduled. At this point, it may be handled or rejected.
// You can check status from context to distinguish these 2 scenes when implement.
/// @param context server context.
/// @return bool true: succ; false: failed.
virtual bool AfterSchedule(const ServerContextPtr& context) = 0;

/// @brief Stop controller. One can stop the thread execution of controller implemetation.
virtual void Stop() {}

/// @brief Destroy resources of controller.
virtual void Destroy() {}
};

using ServerOverloadControllerPtr = std::shared_ptr<ServerOverloadController>;

} // namespace trpc::overload_control
16 changes: 16 additions & 0 deletions trpc/overload_control/server_overload_controller_factory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#pragma once

#include <string>
#include <unordered_map>

#include "trpc/overload_control/common/overload_control_factory.h"
#include "trpc/overload_control/server_overload_controller.h"

namespace trpc::overload_control {

using ServerOverloadControllerFactory = OverloadControlFactory<ServerOverloadControllerPtr>;

} // namespace trpc::overload_control
55 changes: 55 additions & 0 deletions trpc/overload_control/server_overload_controller_factory_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#include "trpc/overload_control/server_overload_controller_factory.h"

#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include "trpc/log/trpc_log.h"
#include "trpc/overload_control/testing/overload_control_testing.h"

namespace trpc::overload_control {

namespace testing {

TEST(ServerOverloadControllerFactory, All) {
// Testing register interface
{
// 1. Register nullptr, failed.
ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(nullptr));
// 2. First time register, succ.
ServerOverloadControllerPtr controller = std::make_shared<MockServerOverloadController>();
MockServerOverloadController* mock_controller = static_cast<MockServerOverloadController*>(controller.get());
EXPECT_CALL(*mock_controller, Init()).WillOnce(::testing::Return(false));
EXPECT_CALL(*mock_controller, Name()).WillRepeatedly(::testing::Return(std::string("mock_controller")));
ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(controller));

EXPECT_CALL(*mock_controller, Init()).WillOnce(::testing::Return(true));
ASSERT_TRUE(ServerOverloadControllerFactory::GetInstance()->Register(controller));
// 3. Duplicated register, failed.
ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(controller));

auto size = ServerOverloadControllerFactory::GetInstance()->Size();
ASSERT_EQ(size, 1);
}
// Testing get interface
{
ServerOverloadControllerPtr controller = ServerOverloadControllerFactory::GetInstance()->Get("xxx");
ASSERT_EQ(controller, nullptr);
controller = ServerOverloadControllerFactory::GetInstance()->Get("mock_controller");
ASSERT_NE(controller, nullptr);
}

// Testing series of cleaning interface
{
ServerOverloadControllerFactory::GetInstance()->Stop();
ServerOverloadControllerFactory::GetInstance()->Destroy();
auto size = ServerOverloadControllerFactory::GetInstance()->Size();
ASSERT_EQ(size, 0);
}
}

} // namespace testing

} // namespace trpc::overload_control
19 changes: 19 additions & 0 deletions trpc/overload_control/testing/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
licenses(["notice"])

package(default_visibility = ["//visibility:public"])

cc_library(
name = "overload_control_testing",
hdrs = ["overload_control_testing.h"],
visibility = ["//visibility:public"],
deps = [
"//trpc/codec:protocol",
"//trpc/coroutine:fiber",
"//trpc/coroutine/testing:fiber_runtime_test",
"//trpc/filter:filter_manager",
"//trpc/overload_control:server_overload_controller",
"//trpc/server:service",
"//trpc/server/testing:service_adapter_testing",
"@com_google_googletest//:gtest_main",
],
)
54 changes: 54 additions & 0 deletions trpc/overload_control/testing/overload_control_testing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2024, Tencent Inc.
// All rights reserved.

#pragma once

#include <atomic>

#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include "trpc/codec/protocol.h"
#include "trpc/coroutine/fiber.h"
#include "trpc/coroutine/fiber_latch.h"
#include "trpc/coroutine/testing/fiber_runtime.h"
#include "trpc/filter/filter_manager.h"
#include "trpc/overload_control/server_overload_controller.h"
#include "trpc/server/service.h"
#include "trpc/server/testing/service_adapter_testing.h"

namespace trpc::overload_control {
namespace testing {

// Mock protocol, only allowed to be used at overload control module.
class MockProtocol : public Protocol {
public:
MOCK_METHOD1(ZeroCopyDecode, bool(NoncontiguousBuffer&));
MOCK_METHOD1(ZeroCopyEncode, bool(NoncontiguousBuffer&));
MOCK_METHOD1(SetCallType, void(RpcCallType));
MOCK_METHOD0(GetCallType, RpcCallType());
};

using MockProtocolPtr = std::shared_ptr<MockProtocol>;

// Get filter object by filter point and filter name.
inline MessageServerFilterPtr GetGlobalServerFilterByName(FilterPoint type, const std::string& name) {
const std::deque<MessageServerFilterPtr>& filters = FilterManager::GetInstance()->GetMessageServerGlobalFilter(type);
for (auto& filter : filters) {
if (!filter->Name().compare(name)) {
return filter;
}
}
return nullptr;
}

class MockServerOverloadController : public ServerOverloadController {
public:
MOCK_METHOD0(Name, std::string());
MOCK_METHOD0(Init, bool());
MOCK_METHOD1(BeforeSchedule, bool(const ServerContextPtr&));
MOCK_METHOD1(AfterSchedule, bool(const ServerContextPtr&));
};

} // namespace testing
} // namespace trpc::overload_control
Loading
Loading