Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: TRPC_STREAM do not have context #84

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trpc/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ cc_library(
"//trpc/codec/trpc",
"//trpc/common:status",
"//trpc/compressor:trpc_compressor",
"//trpc/coroutine:fiber_local",
"//trpc/filter:server_filter_controller_h",
"//trpc/serialization:serialization_type",
"//trpc/stream:stream_provider",
Expand Down
25 changes: 25 additions & 0 deletions trpc/server/server_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "trpc/common/config/trpc_config.h"
#include "trpc/compressor/trpc_compressor.h"
#include "trpc/coroutine/fiber_local.h"
#include "trpc/runtime/common/stats/frame_stats.h"
#include "trpc/serialization/serialization_factory.h"
#include "trpc/server/service.h"
Expand Down Expand Up @@ -303,4 +304,28 @@ void ServerContext::ThrottleConnection(bool set) {
}
}

// Context used for storing data in a fiber environment.
FiberLocal<ServerContext*> fls_server_context;

// Context used for storing data in a regular thread environment, such as setting it in a business thread and releasing
// it when the business request processing is completed.
thread_local ServerContext* tls_server_context = nullptr;

void SetLocalServerContext(const ServerContextPtr& context) {
// Set to fiberLocal in a fiber environment, and set to threadLocal in a regular thread environment.
if (trpc::fiber::detail::GetCurrentFiberEntity()) {
*fls_server_context = context.Get();
} else {
tls_server_context = context.Get();
}
}

ServerContextPtr GetLocalServerContext() {
// Retrieve from fiberLocal in a fiber environment, and retrieve from threadLocal in a regular thread environment.
if (trpc::fiber::detail::GetCurrentFiberEntity()) {
return RefPtr(ref_ptr, *fls_server_context);
}
return RefPtr(ref_ptr, tls_server_context);
}

} // namespace trpc
20 changes: 11 additions & 9 deletions trpc/server/server_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,27 +349,21 @@ class ServerContext : public RefCounted<ServerContext> {
void SetReqEncodeType(uint8_t type) { invoke_info_.req_encode_type = type; }

/// @brief Framework use or for testing. Set the compression type of request data.
void SetReqCompressType(uint8_t compress_type) {
invoke_info_.req_compress_type = compress_type;
}
void SetReqCompressType(uint8_t compress_type) { invoke_info_.req_compress_type = compress_type; }

/// @brief Get the compression type for compressing request message.
uint8_t GetReqCompressType() const { return invoke_info_.req_compress_type; }
/// @brief Deprecated: use `GetReqCompressType` instead.
[[deprecated("use GetReqCompressType instead")]] uint8_t GetCompressType() const { return GetReqCompressType(); }

/// @brief Set the compression type for decompressing response message.
void SetRspCompressType(uint8_t compress_type) {
invoke_info_.rsp_compress_type = compress_type;
}
void SetRspCompressType(uint8_t compress_type) { invoke_info_.rsp_compress_type = compress_type; }

/// @brief Get the compression type for decompressing response message.
uint8_t GetRspCompressType() const { return invoke_info_.rsp_compress_type; }

/// @brief Set the compression level for decompressing response message.
void SetRspCompressLevel(uint8_t compress_level) {
invoke_info_.rsp_compress_level = compress_level;
}
void SetRspCompressLevel(uint8_t compress_level) { invoke_info_.rsp_compress_level = compress_level; }

/// @brief Get the compression level for decompressing response message.
uint8_t GetRspCompressLevel() const { return invoke_info_.rsp_compress_level; }
Expand Down Expand Up @@ -794,4 +788,12 @@ using ServerContextPtr = RefPtr<ServerContext>;
template <typename T>
using is_server_context = std::is_same<T, ServerContext>;

/// @brief Set the context to a thread-private variable. The private variable itself does not hold the context. The set
/// operation must be used when the ctx is valid within its lifecycle.
void SetLocalServerContext(const ServerContextPtr& context);

/// @brief Retrieve the context from a thread-private variable. The private variable itself does not hold the context.
/// The get operation must be used when the ctx is valid within its lifecycle.
ServerContextPtr GetLocalServerContext();

} // namespace trpc
5 changes: 5 additions & 0 deletions trpc/server/service_adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ bool ServiceAdapter::HandleMessage(const ConnectionPtr& conn, std::deque<std::an
succ = false;
}

SetLocalServerContext(req_msg->context);

