Skip to content

Commit

Permalink
Remove per-channel registered methods map
Browse files Browse the repository at this point in the history
  • Loading branch information
ananda1066 committed Dec 6, 2023
1 parent e497eed commit 6b15454
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 181 deletions.
167 changes: 36 additions & 131 deletions src/core/lib/surface/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,9 @@ void Server::Start() {
if (unregistered_request_matcher_ == nullptr) {
unregistered_request_matcher_ = make_real_request_matcher();
}
for (std::unique_ptr<RegisteredMethod>& rm : registered_methods_) {
if (rm->matcher == nullptr) {
rm->matcher = make_real_request_matcher();
for (auto& rm : registered_methods_) {
if (rm.second->matcher == nullptr) {
rm.second->matcher = make_real_request_matcher();
}
}
{
Expand Down Expand Up @@ -928,20 +928,11 @@ void Server::RegisterCompletionQueue(grpc_completion_queue* cq) {
cqs_.push_back(cq);
}

namespace {

bool streq(const std::string& a, const char* b) {
return (a.empty() && b == nullptr) ||
((b != nullptr) && !strcmp(a.c_str(), b));
}

} // namespace

Server::RegisteredMethod* Server::RegisterMethod(
const char* method, const char* host,
grpc_server_register_method_payload_handling payload_handling,
uint32_t flags) {
if (IsRegisteredMethodsMapEnabled() && started_) {
if (started_) {
Crash("Attempting to register method after server started");
}

Expand All @@ -950,21 +941,21 @@ Server::RegisteredMethod* Server::RegisterMethod(
"grpc_server_register_method method string cannot be NULL");
return nullptr;
}
for (std::unique_ptr<RegisteredMethod>& m : registered_methods_) {
if (streq(m->method, method) && streq(m->host, host)) {
gpr_log(GPR_ERROR, "duplicate registration for %s@%s", method,
host ? host : "*");
return nullptr;
}
auto key = std::make_pair(host ? host : "", method);
if (registered_methods_.find(key) != registered_methods_.end()) {
gpr_log(GPR_ERROR, "duplicate registration for %s@%s", method,
host ? host : "*");
return nullptr;
}
if (flags != 0) {
gpr_log(GPR_ERROR, "grpc_server_register_method invalid flags 0x%08x",
flags);
return nullptr;
}
registered_methods_.emplace_back(std::make_unique<RegisteredMethod>(
method, host, payload_handling, flags));
return registered_methods_.back().get();
auto it = registered_methods_.emplace(
key, std::make_unique<RegisteredMethod>(method, host, payload_handling,
flags));
return it.first->second.get();
}

void Server::DoneRequestEvent(void* req, grpc_cq_completion* /*c*/) {
Expand Down Expand Up @@ -1015,9 +1006,9 @@ void Server::KillPendingWorkLocked(grpc_error_handle error) {
if (started_) {
unregistered_request_matcher_->KillRequests(error);
unregistered_request_matcher_->ZombifyPending();
for (std::unique_ptr<RegisteredMethod>& rm : registered_methods_) {
rm->matcher->KillRequests(error);
rm->matcher->ZombifyPending();
for (auto& rm : registered_methods_) {
rm.second->matcher->KillRequests(error);
rm.second->matcher->ZombifyPending();
}
}
}
Expand Down Expand Up @@ -1252,7 +1243,6 @@ class Server::ChannelData::ConnectivityWatcher
//

Server::ChannelData::~ChannelData() {
old_registered_methods_.reset();
if (server_ != nullptr) {
if (server_->channelz_node_ != nullptr && channelz_socket_uuid_ != 0) {
server_->channelz_node_->RemoveChildSocket(channelz_socket_uuid_);
Expand All @@ -1276,50 +1266,6 @@ void Server::ChannelData::InitTransport(RefCountedPtr<Server> server,
channel_ = channel;
cq_idx_ = cq_idx;
channelz_socket_uuid_ = channelz_socket_uuid;
// Build a lookup table phrased in terms of mdstr's in this channels context
// to quickly find registered methods.
size_t num_registered_methods = server_->registered_methods_.size();
if (!IsRegisteredMethodsMapEnabled() && num_registered_methods > 0) {
uint32_t max_probes = 0;
size_t slots = 2 * num_registered_methods;
old_registered_methods_ =
std::make_unique<std::vector<ChannelRegisteredMethod>>(slots);
for (std::unique_ptr<RegisteredMethod>& rm : server_->registered_methods_) {
Slice host;
Slice method = Slice::FromExternalString(rm->method);
const bool has_host = !rm->host.empty();
if (has_host) {
host = Slice::FromExternalString(rm->host);
}
uint32_t hash = MixHash32(has_host ? host.Hash() : 0, method.Hash());
uint32_t probes = 0;
for (probes = 0; (*old_registered_methods_)[(hash + probes) % slots]
.server_registered_method != nullptr;
probes++) {
}
if (probes > max_probes) max_probes = probes;
ChannelRegisteredMethod* crm =
&(*old_registered_methods_)[(hash + probes) % slots];
crm->server_registered_method = rm.get();
crm->flags = rm->flags;
crm->has_host = has_host;
if (has_host) {
crm->host = std::move(host);
}
crm->method = std::move(method);
}
GPR_ASSERT(slots <= UINT32_MAX);
registered_method_max_probes_ = max_probes;
} else if (IsRegisteredMethodsMapEnabled()) {
for (std::unique_ptr<RegisteredMethod>& rm : server_->registered_methods_) {
auto key = std::make_pair(!rm->host.empty() ? rm->host : "", rm->method);
registered_methods_.emplace(
key, std::make_unique<ChannelRegisteredMethod>(
rm.get(), rm->flags, /*has_host=*/!rm->host.empty(),
Slice::FromExternalString(rm->method),
Slice::FromExternalString(rm->host)));
}
}
// Publish channel.
{
MutexLock lock(&server_->mu_global_);
Expand All @@ -1345,45 +1291,17 @@ void Server::ChannelData::InitTransport(RefCountedPtr<Server> server,
transport->PerformOp(op);
}

Server::ChannelRegisteredMethod* Server::ChannelData::GetRegisteredMethod(
const grpc_slice& host, const grpc_slice& path) {
if (old_registered_methods_ == nullptr) return nullptr;
// TODO(ctiller): unify these two searches
// check for an exact match with host
uint32_t hash = MixHash32(grpc_slice_hash(host), grpc_slice_hash(path));
for (size_t i = 0; i <= registered_method_max_probes_; i++) {
ChannelRegisteredMethod* rm = &(
*old_registered_methods_)[(hash + i) % old_registered_methods_->size()];
if (rm->server_registered_method == nullptr) break;
if (!rm->has_host) continue;
if (rm->host != host) continue;
if (rm->method != path) continue;
return rm;
}
// check for a wildcard method definition (no host set)
hash = MixHash32(0, grpc_slice_hash(path));
for (size_t i = 0; i <= registered_method_max_probes_; i++) {
ChannelRegisteredMethod* rm = &(
*old_registered_methods_)[(hash + i) % old_registered_methods_->size()];
if (rm->server_registered_method == nullptr) break;
if (rm->has_host) continue;
if (rm->method != path) continue;
return rm;
}
return nullptr;
}

Server::ChannelRegisteredMethod* Server::ChannelData::GetRegisteredMethod(
Server::RegisteredMethod* Server::ChannelData::GetRegisteredMethod(
const absl::string_view& host, const absl::string_view& path) {
if (registered_methods_.empty()) return nullptr;
if (server_->registered_methods_.empty()) return nullptr;
// check for an exact match with host
auto it = registered_methods_.find(std::make_pair(host, path));
if (it != registered_methods_.end()) {
auto it = server_->registered_methods_.find(std::make_pair(host, path));
if (it != server_->registered_methods_.end()) {
return it->second.get();
}
// check for wildcard method definition (no host set)
it = registered_methods_.find(std::make_pair("", path));
if (it != registered_methods_.end()) {
it = server_->registered_methods_.find(std::make_pair("", path));
if (it != server_->registered_methods_.end()) {
return it->second.get();
}
return nullptr;
Expand All @@ -1404,13 +1322,8 @@ void Server::ChannelData::SetRegisteredMethodOnMetadata(
// Path not being set would result in an RPC error.
return;
}
ChannelRegisteredMethod* method;
if (!IsRegisteredMethodsMapEnabled()) {
method = GetRegisteredMethod(authority->c_slice(), path->c_slice());
} else {
method = GetRegisteredMethod(authority->as_string_view(),
path->as_string_view());
}
RegisteredMethod* method =
GetRegisteredMethod(authority->as_string_view(), path->as_string_view());
// insert in metadata
metadata.Set(GrpcRegisteredMethod(), method);
}
Expand Down Expand Up @@ -1481,24 +1394,20 @@ ArenaPromise<ServerMetadataHandle> Server::ChannelData::MakeCallPromise(
Timestamp deadline = GetContext<CallContext>()->deadline();
// Find request matcher.
RequestMatcherInterface* matcher;
ChannelRegisteredMethod* rm = nullptr;
RegisteredMethod* rm = nullptr;
if (IsRegisteredMethodLookupInTransportEnabled()) {
rm = static_cast<ChannelRegisteredMethod*>(
rm = static_cast<RegisteredMethod*>(
call_args.client_initial_metadata->get(GrpcRegisteredMethod())
.value_or(nullptr));
} else {
if (!IsRegisteredMethodsMapEnabled()) {
rm = chand->GetRegisteredMethod(host_ptr->c_slice(), path->c_slice());
} else {
rm = chand->GetRegisteredMethod(host_ptr->as_string_view(),
path->as_string_view());
}
rm = chand->GetRegisteredMethod(host_ptr->as_string_view(),
path->as_string_view());
}
ArenaPromise<absl::StatusOr<NextResult<MessageHandle>>>
maybe_read_first_message([] { return NextResult<MessageHandle>(); });
if (rm != nullptr) {
matcher = rm->server_registered_method->matcher.get();
switch (rm->server_registered_method->payload_handling) {
matcher = rm->matcher.get();
switch (rm->payload_handling) {
case GRPC_SRM_PAYLOAD_NONE:
break;
case GRPC_SRM_PAYLOAD_READ_INITIAL_BYTE_BUFFER:
Expand Down Expand Up @@ -1752,22 +1661,18 @@ void Server::CallData::StartNewRpc(grpc_call_element* elem) {
grpc_server_register_method_payload_handling payload_handling =
GRPC_SRM_PAYLOAD_NONE;
if (path_.has_value() && host_.has_value()) {
ChannelRegisteredMethod* rm;
RegisteredMethod* rm;
if (IsRegisteredMethodLookupInTransportEnabled()) {
rm = static_cast<ChannelRegisteredMethod*>(
rm = static_cast<RegisteredMethod*>(
recv_initial_metadata_->get(GrpcRegisteredMethod())
.value_or(nullptr));
} else {
if (!IsRegisteredMethodsMapEnabled()) {
rm = chand->GetRegisteredMethod(host_->c_slice(), path_->c_slice());
} else {
rm = chand->GetRegisteredMethod(host_->as_string_view(),
path_->as_string_view());
}
rm = chand->GetRegisteredMethod(host_->as_string_view(),
path_->as_string_view());
}
if (rm != nullptr) {
matcher_ = rm->server_registered_method->matcher.get();
payload_handling = rm->server_registered_method->payload_handling;
matcher_ = rm->matcher.get();
payload_handling = rm->payload_handling;
}
}
// Start recv_message op if needed.
Expand Down
68 changes: 18 additions & 50 deletions src/core/lib/surface/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,26 +211,6 @@ class Server : public InternallyRefCounted<Server>,
private:
struct RequestedCall;

struct ChannelRegisteredMethod {
ChannelRegisteredMethod() = default;
ChannelRegisteredMethod(RegisteredMethod* server_registered_method_arg,
uint32_t flags_arg, bool has_host_arg,
Slice method_arg, Slice host_arg)
: server_registered_method(server_registered_method_arg),
flags(flags_arg),
has_host(has_host_arg),
method(std::move(method_arg)),
host(std::move(host_arg)) {}

~ChannelRegisteredMethod() = default;

RegisteredMethod* server_registered_method = nullptr;
uint32_t flags;
bool has_host;
Slice method;
Slice host;
};

class RequestMatcherInterface;
class RealRequestMatcherFilterStack;
class RealRequestMatcherPromises;
Expand All @@ -251,11 +231,8 @@ class Server : public InternallyRefCounted<Server>,
Channel* channel() const { return channel_.get(); }
size_t cq_idx() const { return cq_idx_; }

ChannelRegisteredMethod* GetRegisteredMethod(const grpc_slice& host,
const grpc_slice& path);

ChannelRegisteredMethod* GetRegisteredMethod(const absl::string_view& host,
const absl::string_view& path);
RegisteredMethod* GetRegisteredMethod(const absl::string_view& host,
const absl::string_view& path);
// Filter vtable functions.
static grpc_error_handle InitChannelElement(
grpc_channel_element* elem, grpc_channel_element_args* args);
Expand All @@ -274,36 +251,12 @@ class Server : public InternallyRefCounted<Server>,

static void FinishDestroy(void* arg, grpc_error_handle error);

struct StringViewStringViewPairHash
: absl::flat_hash_set<
std::pair<absl::string_view, absl::string_view>>::hasher {
using is_transparent = void;
};

struct StringViewStringViewPairEq
: std::equal_to<std::pair<absl::string_view, absl::string_view>> {
using is_transparent = void;
};

RefCountedPtr<Server> server_;
RefCountedPtr<Channel> channel_;
// The index into Server::cqs_ of the CQ used as a starting point for
// where to publish new incoming calls.
size_t cq_idx_;
absl::optional<std::list<ChannelData*>::iterator> list_position_;
// A hash-table of the methods and hosts of the registered methods.
// TODO(vjpai): Convert this to an STL map type as opposed to a direct
// bucket implementation. (Consider performance impact, hash function to
// use, etc.)
std::unique_ptr<std::vector<ChannelRegisteredMethod>>
old_registered_methods_;
// Map of registered methods.
absl::flat_hash_map<std::pair<std::string, std::string> /*host, method*/,
std::unique_ptr<ChannelRegisteredMethod>,
StringViewStringViewPairHash,
StringViewStringViewPairEq>
registered_methods_;
uint32_t registered_method_max_probes_;
grpc_closure finish_destroy_channel_closure_;
intptr_t channelz_socket_uuid_;
};
Expand Down Expand Up @@ -412,6 +365,17 @@ class Server : public InternallyRefCounted<Server>,
grpc_cq_completion completion;
};

struct StringViewStringViewPairHash
: absl::flat_hash_set<
std::pair<absl::string_view, absl::string_view>>::hasher {
using is_transparent = void;
};

struct StringViewStringViewPairEq
: std::equal_to<std::pair<absl::string_view, absl::string_view>> {
using is_transparent = void;
};

static void ListenerDestroyDone(void* arg, grpc_error_handle error);

static void DoneShutdownEvent(void* server,
Expand Down Expand Up @@ -497,7 +461,11 @@ class Server : public InternallyRefCounted<Server>,
bool starting_ ABSL_GUARDED_BY(mu_global_) = false;
CondVar starting_cv_;

std::vector<std::unique_ptr<RegisteredMethod>> registered_methods_;
// Map of registered methods.
absl::flat_hash_map<std::pair<std::string, std::string> /*host, method*/,
std::unique_ptr<RegisteredMethod>,
StringViewStringViewPairHash, StringViewStringViewPairEq>
registered_methods_;

// Request matcher for unregistered methods.
std::unique_ptr<RequestMatcherInterface> unregistered_request_matcher_;
Expand Down

0 comments on commit 6b15454

Please sign in to comment.