MsgTaskHandler msg_handler = [this, req_msg]() mutable {
auto& context = req_msg->context;

Expand All @@ -302,6 +304,7 @@ bool ServiceAdapter::HandleMessage(const ConnectionPtr& conn, std::deque<std::an

if (send) {
this->transport_->SendMsg(send);
SetLocalServerContext(nullptr);
}
};

Expand Down Expand Up @@ -364,6 +367,7 @@ bool ServiceAdapter::HandleFiberMessage(const ConnectionPtr& conn, std::deque<st
context->SetBeginTimestampUs(trpc::time::GetMicroSeconds());

RunServerFilters(FilterPoint::SERVER_POST_SCHED_RECV_MSG, req_msg);
SetLocalServerContext(req_msg->context);

STransportRspMsg* send = nullptr;
Service* service = context->GetService();
Expand All @@ -381,6 +385,7 @@ bool ServiceAdapter::HandleFiberMessage(const ConnectionPtr& conn, std::deque<st
if (send) {
send->context->SetReserved(static_cast<void*>(conn.Get()));
this->transport_->SendMsg(send);
SetLocalServerContext(nullptr);
}

conn->Deref();
Expand Down
2 changes: 1 addition & 1 deletion trpc/util/log/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ using trpc::kTrpcLogCacheStringDefault;
/// @note Use case: Separate business logs from framework logs,
/// Different business logs specify different loggers.
/// For example, if remote logs are connected, business logs can be output to remote.
#define TRPC_FLOW_LOG(instance, msg) TRPC_STREAM(instance, ::trpc::Log::info, msg)
#define TRPC_FLOW_LOG(instance, msg) TRPC_STREAM(instance, ::trpc::Log::info, nullptr, msg)
#define TRPC_FLOW_LOG_EX(context, instance, msg) TRPC_STREAM(instance, ::trpc::Log::info, context, msg)

/// @brief Provides ASSERT that does not invalidate in release mode
Expand Down
66 changes: 35 additions & 31 deletions trpc/util/log/stream_like.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,34 @@
__TRPC_STREAM__ << msg

/// @brief stream-like log macros
#define TRPC_STREAM(instance, level, msg) \
do { \
const auto& __TRPC_CPP_STREAM_LOGGER_INSTANCE__ = ::trpc::LogFactory::GetInstance()->Get(); \
if (__TRPC_CPP_STREAM_LOGGER_INSTANCE__) { \
if (__TRPC_CPP_STREAM_LOGGER_INSTANCE__->ShouldLog(instance, level)) { \
TRPC_LOG_TRY { \
STREAM_APPENDER(msg); \
__TRPC_CPP_STREAM_LOGGER_INSTANCE__->LogIt(instance, level, __FILE__, __LINE__, __FUNCTION__, \
__TRPC_STREAM__.str()); \
} \
TRPC_LOG_CATCH(instance) \
} \
} else { \
if (::trpc::Log::ShouldNoLog(instance, level)) { \
TRPC_LOG_TRY { \
STREAM_APPENDER(msg); \
::trpc::Log::NoLog(instance, level, __FILE__, __LINE__, __FUNCTION__, __TRPC_STREAM__.str()); \
} \
TRPC_LOG_CATCH(instance) \
} \
} \
#define TRPC_STREAM(instance, level, context, msg) \
do { \
const auto& __TRPC_CPP_STREAM_LOGGER_INSTANCE__ = ::trpc::LogFactory::GetInstance()->Get(); \
if (__TRPC_CPP_STREAM_LOGGER_INSTANCE__) { \
if (__TRPC_CPP_STREAM_LOGGER_INSTANCE__->ShouldLog(instance, level)) { \
TRPC_LOG_TRY { \
STREAM_APPENDER(msg); \
if (context) { \
__TRPC_CPP_STREAM_LOGGER_INSTANCE__->LogIt(instance, level, __FILE__, __LINE__, __FUNCTION__, \
__TRPC_STREAM__.str(), context->GetAllFilterData()); \
} else { \
__TRPC_CPP_STREAM_LOGGER_INSTANCE__->LogIt(instance, level, __FILE__, __LINE__, __FUNCTION__, \
__TRPC_STREAM__.str()); \
} \
} \
TRPC_LOG_CATCH(instance) \
} \
} else { \
if (::trpc::Log::ShouldNoLog(instance, level)) { \
TRPC_LOG_TRY { \
STREAM_APPENDER(msg); \
::trpc::Log::NoLog(instance, level, __FILE__, __LINE__, __FUNCTION__, __TRPC_STREAM__.str()); \
} \
TRPC_LOG_CATCH(instance) \
} \
} \
} while (0)


/// @brief stream-like log macros for tRPC-Cpp framework log
#define TRPC_STREAM_DEFAULT(instance, level, msg) \
do { \
Expand All @@ -73,7 +77,7 @@
} while (0)

/// @brief stream-like log macros for tRPC-Cpp framework
#define TRPC_STREAM_EX_DEFAULT(instance, level, context, msg) \
#define TRPC_STREAM_EX_DEFAULT(instance, level, context, msg) \
do { \
const auto& __TRPC_CPP_STREAM_LOGGER_INSTANCE__ = ::trpc::LogFactory::GetInstance()->Get(); \
if (__TRPC_CPP_STREAM_LOGGER_INSTANCE__) { \
Expand All @@ -96,7 +100,6 @@
} \
} while (0)


/// @brief stream-like log macros
#define TRPC_STREAM_EX(instance, level, context, msg) \
do { \
Expand Down Expand Up @@ -136,16 +139,17 @@
}

/// @brief uses default logger for logging with context
#define TRPC_LOG_MSG_IF_EX(level, context, condition, msg) \
if (condition) { \
TRPC_LOG_MSG_EX(level, context, msg); \
#define TRPC_LOG_MSG_IF_EX(level, context, condition, msg) \
if (condition) { \
TRPC_LOG_MSG_EX(level, context, msg); \
}

#define TRPC_LOGGER_MSG_IF_EX(level, context, instance, condition, msg) \
if (condition) { \
TRPC_LOGGER_MSG_EX(level, context, instance, msg); \
#define TRPC_LOGGER_MSG_IF_EX(level, context, instance, condition, msg) \
if (condition) { \
TRPC_LOGGER_MSG_EX(level, context, instance, msg); \
}

#define TRPC_LOGGER_MSG_EX(level, context, instance, msg) TRPC_STREAM_EX(instance, level, context, msg)

#define TRPC_LOG_MSG_EX(level, context, msg) TRPC_STREAM_EX_DEFAULT(::trpc::log::kTrpcLogCacheStringDefault, level, context, msg)
#define TRPC_LOG_MSG_EX(level, context, msg) \
TRPC_STREAM_EX_DEFAULT(::trpc::log::kTrpcLogCacheStringDefault, level, context, msg)