From 3062a5ae214a1fea3f9ed48e21328340873f8aed Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Sun, 24 Jul 2022 23:49:04 -0700 Subject: [PATCH 01/13] Persistence priority rate limiting --- common/headers/caller_info.go | 66 ++++++ common/headers/headers.go | 30 +-- common/headers/versionChecker.go | 24 +++ common/persistence/client/factory.go | 4 +- common/persistence/client/fx.go | 7 +- common/persistence/client/quotas.go | 61 ++++++ .../persistenceRateLimitedClients.go | 192 ++++++++++-------- common/quotas/priority_rate_limiter_impl.go | 20 +- .../quotas/priority_rate_limiter_impl_test.go | 4 +- common/quotas/request.go | 15 +- common/quotas/request_rate_limiter.go | 3 + common/rpc/context.go | 32 ++- common/rpc/grpc.go | 4 +- common/rpc/interceptor/caller_info.go | 65 ++++++ .../rpc/interceptor/namespace_rate_limit.go | 1 + common/rpc/interceptor/rate_limit.go | 3 +- common/xdc/nDCHistoryResender.go | 13 +- common/xdc/nDCHistoryResender_mock.go | 9 +- common/xdc/nDCHistoryResender_test.go | 1 + host/client_integration_test.go | 6 +- host/context.go | 2 +- host/integrationbase.go | 4 +- host/signal_workflow_test.go | 6 +- service/frontend/adminHandler.go | 3 +- service/frontend/configs/quotas.go | 42 ++-- service/frontend/configs/quotas_test.go | 34 +++- service/frontend/fx.go | 9 + service/history/configs/quotas.go | 13 +- service/history/configs/quotas_test.go | 12 +- service/history/queues/executable.go | 2 + service/history/replication/task_executor.go | 21 +- .../history/replication/task_executor_test.go | 2 + service/history/replication/task_fetcher.go | 4 +- service/history/replication/task_processor.go | 3 + service/history/shard/context_impl.go | 18 +- service/history/timerQueueAckMgr.go | 6 +- service/history/timerQueueProcessor.go | 6 +- .../history/timerQueueStandbyTaskExecutor.go | 1 + .../timerQueueStandbyTaskExecutor_test.go | 6 + service/history/transferQueueProcessorBase.go | 17 +- .../transferQueueStandbyTaskExecutor.go | 1 + .../transferQueueStandbyTaskExecutor_test.go | 5 + service/history/visibilityQueueProcessor.go | 11 +- service/history/workflow/context.go | 2 +- service/matching/configs/quotas.go | 13 +- service/matching/configs/quotas_test.go | 12 +- ...namespace_replication_message_processor.go | 2 +- 47 files changed, 593 insertions(+), 224 deletions(-) create mode 100644 common/headers/caller_info.go create mode 100644 common/persistence/client/quotas.go create mode 100644 common/rpc/interceptor/caller_info.go diff --git a/common/headers/caller_info.go b/common/headers/caller_info.go new file mode 100644 index 00000000000..4bf6a7f3c10 --- /dev/null +++ b/common/headers/caller_info.go @@ -0,0 +1,66 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package headers + +import ( + "context" + + "google.golang.org/grpc/metadata" +) + +const ( + CallerTypeAPI = "api" + CallerTypeSystem = "system" + CallerTypeBackground = "background" +) + +// SetCallerInfo sets callerName and callerType value in incoming context +// if not already exists. +func SetCallerInfo( + ctx context.Context, + callerName string, + callerType string, +) context.Context { + mdIncoming, ok := metadata.FromIncomingContext(ctx) + if !ok { + mdIncoming = metadata.MD{} + } + + values := GetValues( + ctx, + CallerNameHeaderName, + CallerTypeHeaderName, + ) + + if values[0] == "" { + mdIncoming.Set(CallerNameHeaderName, callerName) + } + + if values[1] == "" { + mdIncoming.Set(CallerTypeHeaderName, CallerTypeAPI) + } + + return metadata.NewIncomingContext(ctx, mdIncoming) +} diff --git a/common/headers/headers.go b/common/headers/headers.go index a8ab174f7d9..c1a9f8e9e14 100644 --- a/common/headers/headers.go +++ b/common/headers/headers.go @@ -36,6 +36,9 @@ const ( SupportedServerVersionsHeaderName = "supported-server-versions" SupportedFeaturesHeaderName = "supported-features" SupportedFeaturesHeaderDelim = "," + + CallerNameHeaderName = "caller-name" + CallerTypeHeaderName = "caller-type" ) var ( @@ -45,14 +48,9 @@ var ( ClientVersionHeaderName, SupportedServerVersionsHeaderName, SupportedFeaturesHeaderName, + CallerNameHeaderName, + CallerTypeHeaderName, } - - internalVersionHeaders = metadata.New(map[string]string{ - ClientNameHeaderName: ClientNameServer, - ClientVersionHeaderName: ServerVersion, - SupportedServerVersionsHeaderName: SupportedServerVersions, - SupportedFeaturesHeaderName: AllFeatures, - }) ) // GetValues returns header values for passed header names. @@ -70,7 +68,7 @@ func GetValues(ctx context.Context, headerNames ...string) []string { } // Propagate propagates version headers from incoming context to outgoing context. -// It copies all version headers to outgoing context only if they are exist in incoming context +// It copies all headers to outgoing context only if they are exist in incoming context // and doesn't exist in outgoing context already. func Propagate(ctx context.Context) context.Context { if mdIncoming, ok := metadata.FromIncomingContext(ctx); ok { @@ -97,22 +95,6 @@ func Propagate(ctx context.Context) context.Context { return ctx } -// SetVersions sets headers for internal communications. -func SetVersions(ctx context.Context) context.Context { - return metadata.NewOutgoingContext(ctx, internalVersionHeaders) -} - -// SetVersionsForTests sets headers as they would be received from the client. -// Must be used in tests only. -func SetVersionsForTests(ctx context.Context, clientVersion, clientName, supportedServerVersions, supportedFeatures string) context.Context { - return metadata.NewIncomingContext(ctx, metadata.New(map[string]string{ - ClientNameHeaderName: clientName, - ClientVersionHeaderName: clientVersion, - SupportedServerVersionsHeaderName: supportedServerVersions, - SupportedFeaturesHeaderName: supportedFeatures, - })) -} - func getSingleHeaderValue(md metadata.MD, headerName string) string { values := md.Get(headerName) if len(values) == 0 { diff --git a/common/headers/versionChecker.go b/common/headers/versionChecker.go index 7cb8d724718..399060f04d4 100644 --- a/common/headers/versionChecker.go +++ b/common/headers/versionChecker.go @@ -31,6 +31,7 @@ import ( "github.com/blang/semver/v4" "golang.org/x/exp/slices" + "google.golang.org/grpc/metadata" "go.temporal.io/api/serviceerror" ) @@ -71,6 +72,13 @@ var ( ClientNameServer: "<2.0.0", ClientNameUI: "<3.0.0", } + + internalVersionHeaderPairs = []string{ + ClientNameHeaderName, ClientNameServer, + ClientVersionHeaderName, ServerVersion, + SupportedServerVersionsHeaderName, SupportedServerVersions, + SupportedFeaturesHeaderName, AllFeatures, + } ) type ( @@ -109,6 +117,22 @@ func GetClientNameAndVersion(ctx context.Context) (string, string) { return clientName, clientVersion } +// SetVersions sets headers for internal communications. +func SetVersions(ctx context.Context) context.Context { + return metadata.AppendToOutgoingContext(ctx, internalVersionHeaderPairs...) +} + +// SetVersionsForTests sets headers as they would be received from the client. +// Must be used in tests only. +func SetVersionsForTests(ctx context.Context, clientVersion, clientName, supportedServerVersions, supportedFeatures string) context.Context { + return metadata.NewIncomingContext(ctx, metadata.New(map[string]string{ + ClientNameHeaderName: clientName, + ClientVersionHeaderName: clientVersion, + SupportedServerVersionsHeaderName: supportedServerVersions, + SupportedFeaturesHeaderName: supportedFeatures, + })) +} + // ClientSupported returns an error if client is unsupported, nil otherwise. func (vc *versionChecker) ClientSupported(ctx context.Context, enableClientVersionCheck bool) error { if !enableClientVersionCheck { diff --git a/common/persistence/client/factory.go b/common/persistence/client/factory.go index b5ab2965e68..8fd6a448650 100644 --- a/common/persistence/client/factory.go +++ b/common/persistence/client/factory.go @@ -61,7 +61,7 @@ type ( metricsClient metrics.Client logger log.Logger clusterName string - ratelimiter quotas.RateLimiter + ratelimiter quotas.RequestRateLimiter } ) @@ -75,7 +75,7 @@ type ( func NewFactory( dataStoreFactory DataStoreFactory, cfg *config.Persistence, - ratelimiter quotas.RateLimiter, + ratelimiter quotas.RequestRateLimiter, serializer serialization.Serializer, clusterName string, metricsClient metrics.Client, diff --git a/common/persistence/client/fx.go b/common/persistence/client/fx.go index a7bbf7c513c..926f55500c6 100644 --- a/common/persistence/client/fx.go +++ b/common/persistence/client/fx.go @@ -67,16 +67,17 @@ func ClusterNameProvider(config *cluster.Config) ClusterName { func FactoryProvider( params NewFactoryParams, ) Factory { - var ratelimiter quotas.RateLimiter + var requestRatelimiter quotas.RequestRateLimiter if params.PersistenceMaxQPS != nil && params.PersistenceMaxQPS() > 0 { - ratelimiter = quotas.NewDefaultOutgoingRateLimiter( + requestRatelimiter = NewPriorityRateLimiter( func() float64 { return float64(params.PersistenceMaxQPS()) }, ) } + return NewFactory( params.DataStoreFactory, params.Cfg, - ratelimiter, + requestRatelimiter, serialization.NewSerializer(), string(params.ClusterName), params.MetricsClient, diff --git a/common/persistence/client/quotas.go b/common/persistence/client/quotas.go new file mode 100644 index 00000000000..2807aebb09c --- /dev/null +++ b/common/persistence/client/quotas.go @@ -0,0 +1,61 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package client + +import ( + "go.temporal.io/server/common/headers" + "go.temporal.io/server/common/quotas" +) + +var ( + CallerTypeHeaderToPriority = map[string]int{ + headers.CallerTypeAPI: 0, + headers.CallerTypeSystem: 0, + headers.CallerTypeBackground: 1, + } + + RequestPrioritiesOrdered = []int{0, 1} +) + +func NewPriorityRateLimiter( + rateFn quotas.RateFn, +) quotas.RequestRateLimiter { + rateLimiters := make(map[int]quotas.RateLimiter) + for priority := range RequestPrioritiesOrdered { + rateLimiters[priority] = quotas.NewDefaultOutgoingRateLimiter(rateFn) + } + + return quotas.NewPriorityRateLimiter( + func(req quotas.Request) int { + if priority, ok := CallerTypeHeaderToPriority[req.Caller]; ok { + return priority + } + + // default requests to high priority to be consistent with existing behavior + return RequestPrioritiesOrdered[0] + }, + rateLimiters, + ) +} diff --git a/common/persistence/persistenceRateLimitedClients.go b/common/persistence/persistenceRateLimitedClients.go index 3cb280a99e9..c32ae7f72d1 100644 --- a/common/persistence/persistenceRateLimitedClients.go +++ b/common/persistence/persistenceRateLimitedClients.go @@ -26,15 +26,21 @@ package persistence import ( "context" + "time" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/serviceerror" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/quotas" ) +const ( + RateLimitDefaultToken = 1 +) + var ( // ErrPersistenceLimitExceeded is the error indicating QPS limit reached. ErrPersistenceLimitExceeded = serviceerror.NewResourceExhausted(enumspb.RESOURCE_EXHAUSTED_CAUSE_SYSTEM_OVERLOADED, "Persistence Max QPS Reached.") @@ -42,37 +48,37 @@ var ( type ( shardRateLimitedPersistenceClient struct { - rateLimiter quotas.RateLimiter + rateLimiter quotas.RequestRateLimiter persistence ShardManager logger log.Logger } executionRateLimitedPersistenceClient struct { - rateLimiter quotas.RateLimiter + rateLimiter quotas.RequestRateLimiter persistence ExecutionManager logger log.Logger } taskRateLimitedPersistenceClient struct { - rateLimiter quotas.RateLimiter + rateLimiter quotas.RequestRateLimiter persistence TaskManager logger log.Logger } metadataRateLimitedPersistenceClient struct { - rateLimiter quotas.RateLimiter + rateLimiter quotas.RequestRateLimiter persistence MetadataManager logger log.Logger } clusterMetadataRateLimitedPersistenceClient struct { - rateLimiter quotas.RateLimiter + rateLimiter quotas.RequestRateLimiter persistence ClusterMetadataManager logger log.Logger } queueRateLimitedPersistenceClient struct { - rateLimiter quotas.RateLimiter + rateLimiter quotas.RequestRateLimiter persistence Queue logger log.Logger } @@ -86,7 +92,7 @@ var _ ClusterMetadataManager = (*clusterMetadataRateLimitedPersistenceClient)(ni var _ Queue = (*queueRateLimitedPersistenceClient)(nil) // NewShardPersistenceRateLimitedClient creates a client to manage shards -func NewShardPersistenceRateLimitedClient(persistence ShardManager, rateLimiter quotas.RateLimiter, logger log.Logger) ShardManager { +func NewShardPersistenceRateLimitedClient(persistence ShardManager, rateLimiter quotas.RequestRateLimiter, logger log.Logger) ShardManager { return &shardRateLimitedPersistenceClient{ persistence: persistence, rateLimiter: rateLimiter, @@ -95,7 +101,7 @@ func NewShardPersistenceRateLimitedClient(persistence ShardManager, rateLimiter } // NewExecutionPersistenceRateLimitedClient creates a client to manage executions -func NewExecutionPersistenceRateLimitedClient(persistence ExecutionManager, rateLimiter quotas.RateLimiter, logger log.Logger) ExecutionManager { +func NewExecutionPersistenceRateLimitedClient(persistence ExecutionManager, rateLimiter quotas.RequestRateLimiter, logger log.Logger) ExecutionManager { return &executionRateLimitedPersistenceClient{ persistence: persistence, rateLimiter: rateLimiter, @@ -104,7 +110,7 @@ func NewExecutionPersistenceRateLimitedClient(persistence ExecutionManager, rate } // NewTaskPersistenceRateLimitedClient creates a client to manage tasks -func NewTaskPersistenceRateLimitedClient(persistence TaskManager, rateLimiter quotas.RateLimiter, logger log.Logger) TaskManager { +func NewTaskPersistenceRateLimitedClient(persistence TaskManager, rateLimiter quotas.RequestRateLimiter, logger log.Logger) TaskManager { return &taskRateLimitedPersistenceClient{ persistence: persistence, rateLimiter: rateLimiter, @@ -113,7 +119,7 @@ func NewTaskPersistenceRateLimitedClient(persistence TaskManager, rateLimiter qu } // NewMetadataPersistenceRateLimitedClient creates a MetadataManager client to manage metadata -func NewMetadataPersistenceRateLimitedClient(persistence MetadataManager, rateLimiter quotas.RateLimiter, logger log.Logger) MetadataManager { +func NewMetadataPersistenceRateLimitedClient(persistence MetadataManager, rateLimiter quotas.RequestRateLimiter, logger log.Logger) MetadataManager { return &metadataRateLimitedPersistenceClient{ persistence: persistence, rateLimiter: rateLimiter, @@ -122,7 +128,7 @@ func NewMetadataPersistenceRateLimitedClient(persistence MetadataManager, rateLi } // NewClusterMetadataPersistenceRateLimitedClient creates a MetadataManager client to manage metadata -func NewClusterMetadataPersistenceRateLimitedClient(persistence ClusterMetadataManager, rateLimiter quotas.RateLimiter, logger log.Logger) ClusterMetadataManager { +func NewClusterMetadataPersistenceRateLimitedClient(persistence ClusterMetadataManager, rateLimiter quotas.RequestRateLimiter, logger log.Logger) ClusterMetadataManager { return &clusterMetadataRateLimitedPersistenceClient{ persistence: persistence, rateLimiter: rateLimiter, @@ -131,7 +137,7 @@ func NewClusterMetadataPersistenceRateLimitedClient(persistence ClusterMetadataM } // NewQueuePersistenceRateLimitedClient creates a client to manage queue -func NewQueuePersistenceRateLimitedClient(persistence Queue, rateLimiter quotas.RateLimiter, logger log.Logger) Queue { +func NewQueuePersistenceRateLimitedClient(persistence Queue, rateLimiter quotas.RequestRateLimiter, logger log.Logger) Queue { return &queueRateLimitedPersistenceClient{ persistence: persistence, rateLimiter: rateLimiter, @@ -147,7 +153,7 @@ func (p *shardRateLimitedPersistenceClient) GetOrCreateShard( ctx context.Context, request *GetOrCreateShardRequest, ) (*GetOrCreateShardResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -159,7 +165,7 @@ func (p *shardRateLimitedPersistenceClient) UpdateShard( ctx context.Context, request *UpdateShardRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -171,7 +177,7 @@ func (p *shardRateLimitedPersistenceClient) AssertShardOwnership( ctx context.Context, request *AssertShardOwnershipRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -191,7 +197,7 @@ func (p *executionRateLimitedPersistenceClient) CreateWorkflowExecution( ctx context.Context, request *CreateWorkflowExecutionRequest, ) (*CreateWorkflowExecutionResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -203,7 +209,7 @@ func (p *executionRateLimitedPersistenceClient) GetWorkflowExecution( ctx context.Context, request *GetWorkflowExecutionRequest, ) (*GetWorkflowExecutionResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -215,7 +221,7 @@ func (p *executionRateLimitedPersistenceClient) SetWorkflowExecution( ctx context.Context, request *SetWorkflowExecutionRequest, ) (*SetWorkflowExecutionResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -227,7 +233,7 @@ func (p *executionRateLimitedPersistenceClient) UpdateWorkflowExecution( ctx context.Context, request *UpdateWorkflowExecutionRequest, ) (*UpdateWorkflowExecutionResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -239,7 +245,7 @@ func (p *executionRateLimitedPersistenceClient) ConflictResolveWorkflowExecution ctx context.Context, request *ConflictResolveWorkflowExecutionRequest, ) (*ConflictResolveWorkflowExecutionResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -251,7 +257,7 @@ func (p *executionRateLimitedPersistenceClient) DeleteWorkflowExecution( ctx context.Context, request *DeleteWorkflowExecutionRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -263,7 +269,7 @@ func (p *executionRateLimitedPersistenceClient) DeleteCurrentWorkflowExecution( ctx context.Context, request *DeleteCurrentWorkflowExecutionRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -275,7 +281,7 @@ func (p *executionRateLimitedPersistenceClient) GetCurrentExecution( ctx context.Context, request *GetCurrentExecutionRequest, ) (*GetCurrentExecutionResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -287,7 +293,7 @@ func (p *executionRateLimitedPersistenceClient) ListConcreteExecutions( ctx context.Context, request *ListConcreteExecutionsRequest, ) (*ListConcreteExecutionsResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -299,7 +305,7 @@ func (p *executionRateLimitedPersistenceClient) AddHistoryTasks( ctx context.Context, request *AddHistoryTasksRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -311,7 +317,7 @@ func (p *executionRateLimitedPersistenceClient) GetHistoryTask( ctx context.Context, request *GetHistoryTaskRequest, ) (*GetHistoryTaskResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -323,7 +329,7 @@ func (p *executionRateLimitedPersistenceClient) GetHistoryTasks( ctx context.Context, request *GetHistoryTasksRequest, ) (*GetHistoryTasksResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -335,7 +341,7 @@ func (p *executionRateLimitedPersistenceClient) CompleteHistoryTask( ctx context.Context, request *CompleteHistoryTaskRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -347,7 +353,7 @@ func (p *executionRateLimitedPersistenceClient) RangeCompleteHistoryTasks( ctx context.Context, request *RangeCompleteHistoryTasksRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -359,7 +365,7 @@ func (p *executionRateLimitedPersistenceClient) PutReplicationTaskToDLQ( ctx context.Context, request *PutReplicationTaskToDLQRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -370,7 +376,7 @@ func (p *executionRateLimitedPersistenceClient) GetReplicationTasksFromDLQ( ctx context.Context, request *GetReplicationTasksFromDLQRequest, ) (*GetHistoryTasksResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -381,7 +387,7 @@ func (p *executionRateLimitedPersistenceClient) DeleteReplicationTaskFromDLQ( ctx context.Context, request *DeleteReplicationTaskFromDLQRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -392,7 +398,7 @@ func (p *executionRateLimitedPersistenceClient) RangeDeleteReplicationTaskFromDL ctx context.Context, request *RangeDeleteReplicationTaskFromDLQRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -411,7 +417,7 @@ func (p *taskRateLimitedPersistenceClient) CreateTasks( ctx context.Context, request *CreateTasksRequest, ) (*CreateTasksResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -423,7 +429,7 @@ func (p *taskRateLimitedPersistenceClient) GetTasks( ctx context.Context, request *GetTasksRequest, ) (*GetTasksResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -435,7 +441,7 @@ func (p *taskRateLimitedPersistenceClient) CompleteTask( ctx context.Context, request *CompleteTaskRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -447,7 +453,7 @@ func (p *taskRateLimitedPersistenceClient) CompleteTasksLessThan( ctx context.Context, request *CompleteTasksLessThanRequest, ) (int, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return 0, ErrPersistenceLimitExceeded } return p.persistence.CompleteTasksLessThan(ctx, request) @@ -457,7 +463,7 @@ func (p *taskRateLimitedPersistenceClient) CreateTaskQueue( ctx context.Context, request *CreateTaskQueueRequest, ) (*CreateTaskQueueResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.CreateTaskQueue(ctx, request) @@ -467,7 +473,7 @@ func (p *taskRateLimitedPersistenceClient) UpdateTaskQueue( ctx context.Context, request *UpdateTaskQueueRequest, ) (*UpdateTaskQueueResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.UpdateTaskQueue(ctx, request) @@ -477,7 +483,7 @@ func (p *taskRateLimitedPersistenceClient) GetTaskQueue( ctx context.Context, request *GetTaskQueueRequest, ) (*GetTaskQueueResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.GetTaskQueue(ctx, request) @@ -487,7 +493,7 @@ func (p *taskRateLimitedPersistenceClient) ListTaskQueue( ctx context.Context, request *ListTaskQueueRequest, ) (*ListTaskQueueResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.ListTaskQueue(ctx, request) @@ -497,7 +503,7 @@ func (p *taskRateLimitedPersistenceClient) DeleteTaskQueue( ctx context.Context, request *DeleteTaskQueueRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } return p.persistence.DeleteTaskQueue(ctx, request) @@ -515,7 +521,7 @@ func (p *metadataRateLimitedPersistenceClient) CreateNamespace( ctx context.Context, request *CreateNamespaceRequest, ) (*CreateNamespaceResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -527,7 +533,7 @@ func (p *metadataRateLimitedPersistenceClient) GetNamespace( ctx context.Context, request *GetNamespaceRequest, ) (*GetNamespaceResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -539,7 +545,7 @@ func (p *metadataRateLimitedPersistenceClient) UpdateNamespace( ctx context.Context, request *UpdateNamespaceRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -551,7 +557,7 @@ func (p *metadataRateLimitedPersistenceClient) RenameNamespace( ctx context.Context, request *RenameNamespaceRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -563,7 +569,7 @@ func (p *metadataRateLimitedPersistenceClient) DeleteNamespace( ctx context.Context, request *DeleteNamespaceRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -575,7 +581,7 @@ func (p *metadataRateLimitedPersistenceClient) DeleteNamespaceByName( ctx context.Context, request *DeleteNamespaceByNameRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -587,7 +593,7 @@ func (p *metadataRateLimitedPersistenceClient) ListNamespaces( ctx context.Context, request *ListNamespacesRequest, ) (*ListNamespacesResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -598,7 +604,7 @@ func (p *metadataRateLimitedPersistenceClient) ListNamespaces( func (p *metadataRateLimitedPersistenceClient) GetMetadata( ctx context.Context, ) (*GetMetadataResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -606,6 +612,16 @@ func (p *metadataRateLimitedPersistenceClient) GetMetadata( return response, err } +func (p *metadataRateLimitedPersistenceClient) InitializeSystemNamespaces( + ctx context.Context, + currentClusterName string, +) error { + if ok := allow(ctx, p.rateLimiter); !ok { + return ErrPersistenceLimitExceeded + } + return p.persistence.InitializeSystemNamespaces(ctx, currentClusterName) +} + func (p *metadataRateLimitedPersistenceClient) Close() { p.persistence.Close() } @@ -615,7 +631,7 @@ func (p *executionRateLimitedPersistenceClient) AppendHistoryNodes( ctx context.Context, request *AppendHistoryNodesRequest, ) (*AppendHistoryNodesResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.AppendHistoryNodes(ctx, request) @@ -626,7 +642,7 @@ func (p *executionRateLimitedPersistenceClient) AppendRawHistoryNodes( ctx context.Context, request *AppendRawHistoryNodesRequest, ) (*AppendHistoryNodesResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.AppendRawHistoryNodes(ctx, request) @@ -637,7 +653,7 @@ func (p *executionRateLimitedPersistenceClient) ReadHistoryBranch( ctx context.Context, request *ReadHistoryBranchRequest, ) (*ReadHistoryBranchResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ReadHistoryBranch(ctx, request) @@ -649,7 +665,7 @@ func (p *executionRateLimitedPersistenceClient) ReadHistoryBranchReverse( ctx context.Context, request *ReadHistoryBranchReverseRequest, ) (*ReadHistoryBranchReverseResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ReadHistoryBranchReverse(ctx, request) @@ -661,7 +677,7 @@ func (p *executionRateLimitedPersistenceClient) ReadHistoryBranchByBatch( ctx context.Context, request *ReadHistoryBranchRequest, ) (*ReadHistoryBranchByBatchResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ReadHistoryBranchByBatch(ctx, request) @@ -673,7 +689,7 @@ func (p *executionRateLimitedPersistenceClient) ReadRawHistoryBranch( ctx context.Context, request *ReadHistoryBranchRequest, ) (*ReadRawHistoryBranchResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ReadRawHistoryBranch(ctx, request) @@ -685,7 +701,7 @@ func (p *executionRateLimitedPersistenceClient) ForkHistoryBranch( ctx context.Context, request *ForkHistoryBranchRequest, ) (*ForkHistoryBranchResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ForkHistoryBranch(ctx, request) @@ -697,7 +713,7 @@ func (p *executionRateLimitedPersistenceClient) DeleteHistoryBranch( ctx context.Context, request *DeleteHistoryBranchRequest, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } err := p.persistence.DeleteHistoryBranch(ctx, request) @@ -709,7 +725,7 @@ func (p *executionRateLimitedPersistenceClient) TrimHistoryBranch( ctx context.Context, request *TrimHistoryBranchRequest, ) (*TrimHistoryBranchResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } resp, err := p.persistence.TrimHistoryBranch(ctx, request) @@ -721,7 +737,7 @@ func (p *executionRateLimitedPersistenceClient) GetHistoryTree( ctx context.Context, request *GetHistoryTreeRequest, ) (*GetHistoryTreeResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.GetHistoryTree(ctx, request) @@ -732,7 +748,7 @@ func (p *executionRateLimitedPersistenceClient) GetAllHistoryTreeBranches( ctx context.Context, request *GetAllHistoryTreeBranchesRequest, ) (*GetAllHistoryTreeBranchesResponse, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.GetAllHistoryTreeBranches(ctx, request) @@ -743,7 +759,7 @@ func (p *queueRateLimitedPersistenceClient) EnqueueMessage( ctx context.Context, blob commonpb.DataBlob, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -755,7 +771,7 @@ func (p *queueRateLimitedPersistenceClient) ReadMessages( lastMessageID int64, maxCount int, ) ([]*QueueMessage, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -766,7 +782,7 @@ func (p *queueRateLimitedPersistenceClient) UpdateAckLevel( ctx context.Context, metadata *InternalQueueMetadata, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -776,7 +792,7 @@ func (p *queueRateLimitedPersistenceClient) UpdateAckLevel( func (p *queueRateLimitedPersistenceClient) GetAckLevels( ctx context.Context, ) (*InternalQueueMetadata, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -787,7 +803,7 @@ func (p *queueRateLimitedPersistenceClient) DeleteMessagesBefore( ctx context.Context, messageID int64, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -798,7 +814,7 @@ func (p *queueRateLimitedPersistenceClient) EnqueueMessageToDLQ( ctx context.Context, blob commonpb.DataBlob, ) (int64, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return EmptyQueueMessageID, ErrPersistenceLimitExceeded } @@ -812,7 +828,7 @@ func (p *queueRateLimitedPersistenceClient) ReadMessagesFromDLQ( pageSize int, pageToken []byte, ) ([]*QueueMessage, []byte, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, nil, ErrPersistenceLimitExceeded } @@ -824,7 +840,7 @@ func (p *queueRateLimitedPersistenceClient) RangeDeleteMessagesFromDLQ( firstMessageID int64, lastMessageID int64, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -834,7 +850,7 @@ func (p *queueRateLimitedPersistenceClient) UpdateDLQAckLevel( ctx context.Context, metadata *InternalQueueMetadata, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -844,7 +860,7 @@ func (p *queueRateLimitedPersistenceClient) UpdateDLQAckLevel( func (p *queueRateLimitedPersistenceClient) GetDLQAckLevels( ctx context.Context, ) (*InternalQueueMetadata, error) { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -855,7 +871,7 @@ func (p *queueRateLimitedPersistenceClient) DeleteMessageFromDLQ( ctx context.Context, messageID int64, ) error { - if ok := p.rateLimiter.Allow(); !ok { + if ok := allow(ctx, p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -885,7 +901,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) GetClusterMembers( ctx context.Context, request *GetClusterMembersRequest, ) (*GetClusterMembersResponse, error) { - if ok := c.rateLimiter.Allow(); !ok { + if ok := allow(ctx, c.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return c.persistence.GetClusterMembers(ctx, request) @@ -895,7 +911,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) UpsertClusterMembership( ctx context.Context, request *UpsertClusterMembershipRequest, ) error { - if ok := c.rateLimiter.Allow(); !ok { + if ok := allow(ctx, c.rateLimiter); !ok { return ErrPersistenceLimitExceeded } return c.persistence.UpsertClusterMembership(ctx, request) @@ -905,7 +921,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) PruneClusterMembership( ctx context.Context, request *PruneClusterMembershipRequest, ) error { - if ok := c.rateLimiter.Allow(); !ok { + if ok := allow(ctx, c.rateLimiter); !ok { return ErrPersistenceLimitExceeded } return c.persistence.PruneClusterMembership(ctx, request) @@ -915,7 +931,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) ListClusterMetadata( ctx context.Context, request *ListClusterMetadataRequest, ) (*ListClusterMetadataResponse, error) { - if ok := c.rateLimiter.Allow(); !ok { + if ok := allow(ctx, c.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return c.persistence.ListClusterMetadata(ctx, request) @@ -924,7 +940,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) ListClusterMetadata( func (c *clusterMetadataRateLimitedPersistenceClient) GetCurrentClusterMetadata( ctx context.Context, ) (*GetClusterMetadataResponse, error) { - if ok := c.rateLimiter.Allow(); !ok { + if ok := allow(ctx, c.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return c.persistence.GetCurrentClusterMetadata(ctx) @@ -934,7 +950,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) GetClusterMetadata( ctx context.Context, request *GetClusterMetadataRequest, ) (*GetClusterMetadataResponse, error) { - if ok := c.rateLimiter.Allow(); !ok { + if ok := allow(ctx, c.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return c.persistence.GetClusterMetadata(ctx, request) @@ -944,7 +960,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) SaveClusterMetadata( ctx context.Context, request *SaveClusterMetadataRequest, ) (bool, error) { - if ok := c.rateLimiter.Allow(); !ok { + if ok := allow(ctx, c.rateLimiter); !ok { return false, ErrPersistenceLimitExceeded } return c.persistence.SaveClusterMetadata(ctx, request) @@ -954,18 +970,20 @@ func (c *clusterMetadataRateLimitedPersistenceClient) DeleteClusterMetadata( ctx context.Context, request *DeleteClusterMetadataRequest, ) error { - if ok := c.rateLimiter.Allow(); !ok { + if ok := allow(ctx, c.rateLimiter); !ok { return ErrPersistenceLimitExceeded } return c.persistence.DeleteClusterMetadata(ctx, request) } -func (c *metadataRateLimitedPersistenceClient) InitializeSystemNamespaces( +func allow( ctx context.Context, - currentClusterName string, -) error { - if ok := c.rateLimiter.Allow(); !ok { - return ErrPersistenceLimitExceeded - } - return c.persistence.InitializeSystemNamespaces(ctx, currentClusterName) + rateLimiter quotas.RequestRateLimiter, +) bool { + return rateLimiter.Allow(time.Now().UTC(), quotas.NewRequest( + "", // api: currently not used when calculating priority + RateLimitDefaultToken, + "", // caller: currently not used when calculating priority + headers.GetValues(ctx, headers.CallerTypeHeaderName)[0], + )) } diff --git a/common/quotas/priority_rate_limiter_impl.go b/common/quotas/priority_rate_limiter_impl.go index 1f90dfff680..291479b5bc6 100644 --- a/common/quotas/priority_rate_limiter_impl.go +++ b/common/quotas/priority_rate_limiter_impl.go @@ -34,7 +34,7 @@ import ( type ( // PriorityRateLimiterImpl is a wrapper around the golang rate limiter PriorityRateLimiterImpl struct { - apiToPriority map[string]int + requestPriorityFn RequestPriorityFn priorityToRateLimiters map[int]RateLimiter // priority value 0 means highest priority @@ -49,7 +49,7 @@ var _ RequestRateLimiter = (*PriorityRateLimiterImpl)(nil) // NewPriorityRateLimiter returns a new rate limiter that can handle dynamic // configuration updates func NewPriorityRateLimiter( - apiToPriority map[string]int, + requestPriorityFn RequestPriorityFn, priorityToRateLimiters map[int]RateLimiter, ) *PriorityRateLimiterImpl { priorities := make([]int, 0, len(priorityToRateLimiters)) @@ -66,15 +66,8 @@ func NewPriorityRateLimiter( rateLimiters = append(rateLimiters, priorityToRateLimiters[priority]) } - // sanity check priority within apiToPriority appears in priorityToRateLimiters - for _, priority := range apiToPriority { - if _, ok := priorityToRateLimiters[priority]; !ok { - panic("API to priority & priority to rate limiter does not match") - } - } - return &PriorityRateLimiterImpl{ - apiToPriority: apiToPriority, + requestPriorityFn: requestPriorityFn, priorityToRateLimiters: priorityToRateLimiters, priorityToIndex: priorityToIndex, @@ -161,10 +154,9 @@ func (p *PriorityRateLimiterImpl) Wait( func (p *PriorityRateLimiterImpl) getRateLimiters( request Request, ) (RateLimiter, []RateLimiter) { - priority, ok := p.apiToPriority[request.API] - if !ok { - // if API not assigned a priority use the lowest priority - return p.rateLimiters[len(p.rateLimiters)-1], nil + priority := p.requestPriorityFn(request) + if _, ok := p.priorityToRateLimiters[priority]; !ok { + panic("Request to priority & priority to rate limiter does not match") } rateLimiterIndex := p.priorityToIndex[priority] diff --git a/common/quotas/priority_rate_limiter_impl_test.go b/common/quotas/priority_rate_limiter_impl_test.go index d814c88cca5..cb57c3022e7 100644 --- a/common/quotas/priority_rate_limiter_impl_test.go +++ b/common/quotas/priority_rate_limiter_impl_test.go @@ -83,7 +83,9 @@ func (s *priorityStageRateLimiterSuite) SetupTest() { 0: s.highPriorityRateLimiter, 2: s.lowPriorityRateLimiter, } - s.rateLimiter = NewPriorityRateLimiter(apiToPriority, priorityToRateLimiters) + s.rateLimiter = NewPriorityRateLimiter(func(req Request) int { + return apiToPriority[req.API] + }, priorityToRateLimiters) } diff --git a/common/quotas/request.go b/common/quotas/request.go index 4b1e9981f47..cc28a462dc1 100644 --- a/common/quotas/request.go +++ b/common/quotas/request.go @@ -26,9 +26,10 @@ package quotas type ( Request struct { - API string - Token int - Caller string + API string + Token int + Caller string + CallerType string } ) @@ -36,10 +37,12 @@ func NewRequest( api string, token int, caller string, + callerType string, ) Request { return Request{ - API: api, - Token: token, - Caller: caller, + API: api, + Token: token, + Caller: caller, + CallerType: callerType, } } diff --git a/common/quotas/request_rate_limiter.go b/common/quotas/request_rate_limiter.go index 00ae45b47e8..5aeca61926b 100644 --- a/common/quotas/request_rate_limiter.go +++ b/common/quotas/request_rate_limiter.go @@ -35,6 +35,9 @@ type ( // RequestRateLimiterFn returns generate a namespace specific rate limiter RequestRateLimiterFn func(req Request) RequestRateLimiter + // RequestPriorityFn returns a priority for the given Request + RequestPriorityFn func(req Request) int + // RequestRateLimiter corresponds to basic rate limiting functionality. RequestRateLimiter interface { // Allow attempts to allow a request to go through. The method returns diff --git a/common/rpc/context.go b/common/rpc/context.go index b7eb1647371..1af077737d4 100644 --- a/common/rpc/context.go +++ b/common/rpc/context.go @@ -31,17 +31,41 @@ import ( "go.temporal.io/server/common/headers" ) +type ( + valueCopyCtx struct { + context.Context + + valueCtx context.Context + } +) + +func (c *valueCopyCtx) Value(key any) any { + if value := c.Context.Value(key); value != nil { + return value + } + + return c.valueCtx.Value(key) +} + +// CopyContextValues copies values in source Context to destination Context. +func CopyContextValues(dst context.Context, src context.Context) context.Context { + return &valueCopyCtx{ + Context: dst, + valueCtx: src, + } +} + // NewContextWithTimeout creates context with timeout. func NewContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), timeout) } -// NewContextWithTimeoutAndHeaders creates context with timeout and version headers. -func NewContextWithTimeoutAndHeaders(timeout time.Duration) (context.Context, context.CancelFunc) { +// NewContextWithTimeoutAndVersionHeaders creates context with timeout and version headers. +func NewContextWithTimeoutAndVersionHeaders(timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(headers.SetVersions(context.Background()), timeout) } -// NewContextFromParentWithTimeoutAndHeaders creates context from parent context with timeout and version headers. -func NewContextFromParentWithTimeoutAndHeaders(parentCtx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { +// NewContextFromParentWithTimeoutAndVersionHeaders creates context from parent context with timeout and version headers. +func NewContextFromParentWithTimeoutAndVersionHeaders(parentCtx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(headers.SetVersions(parentCtx), timeout) } diff --git a/common/rpc/grpc.go b/common/rpc/grpc.go index c1afbe29902..4f9b98877bc 100644 --- a/common/rpc/grpc.go +++ b/common/rpc/grpc.go @@ -87,7 +87,7 @@ func Dial(hostName string, tlsConfig *tls.Config, logger log.Logger, interceptor grpc.WithChainUnaryInterceptor( append( interceptors, - versionHeadersInterceptor, + headersInterceptor, metrics.NewClientMetricsTrailerPropagatorInterceptor(logger), errorInterceptor, )..., @@ -116,7 +116,7 @@ func errorInterceptor( return err } -func versionHeadersInterceptor( +func headersInterceptor( ctx context.Context, method string, req, reply interface{}, diff --git a/common/rpc/interceptor/caller_info.go b/common/rpc/interceptor/caller_info.go new file mode 100644 index 00000000000..b79feafbbca --- /dev/null +++ b/common/rpc/interceptor/caller_info.go @@ -0,0 +1,65 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package interceptor + +import ( + "context" + + "go.temporal.io/server/common/headers" + "go.temporal.io/server/common/namespace" + "google.golang.org/grpc" +) + +type ( + CallerInfoInterceptor struct { + namespaceRegistry namespace.Registry + } +) + +var _ grpc.UnaryServerInterceptor = (*CallerInfoInterceptor)(nil).Intercept + +func NewCallerInfoInterceptor( + namespaceRegistry namespace.Registry, +) *CallerInfoInterceptor { + return &CallerInfoInterceptor{ + namespaceRegistry: namespaceRegistry, + } +} + +func (i *CallerInfoInterceptor) Intercept( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, +) (interface{}, error) { + return handler( + headers.SetCallerInfo( + ctx, + GetNamespace(i.namespaceRegistry, req).String(), + headers.CallerTypeAPI, + ), + req, + ) +} diff --git a/common/rpc/interceptor/namespace_rate_limit.go b/common/rpc/interceptor/namespace_rate_limit.go index 6ee3f0e0ea7..39654a6de50 100644 --- a/common/rpc/interceptor/namespace_rate_limit.go +++ b/common/rpc/interceptor/namespace_rate_limit.go @@ -83,6 +83,7 @@ func (ni *NamespaceRateLimitInterceptor) Intercept( methodName, token, namespace.String(), + "", // this interceptor layer does not throttle based on caller type )) { return nil, ErrNamespaceRateLimitServerBusy } diff --git a/common/rpc/interceptor/rate_limit.go b/common/rpc/interceptor/rate_limit.go index 052b9146d5d..47717488533 100644 --- a/common/rpc/interceptor/rate_limit.go +++ b/common/rpc/interceptor/rate_limit.go @@ -77,7 +77,8 @@ func (i *RateLimitInterceptor) Intercept( if !i.rateLimiter.Allow(time.Now().UTC(), quotas.NewRequest( methodName, token, - "", // this interceptor layer does not throttle based on caller + "", // this interceptor layer does not throttle based on caller name + "", // this interceptor layer does not throttle based on caller type )) { return nil, RateLimitServerBusy } diff --git a/common/xdc/nDCHistoryResender.go b/common/xdc/nDCHistoryResender.go index 18b11d528e7..51f8bdac4cf 100644 --- a/common/xdc/nDCHistoryResender.go +++ b/common/xdc/nDCHistoryResender.go @@ -58,6 +58,7 @@ type ( NDCHistoryResender interface { // SendSingleWorkflowHistory sends multiple run IDs's history events to remote SendSingleWorkflowHistory( + ctx context.Context, remoteClusterName string, namespaceID namespace.ID, workflowID string, @@ -111,6 +112,7 @@ func NewNDCHistoryResender( // SendSingleWorkflowHistory sends one run IDs's history events to remote func (n *NDCHistoryResenderImpl) SendSingleWorkflowHistory( + ctx context.Context, remoteClusterName string, namespaceID namespace.ID, workflowID string, @@ -121,18 +123,19 @@ func (n *NDCHistoryResenderImpl) SendSingleWorkflowHistory( endEventVersion int64, ) error { - ctx := context.Background() + resendCtx := context.Background() var cancel context.CancelFunc if n.rereplicationTimeout != nil { resendContextTimeout := n.rereplicationTimeout(namespaceID.String()) if resendContextTimeout > 0 { - ctx, cancel = context.WithTimeout(ctx, resendContextTimeout) + resendCtx, cancel = context.WithTimeout(resendCtx, resendContextTimeout) defer cancel() } } + rpc.CopyContextValues(resendCtx, ctx) historyIterator := collection.NewPagingIterator(n.getPaginationFn( - ctx, + resendCtx, remoteClusterName, namespaceID, workflowID, @@ -161,7 +164,7 @@ func (n *NDCHistoryResenderImpl) SendSingleWorkflowHistory( batch.rawEventBatch, batch.versionHistory.GetItems()) - err = n.sendReplicationRawRequest(ctx, replicationRequest) + err = n.sendReplicationRawRequest(resendCtx, replicationRequest) if err != nil { n.logger.Error("failed to replicate events", tag.WorkflowNamespaceID(namespaceID.String()), @@ -269,7 +272,7 @@ func (n *NDCHistoryResenderImpl) getHistory( return nil, err } - ctx, cancel := rpc.NewContextFromParentWithTimeoutAndHeaders(ctx, resendContextTimeout) + ctx, cancel := rpc.NewContextFromParentWithTimeoutAndVersionHeaders(ctx, resendContextTimeout) defer cancel() adminClient, err := n.clientBean.GetRemoteAdminClient(remoteClusterName) diff --git a/common/xdc/nDCHistoryResender_mock.go b/common/xdc/nDCHistoryResender_mock.go index 35c600e3bc6..e9f3b4b92ff 100644 --- a/common/xdc/nDCHistoryResender_mock.go +++ b/common/xdc/nDCHistoryResender_mock.go @@ -29,6 +29,7 @@ package xdc import ( + context "context" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -59,15 +60,15 @@ func (m *MockNDCHistoryResender) EXPECT() *MockNDCHistoryResenderMockRecorder { } // SendSingleWorkflowHistory mocks base method. -func (m *MockNDCHistoryResender) SendSingleWorkflowHistory(remoteClusterName string, namespaceID namespace.ID, workflowID, runID string, startEventID, startEventVersion, endEventID, endEventVersion int64) error { +func (m *MockNDCHistoryResender) SendSingleWorkflowHistory(ctx context.Context, remoteClusterName string, namespaceID namespace.ID, workflowID, runID string, startEventID, startEventVersion, endEventID, endEventVersion int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendSingleWorkflowHistory", remoteClusterName, namespaceID, workflowID, runID, startEventID, startEventVersion, endEventID, endEventVersion) + ret := m.ctrl.Call(m, "SendSingleWorkflowHistory", ctx, remoteClusterName, namespaceID, workflowID, runID, startEventID, startEventVersion, endEventID, endEventVersion) ret0, _ := ret[0].(error) return ret0 } // SendSingleWorkflowHistory indicates an expected call of SendSingleWorkflowHistory. -func (mr *MockNDCHistoryResenderMockRecorder) SendSingleWorkflowHistory(remoteClusterName, namespaceID, workflowID, runID, startEventID, startEventVersion, endEventID, endEventVersion interface{}) *gomock.Call { +func (mr *MockNDCHistoryResenderMockRecorder) SendSingleWorkflowHistory(ctx, remoteClusterName, namespaceID, workflowID, runID, startEventID, startEventVersion, endEventID, endEventVersion interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSingleWorkflowHistory", reflect.TypeOf((*MockNDCHistoryResender)(nil).SendSingleWorkflowHistory), remoteClusterName, namespaceID, workflowID, runID, startEventID, startEventVersion, endEventID, endEventVersion) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSingleWorkflowHistory", reflect.TypeOf((*MockNDCHistoryResender)(nil).SendSingleWorkflowHistory), ctx, remoteClusterName, namespaceID, workflowID, runID, startEventID, startEventVersion, endEventID, endEventVersion) } diff --git a/common/xdc/nDCHistoryResender_test.go b/common/xdc/nDCHistoryResender_test.go index 88798488cbf..2be2490b1b5 100644 --- a/common/xdc/nDCHistoryResender_test.go +++ b/common/xdc/nDCHistoryResender_test.go @@ -225,6 +225,7 @@ func (s *nDCHistoryResenderSuite) TestSendSingleWorkflowHistory() { }).Return(nil, nil).Times(2) err := s.rereplicator.SendSingleWorkflowHistory( + context.Background(), cluster.TestCurrentClusterName, s.namespaceID, workflowID, diff --git a/host/client_integration_test.go b/host/client_integration_test.go index 4bd997c18f5..540c051f27a 100644 --- a/host/client_integration_test.go +++ b/host/client_integration_test.go @@ -275,7 +275,7 @@ func (s *clientIntegrationSuite) TestClientDataConverter() { TaskQueue: s.taskQueue, WorkflowRunTimeout: time.Minute, } - ctx, cancel := rpc.NewContextWithTimeoutAndHeaders(time.Minute) + ctx, cancel := rpc.NewContextWithTimeoutAndVersionHeaders(time.Minute) defer cancel() s.worker.RegisterWorkflow(testDataConverterWorkflow) s.worker.RegisterActivity(testActivity) @@ -311,7 +311,7 @@ func (s *clientIntegrationSuite) TestClientDataConverter_Failed() { TaskQueue: s.taskQueue, WorkflowRunTimeout: time.Minute, } - ctx, cancel := rpc.NewContextWithTimeoutAndHeaders(time.Minute) + ctx, cancel := rpc.NewContextWithTimeoutAndVersionHeaders(time.Minute) defer cancel() s.worker.RegisterWorkflow(testDataConverterWorkflow) @@ -419,7 +419,7 @@ func (s *clientIntegrationSuite) TestClientDataConverter_WithChild() { TaskQueue: s.taskQueue, WorkflowRunTimeout: time.Minute, } - ctx, cancel := rpc.NewContextWithTimeoutAndHeaders(time.Minute) + ctx, cancel := rpc.NewContextWithTimeoutAndVersionHeaders(time.Minute) defer cancel() s.worker.RegisterWorkflow(testParentWorkflow) s.worker.RegisterWorkflow(testChildWorkflow) diff --git a/host/context.go b/host/context.go index c4b62d08d90..eec0972cddb 100644 --- a/host/context.go +++ b/host/context.go @@ -33,6 +33,6 @@ import ( // NewContext create new context with default timeout 90 seconds. func NewContext() context.Context { - ctx, _ := rpc.NewContextWithTimeoutAndHeaders(90 * time.Second) + ctx, _ := rpc.NewContextWithTimeoutAndVersionHeaders(90 * time.Second) return ctx } diff --git a/host/integrationbase.go b/host/integrationbase.go index aa0b32ce04b..1cb8c00d408 100644 --- a/host/integrationbase.go +++ b/host/integrationbase.go @@ -184,7 +184,7 @@ func (s *IntegrationBase) registerNamespace( visibilityArchivalState enumspb.ArchivalState, visibilityArchivalURI string, ) error { - ctx, cancel := rpc.NewContextWithTimeoutAndHeaders(10000 * time.Second) + ctx, cancel := rpc.NewContextWithTimeoutAndVersionHeaders(10000 * time.Second) defer cancel() _, err := s.engine.RegisterNamespace(ctx, &workflowservice.RegisterNamespaceRequest{ Namespace: namespace, @@ -202,7 +202,7 @@ func (s *IntegrationBase) registerNamespace( func (s *IntegrationBase) deleteNamespace( namespace string, ) error { - ctx, cancel := rpc.NewContextWithTimeoutAndHeaders(10000 * time.Second) + ctx, cancel := rpc.NewContextWithTimeoutAndVersionHeaders(10000 * time.Second) defer cancel() _, err := s.engine.UpdateNamespace(ctx, &workflowservice.UpdateNamespaceRequest{ Namespace: namespace, diff --git a/host/signal_workflow_test.go b/host/signal_workflow_test.go index 968daeb27ea..8ba1deedd1c 100644 --- a/host/signal_workflow_test.go +++ b/host/signal_workflow_test.go @@ -1519,7 +1519,7 @@ func (s *integrationSuite) TestSignalWithStartWorkflow_IDReusePolicy() { Identity: identity, WorkflowIdReusePolicy: enumspb.WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE, } - ctx, _ := rpc.NewContextWithTimeoutAndHeaders(5 * time.Second) + ctx, _ := rpc.NewContextWithTimeoutAndVersionHeaders(5 * time.Second) resp, err := s.engine.SignalWithStartWorkflowExecution(ctx, sRequest) s.Nil(resp) s.Error(err) @@ -1528,7 +1528,7 @@ func (s *integrationSuite) TestSignalWithStartWorkflow_IDReusePolicy() { // test policy WorkflowIdReusePolicyAllowDuplicateFailedOnly sRequest.WorkflowIdReusePolicy = enumspb.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY - ctx, _ = rpc.NewContextWithTimeoutAndHeaders(5 * time.Second) + ctx, _ = rpc.NewContextWithTimeoutAndVersionHeaders(5 * time.Second) resp, err = s.engine.SignalWithStartWorkflowExecution(ctx, sRequest) s.Nil(resp) s.Error(err) @@ -1537,7 +1537,7 @@ func (s *integrationSuite) TestSignalWithStartWorkflow_IDReusePolicy() { // test policy WorkflowIdReusePolicyAllowDuplicate sRequest.WorkflowIdReusePolicy = enumspb.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE - ctx, _ = rpc.NewContextWithTimeoutAndHeaders(5 * time.Second) + ctx, _ = rpc.NewContextWithTimeoutAndVersionHeaders(5 * time.Second) resp, err = s.engine.SignalWithStartWorkflowExecution(ctx, sRequest) s.NoError(err) s.NotEmpty(resp.GetRunId()) diff --git a/service/frontend/adminHandler.go b/service/frontend/adminHandler.go index 2a1d45bca05..536527afda6 100644 --- a/service/frontend/adminHandler.go +++ b/service/frontend/adminHandler.go @@ -1422,7 +1422,7 @@ func (adh *AdminHandler) RefreshWorkflowTasks( // ResendReplicationTasks requests replication task from remote cluster func (adh *AdminHandler) ResendReplicationTasks( - _ context.Context, + ctx context.Context, request *adminservice.ResendReplicationTasksRequest, ) (_ *adminservice.ResendReplicationTasksResponse, err error) { defer log.CapturePanic(adh.logger, &err) @@ -1444,6 +1444,7 @@ func (adh *AdminHandler) ResendReplicationTasks( adh.logger, ) if err := resender.SendSingleWorkflowHistory( + ctx, request.GetRemoteCluster(), namespace.ID(request.GetNamespaceId()), request.GetWorkflowId(), diff --git a/service/frontend/configs/quotas.go b/service/frontend/configs/quotas.go index 3a80ea6963e..3d6548ec59e 100644 --- a/service/frontend/configs/quotas.go +++ b/service/frontend/configs/quotas.go @@ -76,12 +76,7 @@ var ( "ListTaskQueuePartitions": 3, } - ExecutionAPIPriorities = map[int]struct{}{ - 0: {}, - 1: {}, - 2: {}, - 3: {}, - } + ExecutionAPIPrioritiesOrdered = []int{0, 1, 2, 3} VisibilityAPIToPriority = map[string]int{ "CountWorkflowExecutions": 0, @@ -92,9 +87,7 @@ var ( "ListArchivedWorkflowExecutions": 0, } - VisibilityAPIPriorities = map[int]struct{}{ - 0: {}, - } + VisibilityAPIPrioritiesOrdered = []int{0} OtherAPIToPriority = map[string]int{ "GetClusterInfo": 0, @@ -116,9 +109,7 @@ var ( "ListSchedules": 0, } - OtherAPIPriorities = map[int]struct{}{ - 0: {}, - } + OtherAPIPrioritiesOrdered = []int{0} ) type ( @@ -179,28 +170,43 @@ func NewExecutionPriorityRateLimiter( rateBurstFn quotas.RateBurst, ) quotas.RequestRateLimiter { rateLimiters := make(map[int]quotas.RateLimiter) - for priority := range ExecutionAPIPriorities { + for priority := range ExecutionAPIPrioritiesOrdered { rateLimiters[priority] = quotas.NewDynamicRateLimiter(rateBurstFn, time.Minute) } - return quotas.NewPriorityRateLimiter(ExecutionAPIToPriority, rateLimiters) + return quotas.NewPriorityRateLimiter(func(req quotas.Request) int { + if priority, ok := ExecutionAPIToPriority[req.API]; ok { + return priority + } + return ExecutionAPIPrioritiesOrdered[len(ExecutionAPIPrioritiesOrdered)-1] + }, rateLimiters) } func NewVisibilityPriorityRateLimiter( rateBurstFn quotas.RateBurst, ) quotas.RequestRateLimiter { rateLimiters := make(map[int]quotas.RateLimiter) - for priority := range VisibilityAPIPriorities { + for priority := range VisibilityAPIPrioritiesOrdered { rateLimiters[priority] = quotas.NewDynamicRateLimiter(rateBurstFn, time.Minute) } - return quotas.NewPriorityRateLimiter(VisibilityAPIToPriority, rateLimiters) + return quotas.NewPriorityRateLimiter(func(req quotas.Request) int { + if priority, ok := VisibilityAPIToPriority[req.API]; ok { + return priority + } + return VisibilityAPIPrioritiesOrdered[len(VisibilityAPIPrioritiesOrdered)-1] + }, rateLimiters) } func NewOtherAPIPriorityRateLimiter( rateBurstFn quotas.RateBurst, ) quotas.RequestRateLimiter { rateLimiters := make(map[int]quotas.RateLimiter) - for priority := range OtherAPIPriorities { + for priority := range OtherAPIPrioritiesOrdered { rateLimiters[priority] = quotas.NewDynamicRateLimiter(rateBurstFn, time.Minute) } - return quotas.NewPriorityRateLimiter(OtherAPIToPriority, rateLimiters) + return quotas.NewPriorityRateLimiter(func(req quotas.Request) int { + if priority, ok := OtherAPIToPriority[req.API]; ok { + return priority + } + return OtherAPIPrioritiesOrdered[len(OtherAPIPrioritiesOrdered)-1] + }, rateLimiters) } diff --git a/service/frontend/configs/quotas_test.go b/service/frontend/configs/quotas_test.go index d3ad7a6e79d..e57ed27a5ce 100644 --- a/service/frontend/configs/quotas_test.go +++ b/service/frontend/configs/quotas_test.go @@ -31,6 +31,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "go.temporal.io/api/workflowservice/v1" + "golang.org/x/exp/slices" ) type ( @@ -59,27 +60,42 @@ func (s *quotasSuite) TearDownTest() { } func (s *quotasSuite) TestExecutionAPIToPriorityMapping() { - mapping := make(map[int]struct{}) for _, priority := range ExecutionAPIToPriority { - mapping[priority] = struct{}{} + index := slices.Index(ExecutionAPIPrioritiesOrdered, priority) + s.NotEqual(-1, index) } - s.Equal(mapping, ExecutionAPIPriorities) } func (s *quotasSuite) TestVisibilityAPIToPriorityMapping() { - mapping := make(map[int]struct{}) for _, priority := range VisibilityAPIToPriority { - mapping[priority] = struct{}{} + index := slices.Index(VisibilityAPIPrioritiesOrdered, priority) + s.NotEqual(-1, index) } - s.Equal(mapping, VisibilityAPIPriorities) } func (s *quotasSuite) TestOtherAPIToPriorityMapping() { - mapping := make(map[int]struct{}) for _, priority := range OtherAPIToPriority { - mapping[priority] = struct{}{} + index := slices.Index(OtherAPIPrioritiesOrdered, priority) + s.NotEqual(-1, index) + } +} + +func (s *quotasSuite) TestExecutionAPIPrioritiesOrdered() { + for idx := range ExecutionAPIPrioritiesOrdered[1:] { + s.True(ExecutionAPIPrioritiesOrdered[idx] < ExecutionAPIPrioritiesOrdered[idx+1]) + } +} + +func (s *quotasSuite) TestVisibilityAPIPrioritiesOrdered() { + for idx := range VisibilityAPIPrioritiesOrdered[1:] { + s.True(VisibilityAPIPrioritiesOrdered[idx] < VisibilityAPIPrioritiesOrdered[idx+1]) + } +} + +func (s *quotasSuite) TestOtherAPIPrioritiesOrdered() { + for idx := range OtherAPIPrioritiesOrdered[1:] { + s.True(OtherAPIPrioritiesOrdered[idx] < OtherAPIPrioritiesOrdered[idx+1]) } - s.Equal(mapping, OtherAPIPriorities) } func (s *quotasSuite) TestExecutionAPIs() { diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 7b0243e09e1..895353c09ce 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -79,6 +79,7 @@ var Module = fx.Options( fx.Provide(NamespaceValidatorInterceptorProvider), fx.Provide(NamespaceRateLimitInterceptorProvider), fx.Provide(SDKVersionInterceptorProvider), + fx.Provide(CallerInfoInterceptorProvider), fx.Provide(GrpcServerOptionsProvider), fx.Provide(VisibilityManagerProvider), fx.Provide(ThrottledLoggerRpsFnProvider), @@ -137,6 +138,7 @@ func GrpcServerOptionsProvider( rateLimitInterceptor *interceptor.RateLimitInterceptor, traceInterceptor telemetry.ServerTraceInterceptor, sdkVersionInterceptor *interceptor.SDKVersionInterceptor, + callerInfoInterceptor *interceptor.CallerInfoInterceptor, authorizer authorization.Authorizer, claimMapper authorization.ClaimMapper, audienceGetter authorization.JWTAudienceMapper, @@ -176,6 +178,7 @@ func GrpcServerOptionsProvider( audienceGetter, ), sdkVersionInterceptor.Intercept, + callerInfoInterceptor.Intercept, } if len(customInterceptors) > 0 { interceptors = append(interceptors, customInterceptors...) @@ -303,6 +306,12 @@ func SDKVersionInterceptorProvider() *interceptor.SDKVersionInterceptor { return interceptor.NewSDKVersionInterceptor() } +func CallerInfoInterceptorProvider( + namespaceRegistry namespace.Registry, +) *interceptor.CallerInfoInterceptor { + return interceptor.NewCallerInfoInterceptor(namespaceRegistry) +} + func PersistenceMaxQpsProvider( serviceConfig *Config, ) persistenceClient.PersistenceMaxQps { diff --git a/service/history/configs/quotas.go b/service/history/configs/quotas.go index 0a501408c3e..381a56bc606 100644 --- a/service/history/configs/quotas.go +++ b/service/history/configs/quotas.go @@ -77,17 +77,20 @@ var ( "UpdateWorkflow": 0, } - APIPriorities = map[int]struct{}{ - 0: {}, - } + APIPrioritiesOrdered = []int{0} ) func NewPriorityRateLimiter( rateFn quotas.RateFn, ) quotas.RequestRateLimiter { rateLimiters := make(map[int]quotas.RateLimiter) - for priority := range APIPriorities { + for priority := range APIPrioritiesOrdered { rateLimiters[priority] = quotas.NewDefaultIncomingRateLimiter(rateFn) } - return quotas.NewPriorityRateLimiter(APIToPriority, rateLimiters) + return quotas.NewPriorityRateLimiter(func(req quotas.Request) int { + if priority, ok := APIToPriority[req.API]; ok { + return priority + } + return APIPrioritiesOrdered[len(APIPrioritiesOrdered)-1] + }, rateLimiters) } diff --git a/service/history/configs/quotas_test.go b/service/history/configs/quotas_test.go index 6ba1e364c6a..6ea7b6c0da7 100644 --- a/service/history/configs/quotas_test.go +++ b/service/history/configs/quotas_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "golang.org/x/exp/slices" "go.temporal.io/server/api/historyservice/v1" ) @@ -60,11 +61,16 @@ func (s *quotasSuite) TearDownTest() { } func (s *quotasSuite) TestAPIToPriorityMapping() { - mapping := make(map[int]struct{}) for _, priority := range APIToPriority { - mapping[priority] = struct{}{} + index := slices.Index(APIPrioritiesOrdered, priority) + s.NotEqual(-1, index) + } +} + +func (s *quotasSuite) TestAPIPrioritiesOrdered() { + for idx := range APIPrioritiesOrdered[1:] { + s.True(APIPrioritiesOrdered[idx] < APIPrioritiesOrdered[idx+1]) } - s.Equal(mapping, APIPriorities) } func (s *quotasSuite) TestAPIs() { diff --git a/service/history/queues/executable.go b/service/history/queues/executable.go index f339746a744..3d1a45e6628 100644 --- a/service/history/queues/executable.go +++ b/service/history/queues/executable.go @@ -37,6 +37,7 @@ import ( "go.temporal.io/server/common/backoff" "go.temporal.io/server/common/clock" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -160,6 +161,7 @@ func (e *executableImpl) Execute() error { } ctx := metrics.AddMetricsContext(context.Background()) + ctx = headers.SetCallerInfo(ctx, e.GetNamespaceID(), headers.CallerTypeBackground) startTime := e.timeSource.Now() var err error diff --git a/service/history/replication/task_executor.go b/service/history/replication/task_executor.go index f2abf25d077..85a0c60c14e 100644 --- a/service/history/replication/task_executor.go +++ b/service/history/replication/task_executor.go @@ -35,6 +35,7 @@ import ( enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/historyservice/v1" replicationspb "go.temporal.io/server/api/replication/v1" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -160,7 +161,7 @@ func (e *taskExecutorImpl) handleActivityTask( LastWorkerIdentity: attr.LastWorkerIdentity, VersionHistory: attr.GetVersionHistory(), } - ctx, cancel := context.WithTimeout(context.Background(), replicationTimeout) + ctx, cancel := e.newTaskContext(task.TaskType, attr.NamespaceId) defer cancel() err = e.historyEngine.SyncActivity(ctx, request) @@ -174,6 +175,7 @@ func (e *taskExecutorImpl) handleActivityTask( defer stopwatch.Stop() resendErr := e.nDCHistoryResender.SendSingleWorkflowHistory( + ctx, e.remoteCluster, namespace.ID(retryErr.NamespaceId), retryErr.WorkflowId, @@ -225,7 +227,7 @@ func (e *taskExecutorImpl) handleHistoryReplicationTask( // new run events does not need version history since there is no prior events NewRunEvents: attr.NewRunEvents, } - ctx, cancel := context.WithTimeout(context.Background(), replicationTimeout) + ctx, cancel := e.newTaskContext(task.TaskType, attr.NamespaceId) defer cancel() err = e.historyEngine.ReplicateEventsV2(ctx, request) @@ -239,6 +241,7 @@ func (e *taskExecutorImpl) handleHistoryReplicationTask( defer resendStopWatch.Stop() resendErr := e.nDCHistoryResender.SendSingleWorkflowHistory( + ctx, e.remoteCluster, namespace.ID(retryErr.NamespaceId), retryErr.WorkflowId, @@ -280,7 +283,7 @@ func (e *taskExecutorImpl) handleSyncWorkflowStateTask( return err } - ctx, cancel := context.WithTimeout(context.Background(), replicationTimeout) + ctx, cancel := e.newTaskContext(task.TaskType, executionInfo.NamespaceId) defer cancel() return e.historyEngine.ReplicateWorkflowState(ctx, &historyservice.ReplicateWorkflowStateRequest{ @@ -339,3 +342,15 @@ func (e *taskExecutorImpl) cleanupWorkflowExecution(ctx context.Context, namespa false, ) } + +func (e *taskExecutorImpl) newTaskContext( + taskType enumsspb.ReplicationTaskType, + namespaceID string, +) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(context.Background(), replicationTimeout) + + namespace, _ := e.namespaceRegistry.GetNamespaceName(namespace.ID(namespaceID)) + ctx = headers.SetCallerInfo(ctx, namespace.String(), headers.CallerTypeBackground) + + return ctx, cancel +} diff --git a/service/history/replication/task_executor_test.go b/service/history/replication/task_executor_test.go index 9e76410083c..e0dda1d4ea3 100644 --- a/service/history/replication/task_executor_test.go +++ b/service/history/replication/task_executor_test.go @@ -271,6 +271,7 @@ func (s *taskExecutorSuite) TestProcessTaskOnce_SyncActivityReplicationTask_Rese ) s.mockEngine.EXPECT().SyncActivity(gomock.Any(), request).Return(resendErr) s.nDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.remoteCluster, namespaceID, workflowID, @@ -358,6 +359,7 @@ func (s *taskExecutorSuite) TestProcess_HistoryReplicationTask_Resend() { ) s.mockEngine.EXPECT().ReplicateEventsV2(gomock.Any(), request).Return(resendErr) s.nDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.remoteCluster, namespaceID, workflowID, diff --git a/service/history/replication/task_fetcher.go b/service/history/replication/task_fetcher.go index 9e5117846a5..3ebdeca56a4 100644 --- a/service/history/replication/task_fetcher.go +++ b/service/history/replication/task_fetcher.go @@ -37,6 +37,7 @@ import ( "go.temporal.io/server/common" "go.temporal.io/server/common/backoff" "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/quotas" @@ -416,8 +417,9 @@ func (f *replicationTaskFetcherWorker) getMessages() error { tokens = append(tokens, request.token) } - ctx, cancel := rpc.NewContextWithTimeoutAndHeaders(fetchTaskRequestTimeout) + ctx, cancel := rpc.NewContextWithTimeoutAndVersionHeaders(fetchTaskRequestTimeout) defer cancel() + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) request := &adminservice.GetReplicationMessagesRequest{ Tokens: tokens, diff --git a/service/history/replication/task_processor.go b/service/history/replication/task_processor.go index c7d8d4ac8fe..484d50e1604 100644 --- a/service/history/replication/task_processor.go +++ b/service/history/replication/task_processor.go @@ -42,6 +42,7 @@ import ( "go.temporal.io/server/common/backoff" "go.temporal.io/server/common/collection" "go.temporal.io/server/common/convert" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -302,6 +303,8 @@ func (p *taskProcessorImpl) handleSyncShardStatus( p.metricsClient.Scope(metrics.HistorySyncShardStatusScope).IncCounter(metrics.SyncShardFromRemoteCounter) ctx, cancel := context.WithTimeout(context.Background(), replicationTimeout) defer cancel() + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + return p.historyEngine.SyncShardStatus(ctx, &historyservice.SyncShardStatusRequest{ SourceCluster: p.sourceCluster, ShardId: p.shard.GetShardID(), diff --git a/service/history/shard/context_impl.go b/service/history/shard/context_impl.go index d01bdf1f34a..0ff895b7a42 100644 --- a/service/history/shard/context_impl.go +++ b/service/history/shard/context_impl.go @@ -49,6 +49,7 @@ import ( "go.temporal.io/server/common/convert" "go.temporal.io/server/common/definition" "go.temporal.io/server/common/future" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" @@ -57,6 +58,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/primitives/timestamp" + "go.temporal.io/server/common/rpc" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/util" "go.temporal.io/server/service/history/configs" @@ -1161,7 +1163,7 @@ func (s *ContextImpl) renewRangeLocked(isStealing bool) error { updatedShardInfo.StolenSinceRenew++ } - ctx, cancel := context.WithTimeout(s.lifecycleCtx, shardIOTimeout) + ctx, cancel := s.newIOContext() defer cancel() err := s.persistenceShardManager.UpdateShard(ctx, &persistence.UpdateShardRequest{ ShardInfo: updatedShardInfo.ShardInfo, @@ -1214,7 +1216,7 @@ func (s *ContextImpl) updateShardInfoLocked() error { updatedShardInfo := copyShardInfo(s.shardInfo) s.emitShardInfoMetricsLogsLocked() - ctx, cancel := context.WithTimeout(s.lifecycleCtx, shardIOTimeout) + ctx, cancel := s.newIOContext() defer cancel() err = s.persistenceShardManager.UpdateShard(ctx, &persistence.UpdateShardRequest{ ShardInfo: updatedShardInfo.ShardInfo, @@ -1664,7 +1666,7 @@ func (s *ContextImpl) loadShardMetadata(ownershipChanged *bool) error { s.rUnlock() // We don't have any shardInfo yet, load it (outside of context rwlock) - ctx, cancel := context.WithTimeout(s.lifecycleCtx, shardIOTimeout) + ctx, cancel := s.newIOContext() defer cancel() resp, err := s.persistenceShardManager.GetOrCreateShard(ctx, &persistence.GetOrCreateShardRequest{ ShardID: s.shardID, @@ -1790,7 +1792,7 @@ func (s *ContextImpl) getOrUpdateRemoteClusterInfoLocked(clusterName string) *re func (s *ContextImpl) acquireShard() { // This is called in two contexts: initially acquiring the rangeid lock, and trying to // re-acquire it after a persistence error. In both cases, we retry the acquire operation - // (renewRangeLocked) for 5 minutes. Each individual attempt uses shardIOTimeout (10s) as + // (renewRangeLocked) for 5 minutes. Each individual attempt uses shardIOTimeout (5s) as // the timeout. This lets us handle a few minutes of persistence unavailability without // dropping and reloading the whole shard context, which is relatively expensive (includes // caches that would have to be refilled, etc.). @@ -2035,9 +2037,17 @@ func (s *ContextImpl) ensureMinContextTimeout( } newContext, cancel := context.WithTimeout(s.lifecycleCtx, minContextTimeout) + newContext = rpc.CopyContextValues(newContext, ctx) return newContext, cancel, nil } +func (s *ContextImpl) newIOContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(s.lifecycleCtx, shardIOTimeout) + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeSystem) + + return ctx, cancel +} + func OperationPossiblySucceeded(err error) bool { if err == consts.ErrConflict { return false diff --git a/service/history/timerQueueAckMgr.go b/service/history/timerQueueAckMgr.go index 3b4b32b17c9..ca1736b9688 100644 --- a/service/history/timerQueueAckMgr.go +++ b/service/history/timerQueueAckMgr.go @@ -25,7 +25,6 @@ package history import ( - "context" "math" "sort" "sync" @@ -351,6 +350,9 @@ MoveAckLevelLoop: // this function does not take cluster name as parameter, due to we only have one timer queue on Cassandra // all timer tasks are in this queue and filter will be applied. func (t *timerQueueAckMgrImpl) getTimerTasks(minTimestamp time.Time, maxTimestamp time.Time, batchSize int, pageToken []byte) ([]tasks.Task, []byte, error) { + ctx, cancel := newQueueIOContext() + defer cancel() + request := &persistence.GetHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryTimer, @@ -359,7 +361,7 @@ func (t *timerQueueAckMgrImpl) getTimerTasks(minTimestamp time.Time, maxTimestam BatchSize: batchSize, NextPageToken: pageToken, } - response, err := t.executionMgr.GetHistoryTasks(context.TODO(), request) + response, err := t.executionMgr.GetHistoryTasks(ctx, request) if err != nil { return nil, nil, err } diff --git a/service/history/timerQueueProcessor.go b/service/history/timerQueueProcessor.go index 7ba82472da5..1a4258be668 100644 --- a/service/history/timerQueueProcessor.go +++ b/service/history/timerQueueProcessor.go @@ -25,7 +25,6 @@ package history import ( - "context" "fmt" "sync" "sync/atomic" @@ -343,7 +342,10 @@ func (t *timerQueueProcessorImpl) completeTimers() error { t.metricsClient.IncCounter(metrics.TimerQueueProcessorScope, metrics.TaskBatchCompleteCounter) if lowerAckLevel.FireTime.Before(upperAckLevel.FireTime) { - err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(context.TODO(), &persistence.RangeCompleteHistoryTasksRequest{ + ctx, cancel := newQueueIOContext() + defer cancel() + + err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(ctx, &persistence.RangeCompleteHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryTimer, InclusiveMinTaskKey: tasks.NewKey(lowerAckLevel.FireTime, 0), diff --git a/service/history/timerQueueStandbyTaskExecutor.go b/service/history/timerQueueStandbyTaskExecutor.go index 1057e362af3..0692e9e2442 100644 --- a/service/history/timerQueueStandbyTaskExecutor.go +++ b/service/history/timerQueueStandbyTaskExecutor.go @@ -556,6 +556,7 @@ func (t *timerQueueStandbyTaskExecutor) fetchHistoryFromRemote( // NOTE: history resend may take long time and its timeout is currently // controlled by a separate dynamicconfig config: StandbyTaskReReplicationContextTimeout if err = t.nDCHistoryResender.SendSingleWorkflowHistory( + ctx, remoteClusterName, namespace.ID(taskInfo.GetNamespaceID()), taskInfo.GetWorkflowID(), diff --git a/service/history/timerQueueStandbyTaskExecutor_test.go b/service/history/timerQueueStandbyTaskExecutor_test.go index 4141a9b1e69..c7a0907e83f 100644 --- a/service/history/timerQueueStandbyTaskExecutor_test.go +++ b/service/history/timerQueueStandbyTaskExecutor_test.go @@ -279,6 +279,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessUserTimerTimeout_Pending }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(timerTask.NamespaceID), timerTask.WorkflowID, @@ -507,6 +508,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessActivityTimeout_Pending( }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(timerTask.NamespaceID), timerTask.WorkflowID, @@ -850,6 +852,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessWorkflowTaskTimeout_Pend }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(timerTask.NamespaceID), timerTask.WorkflowID, @@ -1004,6 +1007,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessWorkflowBackoffTimer_Pen }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(timerTask.NamespaceID), timerTask.WorkflowID, @@ -1129,6 +1133,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessWorkflowTimeout_Pending( }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(timerTask.NamespaceID), timerTask.WorkflowID, @@ -1482,6 +1487,7 @@ func (s *timerQueueStandbyTaskExecutorSuite) TestProcessActivityRetryTimer_Pendi Execution: &execution, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, s.namespaceID, execution.WorkflowId, diff --git a/service/history/transferQueueProcessorBase.go b/service/history/transferQueueProcessorBase.go index 6d36ada2a74..1fa783c209c 100644 --- a/service/history/transferQueueProcessorBase.go +++ b/service/history/transferQueueProcessorBase.go @@ -26,7 +26,9 @@ package history import ( "context" + "time" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/persistence" @@ -37,6 +39,10 @@ import ( "go.temporal.io/server/service/history/tasks" ) +const ( + queueIOTimeout = 5 * time.Second +) + type ( maxReadLevel func() int64 @@ -76,7 +82,10 @@ func newTransferQueueProcessorBase( func (t *transferQueueProcessorBase) readTasks( readLevel int64, ) ([]tasks.Task, bool, error) { - response, err := t.executionManager.GetHistoryTasks(context.TODO(), &persistence.GetHistoryTasksRequest{ + ctx, cancel := newQueueIOContext() + defer cancel() + + response, err := t.executionManager.GetHistoryTasks(ctx, &persistence.GetHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryTransfer, InclusiveMinTaskKey: tasks.NewImmediateKey(readLevel + 1), @@ -120,3 +129,9 @@ func newTransferTaskScheduler( logger, ) } + +func newQueueIOContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(context.Background(), queueIOTimeout) + headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + return ctx, cancel +} diff --git a/service/history/transferQueueStandbyTaskExecutor.go b/service/history/transferQueueStandbyTaskExecutor.go index 8dfae4c418e..ba4693798e3 100644 --- a/service/history/transferQueueStandbyTaskExecutor.go +++ b/service/history/transferQueueStandbyTaskExecutor.go @@ -659,6 +659,7 @@ func (t *transferQueueStandbyTaskExecutor) fetchHistoryFromRemote( // NOTE: history resend may take long time and its timeout is currently // controlled by a separate dynamicconfig config: StandbyTaskReReplicationContextTimeout if err = t.nDCHistoryResender.SendSingleWorkflowHistory( + ctx, remoteClusterName, namespace.ID(taskInfo.GetNamespaceID()), taskInfo.GetWorkflowID(), diff --git a/service/history/transferQueueStandbyTaskExecutor_test.go b/service/history/transferQueueStandbyTaskExecutor_test.go index 34d5fd47377..e95f8e0f238 100644 --- a/service/history/transferQueueStandbyTaskExecutor_test.go +++ b/service/history/transferQueueStandbyTaskExecutor_test.go @@ -279,6 +279,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessActivityTask_Pending( }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(transferTask.NamespaceID), transferTask.WorkflowID, @@ -421,6 +422,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessWorkflowTask_Pending( }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(transferTask.NamespaceID), transferTask.WorkflowID, @@ -728,6 +730,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessCancelExecution_Pendi }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(transferTask.NamespaceID), transferTask.WorkflowID, @@ -883,6 +886,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessSignalExecution_Pendi }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(transferTask.NamespaceID), transferTask.WorkflowID, @@ -1037,6 +1041,7 @@ func (s *transferQueueStandbyTaskExecutorSuite) TestProcessStartChildExecution_P }, }).Return(&adminservice.RefreshWorkflowTasksResponse{}, nil) s.mockNDCHistoryResender.EXPECT().SendSingleWorkflowHistory( + gomock.Any(), s.clusterName, namespace.ID(transferTask.NamespaceID), transferTask.WorkflowID, diff --git a/service/history/visibilityQueueProcessor.go b/service/history/visibilityQueueProcessor.go index 0cab5680cc0..f56d0dfb7e4 100644 --- a/service/history/visibilityQueueProcessor.go +++ b/service/history/visibilityQueueProcessor.go @@ -25,7 +25,6 @@ package history import ( - "context" "errors" "sync/atomic" "time" @@ -295,7 +294,10 @@ func (t *visibilityQueueProcessorImpl) completeTask() error { t.metricsClient.IncCounter(metrics.VisibilityQueueProcessorScope, metrics.TaskBatchCompleteCounter) if lowerAckLevel < upperAckLevel { - err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(context.TODO(), &persistence.RangeCompleteHistoryTasksRequest{ + ctx, cancel := newQueueIOContext() + defer cancel() + + err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(ctx, &persistence.RangeCompleteHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryVisibility, InclusiveMinTaskKey: tasks.NewImmediateKey(lowerAckLevel + 1), @@ -320,7 +322,10 @@ func (t *visibilityQueueProcessorImpl) notifyNewTask() { func (t *visibilityQueueProcessorImpl) readTasks( readLevel int64, ) ([]tasks.Task, bool, error) { - response, err := t.executionManager.GetHistoryTasks(context.TODO(), &persistence.GetHistoryTasksRequest{ + ctx, cancel := newQueueIOContext() + defer cancel() + + response, err := t.executionManager.GetHistoryTasks(ctx, &persistence.GetHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryVisibility, InclusiveMinTaskKey: tasks.NewImmediateKey(readLevel + 1), diff --git a/service/history/workflow/context.go b/service/history/workflow/context.go index b3551a28c12..0c7cb8206e2 100644 --- a/service/history/workflow/context.go +++ b/service/history/workflow/context.go @@ -855,7 +855,7 @@ func (c *ContextImpl) ReapplyEvents( if sourceCluster == nil { return serviceerror.NewInternal(fmt.Sprintf("cannot find cluster config %v to do reapply", activeCluster)) } - ctx2, cancel2 := rpc.NewContextWithTimeoutAndHeaders(defaultRemoteCallTimeout) + ctx2, cancel2 := rpc.NewContextWithTimeoutAndVersionHeaders(defaultRemoteCallTimeout) defer cancel2() _, err = sourceCluster.ReapplyEvents( ctx2, diff --git a/service/matching/configs/quotas.go b/service/matching/configs/quotas.go index 5c57fe56e8e..b6c9b7e5925 100644 --- a/service/matching/configs/quotas.go +++ b/service/matching/configs/quotas.go @@ -43,17 +43,20 @@ var ( "UpdateWorkerBuildIdOrdering": 0, } - APIPriorities = map[int]struct{}{ - 0: {}, - } + APIPrioritiesOrdered = []int{0} ) func NewPriorityRateLimiter( rateFn quotas.RateFn, ) quotas.RequestRateLimiter { rateLimiters := make(map[int]quotas.RateLimiter) - for priority := range APIPriorities { + for priority := range APIPrioritiesOrdered { rateLimiters[priority] = quotas.NewDefaultIncomingRateLimiter(rateFn) } - return quotas.NewPriorityRateLimiter(APIToPriority, rateLimiters) + return quotas.NewPriorityRateLimiter(func(req quotas.Request) int { + if priority, ok := APIToPriority[req.API]; ok { + return priority + } + return APIPrioritiesOrdered[len(APIPrioritiesOrdered)-1] + }, rateLimiters) } diff --git a/service/matching/configs/quotas_test.go b/service/matching/configs/quotas_test.go index 073c5952db2..863cc50fc13 100644 --- a/service/matching/configs/quotas_test.go +++ b/service/matching/configs/quotas_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "golang.org/x/exp/slices" "go.temporal.io/server/api/matchingservice/v1" ) @@ -60,11 +61,16 @@ func (s *quotasSuite) TearDownTest() { } func (s *quotasSuite) TestAPIToPriorityMapping() { - mapping := make(map[int]struct{}) for _, priority := range APIToPriority { - mapping[priority] = struct{}{} + index := slices.Index(APIPrioritiesOrdered, priority) + s.NotEqual(-1, index) + } +} + +func (s *quotasSuite) TestAPIPrioritiesOrdered() { + for idx := range APIPrioritiesOrdered[1:] { + s.True(APIPrioritiesOrdered[idx] < APIPrioritiesOrdered[idx+1]) } - s.Equal(mapping, APIPriorities) } func (s *quotasSuite) TestAPIs() { diff --git a/service/worker/replicator/namespace_replication_message_processor.go b/service/worker/replicator/namespace_replication_message_processor.go index d90f6599b5b..76b3a3c36d4 100644 --- a/service/worker/replicator/namespace_replication_message_processor.go +++ b/service/worker/replicator/namespace_replication_message_processor.go @@ -145,7 +145,7 @@ func (p *namespaceReplicationMessageProcessor) getAndHandleNamespaceReplicationT return } - ctx, cancel := rpc.NewContextWithTimeoutAndHeaders(fetchTaskRequestTimeout) + ctx, cancel := rpc.NewContextWithTimeoutAndVersionHeaders(fetchTaskRequestTimeout) request := &adminservice.GetNamespaceReplicationMessagesRequest{ ClusterName: p.currentCluster, LastRetrievedMessageId: p.lastRetrievedMessageID, From 4e01af4289f312082c32deaf52999b4fc848d071 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 25 Jul 2022 11:08:46 -0700 Subject: [PATCH 02/13] better ratelimiting --- common/headers/caller_info.go | 1 - common/persistence/client/quotas.go | 135 ++++++++++++++++- .../persistenceRateLimitedClients.go | 141 +++++++++--------- service/history/shard/context_impl.go | 2 +- 4 files changed, 202 insertions(+), 77 deletions(-) diff --git a/common/headers/caller_info.go b/common/headers/caller_info.go index 4bf6a7f3c10..d75b4d8d11a 100644 --- a/common/headers/caller_info.go +++ b/common/headers/caller_info.go @@ -32,7 +32,6 @@ import ( const ( CallerTypeAPI = "api" - CallerTypeSystem = "system" CallerTypeBackground = "background" ) diff --git a/common/persistence/client/quotas.go b/common/persistence/client/quotas.go index 2807aebb09c..01433becf3e 100644 --- a/common/persistence/client/quotas.go +++ b/common/persistence/client/quotas.go @@ -30,10 +30,135 @@ import ( ) var ( - CallerTypeHeaderToPriority = map[string]int{ - headers.CallerTypeAPI: 0, - headers.CallerTypeSystem: 0, - headers.CallerTypeBackground: 1, + CallerTypeAndAPIToPriority = map[string]map[string]int{ + headers.CallerTypeAPI: { + "GetOrCreateShard": 0, + "UpdateShard": 0, + "AssertShardOwnership": 0, + + "CreateWorkflowExecution": 0, + "UpdateWorkflowExecution": 0, + "ConflictResolveWorkflowExecution": 0, + "DeleteWorkflowExecution": 0, + "DeleteCurrentWorkflowExecution": 0, + "GetCurrentExecution": 0, + "GetWorkflowExecution": 0, + "SetWorkflowExecution": 0, + "ListConcreteExecutions": 0, + "AddHistoryTasks": 0, + "GetHistoryTask": 0, + "GetHistoryTasks": 0, + "CompleteHistoryTask": 0, + "RangeCompleteHistoryTasks": 0, + "PutReplicationTaskToDLQ": 0, + "GetReplicationTasksFromDLQ": 0, + "DeleteReplicationTaskFromDLQ": 0, + "RangeDeleteReplicationTaskFromDLQ": 0, + "AppendHistoryNodes": 0, + "AppendRawHistoryNodes": 0, + "ReadHistoryBranch": 0, + "ReadHistoryBranchByBatch": 0, + "ReadHistoryBranchReverse": 0, + "ReadRawHistoryBranch": 0, + "ForkHistoryBranch": 0, + "DeleteHistoryBranch": 0, + "TrimHistoryBranch": 0, + "GetHistoryTree": 0, + "GetAllHistoryTreeBranches": 0, + + "CreateTaskQueue": 0, + "UpdateTaskQueue": 0, + "GetTaskQueue": 0, + "ListTaskQueue": 0, + "DeleteTaskQueue": 0, + "CreateTasks": 0, + "GetTasks": 0, + "CompleteTask": 0, + "CompleteTasksLessThan": 0, + + "CreateNamespace": 0, + "GetNamespace": 0, + "UpdateNamespace": 0, + "RenameNamespace": 0, + "DeleteNamespace": 0, + "DeleteNamespaceByName": 0, + "ListNamespaces": 0, + "GetMetadata": 0, + "InitializeSystemNamespaces": 0, + + "GetClusterMembers": 0, + "UpsertClusterMembership": 0, + "PruneClusterMembership": 0, + "ListClusterMetadata": 0, + "GetCurrentClusterMetadata": 0, + "GetClusterMetadata": 0, + "SaveClusterMetadata": 0, + "DeleteClusterMetadata": 0, + }, + headers.CallerTypeBackground: { + "GetOrCreateShard": 0, + "UpdateShard": 0, + "AssertShardOwnership": 1, + + "CreateWorkflowExecution": 1, + "UpdateWorkflowExecution": 1, + "ConflictResolveWorkflowExecution": 1, + "DeleteWorkflowExecution": 1, + "DeleteCurrentWorkflowExecution": 1, + "GetCurrentExecution": 1, + "GetWorkflowExecution": 1, + "SetWorkflowExecution": 1, + "ListConcreteExecutions": 1, + "AddHistoryTasks": 1, + "GetHistoryTask": 1, + "GetHistoryTasks": 1, + "CompleteHistoryTask": 1, + "RangeCompleteHistoryTasks": 0, // this is a preprequisite for updating ack level + "PutReplicationTaskToDLQ": 1, + "GetReplicationTasksFromDLQ": 1, + "DeleteReplicationTaskFromDLQ": 1, + "RangeDeleteReplicationTaskFromDLQ": 1, + "AppendHistoryNodes": 1, + "AppendRawHistoryNodes": 1, + "ReadHistoryBranch": 1, + "ReadHistoryBranchByBatch": 1, + "ReadHistoryBranchReverse": 1, + "ReadRawHistoryBranch": 1, + "ForkHistoryBranch": 1, + "DeleteHistoryBranch": 1, + "TrimHistoryBranch": 1, + "GetHistoryTree": 1, + "GetAllHistoryTreeBranches": 1, + + "CreateTaskQueue": 1, + "UpdateTaskQueue": 1, + "GetTaskQueue": 1, + "ListTaskQueue": 1, + "DeleteTaskQueue": 1, + "CreateTasks": 1, + "GetTasks": 1, + "CompleteTask": 1, + "CompleteTasksLessThan": 1, + + "CreateNamespace": 1, + "GetNamespace": 1, + "UpdateNamespace": 1, + "RenameNamespace": 1, + "DeleteNamespace": 1, + "DeleteNamespaceByName": 1, + "ListNamespaces": 1, + "GetMetadata": 1, + "InitializeSystemNamespaces": 1, + + "GetClusterMembers": 1, + "UpsertClusterMembership": 1, + "PruneClusterMembership": 1, + "ListClusterMetadata": 1, + "GetCurrentClusterMetadata": 1, + "GetClusterMetadata": 1, + "SaveClusterMetadata": 1, + "DeleteClusterMetadata": 1, + }, } RequestPrioritiesOrdered = []int{0, 1} @@ -49,7 +174,7 @@ func NewPriorityRateLimiter( return quotas.NewPriorityRateLimiter( func(req quotas.Request) int { - if priority, ok := CallerTypeHeaderToPriority[req.Caller]; ok { + if priority, ok := CallerTypeAndAPIToPriority[req.CallerType][req.API]; ok { return priority } diff --git a/common/persistence/persistenceRateLimitedClients.go b/common/persistence/persistenceRateLimitedClients.go index c32ae7f72d1..4368f688f51 100644 --- a/common/persistence/persistenceRateLimitedClients.go +++ b/common/persistence/persistenceRateLimitedClients.go @@ -153,7 +153,7 @@ func (p *shardRateLimitedPersistenceClient) GetOrCreateShard( ctx context.Context, request *GetOrCreateShardRequest, ) (*GetOrCreateShardResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetOrCreateShard", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -165,7 +165,7 @@ func (p *shardRateLimitedPersistenceClient) UpdateShard( ctx context.Context, request *UpdateShardRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "UpdateShard", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -177,7 +177,7 @@ func (p *shardRateLimitedPersistenceClient) AssertShardOwnership( ctx context.Context, request *AssertShardOwnershipRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "AssertShardOwnership", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -197,7 +197,7 @@ func (p *executionRateLimitedPersistenceClient) CreateWorkflowExecution( ctx context.Context, request *CreateWorkflowExecutionRequest, ) (*CreateWorkflowExecutionResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "CreateWorkflowExecution", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -209,7 +209,7 @@ func (p *executionRateLimitedPersistenceClient) GetWorkflowExecution( ctx context.Context, request *GetWorkflowExecutionRequest, ) (*GetWorkflowExecutionResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetWorkflowExecution", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -221,7 +221,7 @@ func (p *executionRateLimitedPersistenceClient) SetWorkflowExecution( ctx context.Context, request *SetWorkflowExecutionRequest, ) (*SetWorkflowExecutionResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "SetWorkflowExecution", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -233,7 +233,7 @@ func (p *executionRateLimitedPersistenceClient) UpdateWorkflowExecution( ctx context.Context, request *UpdateWorkflowExecutionRequest, ) (*UpdateWorkflowExecutionResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "UpdateWorkflowExecution", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -245,7 +245,7 @@ func (p *executionRateLimitedPersistenceClient) ConflictResolveWorkflowExecution ctx context.Context, request *ConflictResolveWorkflowExecutionRequest, ) (*ConflictResolveWorkflowExecutionResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ConflictResolveWorkflowExecution", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -257,7 +257,7 @@ func (p *executionRateLimitedPersistenceClient) DeleteWorkflowExecution( ctx context.Context, request *DeleteWorkflowExecutionRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "DeleteWorkflowExecution", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -269,7 +269,7 @@ func (p *executionRateLimitedPersistenceClient) DeleteCurrentWorkflowExecution( ctx context.Context, request *DeleteCurrentWorkflowExecutionRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "DeleteCurrentWorkflowExecution", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -281,7 +281,7 @@ func (p *executionRateLimitedPersistenceClient) GetCurrentExecution( ctx context.Context, request *GetCurrentExecutionRequest, ) (*GetCurrentExecutionResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetCurrentExecution", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -293,7 +293,7 @@ func (p *executionRateLimitedPersistenceClient) ListConcreteExecutions( ctx context.Context, request *ListConcreteExecutionsRequest, ) (*ListConcreteExecutionsResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ListConcreteExecutions", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -305,7 +305,7 @@ func (p *executionRateLimitedPersistenceClient) AddHistoryTasks( ctx context.Context, request *AddHistoryTasksRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "AddHistoryTasks", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -317,7 +317,7 @@ func (p *executionRateLimitedPersistenceClient) GetHistoryTask( ctx context.Context, request *GetHistoryTaskRequest, ) (*GetHistoryTaskResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetHistoryTask", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -329,7 +329,7 @@ func (p *executionRateLimitedPersistenceClient) GetHistoryTasks( ctx context.Context, request *GetHistoryTasksRequest, ) (*GetHistoryTasksResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetHistoryTasks", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -341,7 +341,7 @@ func (p *executionRateLimitedPersistenceClient) CompleteHistoryTask( ctx context.Context, request *CompleteHistoryTaskRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "CompleteHistoryTask", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -353,7 +353,7 @@ func (p *executionRateLimitedPersistenceClient) RangeCompleteHistoryTasks( ctx context.Context, request *RangeCompleteHistoryTasksRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "RangeCompleteHistoryTasks", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -365,7 +365,7 @@ func (p *executionRateLimitedPersistenceClient) PutReplicationTaskToDLQ( ctx context.Context, request *PutReplicationTaskToDLQRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "PutReplicationTaskToDLQ", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -376,7 +376,7 @@ func (p *executionRateLimitedPersistenceClient) GetReplicationTasksFromDLQ( ctx context.Context, request *GetReplicationTasksFromDLQRequest, ) (*GetHistoryTasksResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetReplicationTasksFromDLQ", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -387,7 +387,7 @@ func (p *executionRateLimitedPersistenceClient) DeleteReplicationTaskFromDLQ( ctx context.Context, request *DeleteReplicationTaskFromDLQRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "DeleteReplicationTaskFromDLQ", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -398,7 +398,7 @@ func (p *executionRateLimitedPersistenceClient) RangeDeleteReplicationTaskFromDL ctx context.Context, request *RangeDeleteReplicationTaskFromDLQRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "RangeDeleteReplicationTaskFromDLQ", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -417,7 +417,7 @@ func (p *taskRateLimitedPersistenceClient) CreateTasks( ctx context.Context, request *CreateTasksRequest, ) (*CreateTasksResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "CreateTasks", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -429,7 +429,7 @@ func (p *taskRateLimitedPersistenceClient) GetTasks( ctx context.Context, request *GetTasksRequest, ) (*GetTasksResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetTasks", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -441,7 +441,7 @@ func (p *taskRateLimitedPersistenceClient) CompleteTask( ctx context.Context, request *CompleteTaskRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "CompleteTask", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -453,7 +453,7 @@ func (p *taskRateLimitedPersistenceClient) CompleteTasksLessThan( ctx context.Context, request *CompleteTasksLessThanRequest, ) (int, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "CompleteTasksLessThan", p.rateLimiter); !ok { return 0, ErrPersistenceLimitExceeded } return p.persistence.CompleteTasksLessThan(ctx, request) @@ -463,7 +463,7 @@ func (p *taskRateLimitedPersistenceClient) CreateTaskQueue( ctx context.Context, request *CreateTaskQueueRequest, ) (*CreateTaskQueueResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "CreateTaskQueue", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.CreateTaskQueue(ctx, request) @@ -473,7 +473,7 @@ func (p *taskRateLimitedPersistenceClient) UpdateTaskQueue( ctx context.Context, request *UpdateTaskQueueRequest, ) (*UpdateTaskQueueResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "UpdateTaskQueue", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.UpdateTaskQueue(ctx, request) @@ -483,7 +483,7 @@ func (p *taskRateLimitedPersistenceClient) GetTaskQueue( ctx context.Context, request *GetTaskQueueRequest, ) (*GetTaskQueueResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetTaskQueue", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.GetTaskQueue(ctx, request) @@ -493,7 +493,7 @@ func (p *taskRateLimitedPersistenceClient) ListTaskQueue( ctx context.Context, request *ListTaskQueueRequest, ) (*ListTaskQueueResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ListTaskQueue", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.ListTaskQueue(ctx, request) @@ -503,7 +503,7 @@ func (p *taskRateLimitedPersistenceClient) DeleteTaskQueue( ctx context.Context, request *DeleteTaskQueueRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "DeleteTaskQueue", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } return p.persistence.DeleteTaskQueue(ctx, request) @@ -521,7 +521,7 @@ func (p *metadataRateLimitedPersistenceClient) CreateNamespace( ctx context.Context, request *CreateNamespaceRequest, ) (*CreateNamespaceResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "CreateNamespace", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -533,7 +533,7 @@ func (p *metadataRateLimitedPersistenceClient) GetNamespace( ctx context.Context, request *GetNamespaceRequest, ) (*GetNamespaceResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetNamespace", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -545,7 +545,7 @@ func (p *metadataRateLimitedPersistenceClient) UpdateNamespace( ctx context.Context, request *UpdateNamespaceRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "UpdateNamespace", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -557,7 +557,7 @@ func (p *metadataRateLimitedPersistenceClient) RenameNamespace( ctx context.Context, request *RenameNamespaceRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "RenameNamespace", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -569,7 +569,7 @@ func (p *metadataRateLimitedPersistenceClient) DeleteNamespace( ctx context.Context, request *DeleteNamespaceRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "DeleteNamespace", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -581,7 +581,7 @@ func (p *metadataRateLimitedPersistenceClient) DeleteNamespaceByName( ctx context.Context, request *DeleteNamespaceByNameRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "DeleteNamespaceByName", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -593,7 +593,7 @@ func (p *metadataRateLimitedPersistenceClient) ListNamespaces( ctx context.Context, request *ListNamespacesRequest, ) (*ListNamespacesResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ListNamespaces", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -604,7 +604,7 @@ func (p *metadataRateLimitedPersistenceClient) ListNamespaces( func (p *metadataRateLimitedPersistenceClient) GetMetadata( ctx context.Context, ) (*GetMetadataResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetMetadata", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -616,7 +616,7 @@ func (p *metadataRateLimitedPersistenceClient) InitializeSystemNamespaces( ctx context.Context, currentClusterName string, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "InitializeSystemNamespaces", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } return p.persistence.InitializeSystemNamespaces(ctx, currentClusterName) @@ -631,7 +631,7 @@ func (p *executionRateLimitedPersistenceClient) AppendHistoryNodes( ctx context.Context, request *AppendHistoryNodesRequest, ) (*AppendHistoryNodesResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "AppendHistoryNodes", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.AppendHistoryNodes(ctx, request) @@ -642,7 +642,7 @@ func (p *executionRateLimitedPersistenceClient) AppendRawHistoryNodes( ctx context.Context, request *AppendRawHistoryNodesRequest, ) (*AppendHistoryNodesResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "AppendRawHistoryNodes", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return p.persistence.AppendRawHistoryNodes(ctx, request) @@ -653,7 +653,7 @@ func (p *executionRateLimitedPersistenceClient) ReadHistoryBranch( ctx context.Context, request *ReadHistoryBranchRequest, ) (*ReadHistoryBranchResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ReadHistoryBranch", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ReadHistoryBranch(ctx, request) @@ -665,7 +665,7 @@ func (p *executionRateLimitedPersistenceClient) ReadHistoryBranchReverse( ctx context.Context, request *ReadHistoryBranchReverseRequest, ) (*ReadHistoryBranchReverseResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ReadHistoryBranchReverse", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ReadHistoryBranchReverse(ctx, request) @@ -677,7 +677,7 @@ func (p *executionRateLimitedPersistenceClient) ReadHistoryBranchByBatch( ctx context.Context, request *ReadHistoryBranchRequest, ) (*ReadHistoryBranchByBatchResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ReadHistoryBranchByBatch", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ReadHistoryBranchByBatch(ctx, request) @@ -689,7 +689,7 @@ func (p *executionRateLimitedPersistenceClient) ReadRawHistoryBranch( ctx context.Context, request *ReadHistoryBranchRequest, ) (*ReadRawHistoryBranchResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ReadRawHistoryBranch", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ReadRawHistoryBranch(ctx, request) @@ -701,7 +701,7 @@ func (p *executionRateLimitedPersistenceClient) ForkHistoryBranch( ctx context.Context, request *ForkHistoryBranchRequest, ) (*ForkHistoryBranchResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ForkHistoryBranch", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.ForkHistoryBranch(ctx, request) @@ -713,7 +713,7 @@ func (p *executionRateLimitedPersistenceClient) DeleteHistoryBranch( ctx context.Context, request *DeleteHistoryBranchRequest, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "DeleteHistoryBranch", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } err := p.persistence.DeleteHistoryBranch(ctx, request) @@ -725,7 +725,7 @@ func (p *executionRateLimitedPersistenceClient) TrimHistoryBranch( ctx context.Context, request *TrimHistoryBranchRequest, ) (*TrimHistoryBranchResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "TrimHistoryBranch", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } resp, err := p.persistence.TrimHistoryBranch(ctx, request) @@ -737,7 +737,7 @@ func (p *executionRateLimitedPersistenceClient) GetHistoryTree( ctx context.Context, request *GetHistoryTreeRequest, ) (*GetHistoryTreeResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetHistoryTree", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.GetHistoryTree(ctx, request) @@ -748,7 +748,7 @@ func (p *executionRateLimitedPersistenceClient) GetAllHistoryTreeBranches( ctx context.Context, request *GetAllHistoryTreeBranchesRequest, ) (*GetAllHistoryTreeBranchesResponse, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetAllHistoryTreeBranches", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } response, err := p.persistence.GetAllHistoryTreeBranches(ctx, request) @@ -759,7 +759,7 @@ func (p *queueRateLimitedPersistenceClient) EnqueueMessage( ctx context.Context, blob commonpb.DataBlob, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "EnqueueMessage", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -771,7 +771,7 @@ func (p *queueRateLimitedPersistenceClient) ReadMessages( lastMessageID int64, maxCount int, ) ([]*QueueMessage, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ReadMessages", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -782,7 +782,7 @@ func (p *queueRateLimitedPersistenceClient) UpdateAckLevel( ctx context.Context, metadata *InternalQueueMetadata, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "UpdateAckLevel", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -792,7 +792,7 @@ func (p *queueRateLimitedPersistenceClient) UpdateAckLevel( func (p *queueRateLimitedPersistenceClient) GetAckLevels( ctx context.Context, ) (*InternalQueueMetadata, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetAckLevels", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -803,7 +803,7 @@ func (p *queueRateLimitedPersistenceClient) DeleteMessagesBefore( ctx context.Context, messageID int64, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "DeleteMessagesBefore", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -814,7 +814,7 @@ func (p *queueRateLimitedPersistenceClient) EnqueueMessageToDLQ( ctx context.Context, blob commonpb.DataBlob, ) (int64, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "EnqueueMessageToDLQ", p.rateLimiter); !ok { return EmptyQueueMessageID, ErrPersistenceLimitExceeded } @@ -828,7 +828,7 @@ func (p *queueRateLimitedPersistenceClient) ReadMessagesFromDLQ( pageSize int, pageToken []byte, ) ([]*QueueMessage, []byte, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "ReadMessagesFromDLQ", p.rateLimiter); !ok { return nil, nil, ErrPersistenceLimitExceeded } @@ -840,7 +840,7 @@ func (p *queueRateLimitedPersistenceClient) RangeDeleteMessagesFromDLQ( firstMessageID int64, lastMessageID int64, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "RangeDeleteMessagesFromDLQ", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -850,7 +850,7 @@ func (p *queueRateLimitedPersistenceClient) UpdateDLQAckLevel( ctx context.Context, metadata *InternalQueueMetadata, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "UpdateDLQAckLevel", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -860,7 +860,7 @@ func (p *queueRateLimitedPersistenceClient) UpdateDLQAckLevel( func (p *queueRateLimitedPersistenceClient) GetDLQAckLevels( ctx context.Context, ) (*InternalQueueMetadata, error) { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "GetDLQAckLevels", p.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } @@ -871,7 +871,7 @@ func (p *queueRateLimitedPersistenceClient) DeleteMessageFromDLQ( ctx context.Context, messageID int64, ) error { - if ok := allow(ctx, p.rateLimiter); !ok { + if ok := allow(ctx, "DeleteMessageFromDLQ", p.rateLimiter); !ok { return ErrPersistenceLimitExceeded } @@ -901,7 +901,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) GetClusterMembers( ctx context.Context, request *GetClusterMembersRequest, ) (*GetClusterMembersResponse, error) { - if ok := allow(ctx, c.rateLimiter); !ok { + if ok := allow(ctx, "GetClusterMembers", c.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return c.persistence.GetClusterMembers(ctx, request) @@ -911,7 +911,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) UpsertClusterMembership( ctx context.Context, request *UpsertClusterMembershipRequest, ) error { - if ok := allow(ctx, c.rateLimiter); !ok { + if ok := allow(ctx, "UpsertClusterMembership", c.rateLimiter); !ok { return ErrPersistenceLimitExceeded } return c.persistence.UpsertClusterMembership(ctx, request) @@ -921,7 +921,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) PruneClusterMembership( ctx context.Context, request *PruneClusterMembershipRequest, ) error { - if ok := allow(ctx, c.rateLimiter); !ok { + if ok := allow(ctx, "PruneClusterMembership", c.rateLimiter); !ok { return ErrPersistenceLimitExceeded } return c.persistence.PruneClusterMembership(ctx, request) @@ -931,7 +931,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) ListClusterMetadata( ctx context.Context, request *ListClusterMetadataRequest, ) (*ListClusterMetadataResponse, error) { - if ok := allow(ctx, c.rateLimiter); !ok { + if ok := allow(ctx, "ListClusterMetadata", c.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return c.persistence.ListClusterMetadata(ctx, request) @@ -940,7 +940,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) ListClusterMetadata( func (c *clusterMetadataRateLimitedPersistenceClient) GetCurrentClusterMetadata( ctx context.Context, ) (*GetClusterMetadataResponse, error) { - if ok := allow(ctx, c.rateLimiter); !ok { + if ok := allow(ctx, "GetCurrentClusterMetadata", c.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return c.persistence.GetCurrentClusterMetadata(ctx) @@ -950,7 +950,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) GetClusterMetadata( ctx context.Context, request *GetClusterMetadataRequest, ) (*GetClusterMetadataResponse, error) { - if ok := allow(ctx, c.rateLimiter); !ok { + if ok := allow(ctx, "GetClusterMetadata", c.rateLimiter); !ok { return nil, ErrPersistenceLimitExceeded } return c.persistence.GetClusterMetadata(ctx, request) @@ -960,7 +960,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) SaveClusterMetadata( ctx context.Context, request *SaveClusterMetadataRequest, ) (bool, error) { - if ok := allow(ctx, c.rateLimiter); !ok { + if ok := allow(ctx, "SaveClusterMetadata", c.rateLimiter); !ok { return false, ErrPersistenceLimitExceeded } return c.persistence.SaveClusterMetadata(ctx, request) @@ -970,7 +970,7 @@ func (c *clusterMetadataRateLimitedPersistenceClient) DeleteClusterMetadata( ctx context.Context, request *DeleteClusterMetadataRequest, ) error { - if ok := allow(ctx, c.rateLimiter); !ok { + if ok := allow(ctx, "DeleteClusterMetadata", c.rateLimiter); !ok { return ErrPersistenceLimitExceeded } return c.persistence.DeleteClusterMetadata(ctx, request) @@ -978,10 +978,11 @@ func (c *clusterMetadataRateLimitedPersistenceClient) DeleteClusterMetadata( func allow( ctx context.Context, + api string, rateLimiter quotas.RequestRateLimiter, ) bool { return rateLimiter.Allow(time.Now().UTC(), quotas.NewRequest( - "", // api: currently not used when calculating priority + api, RateLimitDefaultToken, "", // caller: currently not used when calculating priority headers.GetValues(ctx, headers.CallerTypeHeaderName)[0], diff --git a/service/history/shard/context_impl.go b/service/history/shard/context_impl.go index 0ff895b7a42..274c4864bf9 100644 --- a/service/history/shard/context_impl.go +++ b/service/history/shard/context_impl.go @@ -2043,7 +2043,7 @@ func (s *ContextImpl) ensureMinContextTimeout( func (s *ContextImpl) newIOContext() (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(s.lifecycleCtx, shardIOTimeout) - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeSystem) + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) return ctx, cancel } From e8f9144b8de44439f651c4ee7dc54321e9ebb629 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 25 Jul 2022 12:01:28 -0700 Subject: [PATCH 03/13] add and fix unit tests --- common/headers/caller_info_test.go | 121 ++++++++++++++++++ common/rpc/context_test.go | 105 +++++++++++++++ common/xdc/nDCHistoryResender.go | 2 +- .../history/replication/task_executor_test.go | 25 ++-- 4 files changed, 238 insertions(+), 15 deletions(-) create mode 100644 common/headers/caller_info_test.go create mode 100644 common/rpc/context_test.go diff --git a/common/headers/caller_info_test.go b/common/headers/caller_info_test.go new file mode 100644 index 00000000000..68a74072038 --- /dev/null +++ b/common/headers/caller_info_test.go @@ -0,0 +1,121 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package headers + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc/metadata" +) + +type ( + callerInfoSuite struct { + *require.Assertions + suite.Suite + } +) + +func TestCallerInfoSuite(t *testing.T) { + suite.Run(t, &callerInfoSuite{}) +} + +func (s *callerInfoSuite) SetupTest() { + s.Assertions = require.New(s.T()) +} + +func (s *callerInfoSuite) TestSetCallerInfo_PreserveOtherValues() { + existingKey := "key" + existingValue := "value" + callerName := "callerName" + callerType := CallerTypeAPI + + ctx := metadata.NewIncomingContext( + context.Background(), + metadata.Pairs(existingKey, existingValue), + ) + + ctx = SetCallerInfo(ctx, callerName, callerType) + + md, ok := metadata.FromIncomingContext(ctx) + s.True(ok) + s.Equal(existingValue, md.Get(existingKey)[0]) + s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) + s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) + s.Len(md, 3) +} + +func (s *callerInfoSuite) TestSetCallerInfo_NoExistingCallerInfo() { + callerName := "callerName" + callerType := CallerTypeAPI + + ctx := SetCallerInfo(context.Background(), callerName, callerType) + + md, ok := metadata.FromIncomingContext(ctx) + s.True(ok) + s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) + s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) + s.Len(md, 2) +} + +func (s *callerInfoSuite) TestSetCallerInfo_WithExistingCallerInfo() { + callerName := "callerName" + callerType := CallerTypeAPI + + ctx := SetCallerInfo(context.Background(), callerName, callerType) + + ctx = SetCallerInfo(ctx, "another caller", CallerTypeBackground) + + md, ok := metadata.FromIncomingContext(ctx) + s.True(ok) + s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) + s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) + s.Len(md, 2) +} + +func (s *callerInfoSuite) TestSetCallerInfo_WithPartialCallerInfo() { + callerName := "callerName" + callerType := CallerTypeAPI + + ctx := SetCallerInfo(context.Background(), callerName, "") + ctx = SetCallerInfo(ctx, "another caller", callerType) + + md, ok := metadata.FromIncomingContext(ctx) + s.True(ok) + s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) + s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) + s.Len(md, 2) + + ctx = SetCallerInfo(context.Background(), "", callerType) + ctx = SetCallerInfo(ctx, callerName, "") + + md, ok = metadata.FromIncomingContext(ctx) + s.True(ok) + s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) + s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) + s.Len(md, 2) +} diff --git a/common/rpc/context_test.go b/common/rpc/context_test.go new file mode 100644 index 00000000000..14130205509 --- /dev/null +++ b/common/rpc/context_test.go @@ -0,0 +1,105 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package rpc + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc/metadata" +) + +type ( + contextSuite struct { + *require.Assertions + suite.Suite + } +) + +func TestContextSuite(t *testing.T) { + suite.Run(t, &contextSuite{}) +} + +func (s *contextSuite) SetupTest() { + s.Assertions = require.New(s.T()) +} + +func (s *contextSuite) TestCopyContextValues_ValueCopied() { + strKey := "key" + strValue := "value" + + structKey := struct{}{} + structValue := struct{}{} + + metadataKey := "header-key" + metadataValue := "header-value" + + ctx := context.Background() + ctx = context.WithValue(ctx, strKey, strValue) + ctx = context.WithValue(ctx, structKey, structValue) + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(metadataKey, metadataValue)) + + newDeadline := time.Now().Add(time.Hour) + newContext, _ := context.WithDeadline(context.Background(), newDeadline) + + newContext = CopyContextValues(newContext, ctx) + + s.Equal(strValue, newContext.Value(strKey)) + s.Equal(structValue, newContext.Value(structKey)) + md, ok := metadata.FromIncomingContext(newContext) + s.True(ok) + s.Equal(metadataValue, md[metadataKey][0]) +} + +func (s *contextSuite) TestCopyContextValue_DeadlineSeparated() { + deadline := time.Now().Add(time.Minute) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + + newDeadline := time.Now().Add(time.Hour) + newContext, newCancel := context.WithDeadline(context.Background(), newDeadline) + defer newCancel() + + newContext = CopyContextValues(newContext, ctx) + + cancel() + s.NotNil(ctx.Err()) + s.Nil(newContext.Err()) +} + +func (s *contextSuite) TestCopyContextValue_ValueNotOverWritten() { + key := struct{}{} + value := "value" + ctx := context.WithValue(context.Background(), key, value) + + newValue := "newValue" + newContext := context.WithValue(context.Background(), key, newValue) + + newContext = CopyContextValues(newContext, ctx) + + s.Equal(newValue, newContext.Value(key)) +} diff --git a/common/xdc/nDCHistoryResender.go b/common/xdc/nDCHistoryResender.go index 51f8bdac4cf..65003cd2e7e 100644 --- a/common/xdc/nDCHistoryResender.go +++ b/common/xdc/nDCHistoryResender.go @@ -132,7 +132,7 @@ func (n *NDCHistoryResenderImpl) SendSingleWorkflowHistory( defer cancel() } } - rpc.CopyContextValues(resendCtx, ctx) + resendCtx = rpc.CopyContextValues(resendCtx, ctx) historyIterator := collection.NewPagingIterator(n.getPaginationFn( resendCtx, diff --git a/service/history/replication/task_executor_test.go b/service/history/replication/task_executor_test.go index e0dda1d4ea3..50bf667cf8c 100644 --- a/service/history/replication/task_executor_test.go +++ b/service/history/replication/task_executor_test.go @@ -109,10 +109,12 @@ func (s *taskExecutorSuite) SetupTest() { s.mockNamespaceCache = s.mockResource.NamespaceCache s.clusterMetadata = s.mockResource.ClusterMetadata s.nDCHistoryResender = xdc.NewMockNDCHistoryResender(s.controller) - s.historyClient = historyservicemock.NewMockHistoryServiceClient(s.controller) - s.clusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() s.workflowCache = workflow.NewMockCache(s.controller) + + s.clusterMetadata.EXPECT().GetCurrentClusterName().Return(cluster.TestCurrentClusterName).AnyTimes() + s.mockNamespaceCache.EXPECT().GetNamespaceName(gomock.Any()).Return(tests.Namespace, nil).AnyTimes() + s.replicationTaskExecutor = NewTaskExecutor( s.remoteCluster, s.mockShard, @@ -379,21 +381,16 @@ func (s *taskExecutorSuite) TestProcessTaskOnce_SyncWorkflowStateTask() { task := &replicationspb.ReplicationTask{ TaskType: enumsspb.REPLICATION_TASK_TYPE_SYNC_WORKFLOW_STATE_TASK, Attributes: &replicationspb.ReplicationTask_SyncWorkflowStateTaskAttributes{ - SyncWorkflowStateTaskAttributes: &replicationspb.SyncWorkflowStateTaskAttributes{}, + SyncWorkflowStateTaskAttributes: &replicationspb.SyncWorkflowStateTaskAttributes{ + WorkflowState: &persistencespb.WorkflowMutableState{ + ExecutionInfo: &persistencespb.WorkflowExecutionInfo{ + NamespaceId: namespaceID.String(), + }, + }, + }, }, } - s.mockNamespaceCache.EXPECT(). - GetNamespaceByID(namespaceID). - Return(namespace.NewGlobalNamespaceForTest( - nil, - nil, - &persistencespb.NamespaceReplicationConfig{Clusters: []string{ - cluster.TestCurrentClusterName, - cluster.TestAlternativeClusterName, - }}, - 0, - ), nil).AnyTimes() s.mockEngine.EXPECT().ReplicateWorkflowState(gomock.Any(), gomock.Any()).Return(nil) _, err := s.replicationTaskExecutor.Execute(task, true) From ebab5d8653579f6878463f3848188ddda101261e Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 25 Jul 2022 12:58:39 -0700 Subject: [PATCH 04/13] bug fixes --- common/headers/caller_info.go | 2 +- common/headers/caller_info_test.go | 2 +- service/history/transferQueueProcessorBase.go | 2 +- service/matching/forwarder.go | 3 +++ service/matching/taskReader.go | 5 +++++ service/matching/taskWriter.go | 3 +++ 6 files changed, 14 insertions(+), 3 deletions(-) diff --git a/common/headers/caller_info.go b/common/headers/caller_info.go index d75b4d8d11a..a2021cfbb51 100644 --- a/common/headers/caller_info.go +++ b/common/headers/caller_info.go @@ -58,7 +58,7 @@ func SetCallerInfo( } if values[1] == "" { - mdIncoming.Set(CallerTypeHeaderName, CallerTypeAPI) + mdIncoming.Set(CallerTypeHeaderName, callerType) } return metadata.NewIncomingContext(ctx, mdIncoming) diff --git a/common/headers/caller_info_test.go b/common/headers/caller_info_test.go index 68a74072038..c079ce2750c 100644 --- a/common/headers/caller_info_test.go +++ b/common/headers/caller_info_test.go @@ -99,7 +99,7 @@ func (s *callerInfoSuite) TestSetCallerInfo_WithExistingCallerInfo() { func (s *callerInfoSuite) TestSetCallerInfo_WithPartialCallerInfo() { callerName := "callerName" - callerType := CallerTypeAPI + callerType := CallerTypeBackground ctx := SetCallerInfo(context.Background(), callerName, "") ctx = SetCallerInfo(ctx, "another caller", callerType) diff --git a/service/history/transferQueueProcessorBase.go b/service/history/transferQueueProcessorBase.go index 1fa783c209c..3f4aec01eca 100644 --- a/service/history/transferQueueProcessorBase.go +++ b/service/history/transferQueueProcessorBase.go @@ -132,6 +132,6 @@ func newTransferTaskScheduler( func newQueueIOContext() (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(context.Background(), queueIOTimeout) - headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) return ctx, cancel } diff --git a/service/matching/forwarder.go b/service/matching/forwarder.go index 2bee3400c3d..e7272d41952 100644 --- a/service/matching/forwarder.go +++ b/service/matching/forwarder.go @@ -36,6 +36,7 @@ import ( "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/api/matchingservice/v1" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/quotas" ) @@ -142,6 +143,7 @@ func (fwdr *Forwarder) ForwardTask(ctx context.Context, task *internalTask) erro } } + ctx = headers.SetCallerInfo(ctx, task.event.Data.GetNamespaceId(), "") switch fwdr.taskQueueID.taskType { case enumspb.TASK_QUEUE_TYPE_WORKFLOW: _, err = fwdr.client.AddWorkflowTask(ctx, &matchingservice.AddWorkflowTaskRequest{ @@ -194,6 +196,7 @@ func (fwdr *Forwarder) ForwardQueryTask( return nil, errNoParent } + ctx = headers.SetCallerInfo(ctx, task.query.request.GetNamespaceId(), "") resp, err := fwdr.client.QueryWorkflow(ctx, &matchingservice.QueryWorkflowRequest{ NamespaceId: task.query.request.GetNamespaceId(), TaskQueue: &taskqueuepb.TaskQueue{ diff --git a/service/matching/taskReader.go b/service/matching/taskReader.go index 079f298c6a8..cf215c04879 100644 --- a/service/matching/taskReader.go +++ b/service/matching/taskReader.go @@ -32,6 +32,7 @@ import ( enumsspb "go.temporal.io/server/api/enums/v1" persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -102,6 +103,8 @@ func (tr *taskReader) Signal() { } func (tr *taskReader) dispatchBufferedTasks(ctx context.Context) error { + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + dispatchLoop: for { select { @@ -133,6 +136,8 @@ dispatchLoop: } func (tr *taskReader) getTasksPump(ctx context.Context) error { + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + if err := tr.tlMgr.WaitUntilInitialized(ctx); err != nil { return err } diff --git a/service/matching/taskWriter.go b/service/matching/taskWriter.go index 78777a74b1d..db5c8e668d6 100644 --- a/service/matching/taskWriter.go +++ b/service/matching/taskWriter.go @@ -37,6 +37,7 @@ import ( persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/backoff" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -216,6 +217,8 @@ func (w *taskWriter) appendTasks( } func (w *taskWriter) taskWriterLoop(ctx context.Context) error { + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + err := w.initReadWriteState(ctx) w.tlMgr.initializedError.Set(struct{}{}, err) if err != nil { From c0047e943b94c1a2f761bb020631a80bfddc82b7 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 25 Jul 2022 13:11:16 -0700 Subject: [PATCH 05/13] mark worker ctx as background --- service/worker/archiver/client_worker.go | 2 ++ service/worker/batcher/batcher.go | 2 ++ service/worker/parentclosepolicy/processor.go | 3 +++ service/worker/scanner/scanner.go | 6 +++++- service/worker/scheduler/fx.go | 6 ++++++ service/worker/worker.go | 3 +++ 6 files changed, 21 insertions(+), 1 deletion(-) diff --git a/service/worker/archiver/client_worker.go b/service/worker/archiver/client_worker.go index 8393647cbad..42cbd8ad907 100644 --- a/service/worker/archiver/client_worker.go +++ b/service/worker/archiver/client_worker.go @@ -35,6 +35,7 @@ import ( "go.temporal.io/server/common" "go.temporal.io/server/common/archiver/provider" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -106,6 +107,7 @@ func NewClientWorker(container *BootstrapContainer) ClientWorker { globalMetricsClient = container.MetricsClient globalConfig = container.Config actCtx := context.WithValue(context.Background(), bootstrapContainerKey, container) + actCtx = headers.SetCallerInfo(actCtx, "", headers.CallerTypeBackground) sdkClient := container.SdkClientFactory.GetSystemClient(container.Logger) wo := worker.Options{ diff --git a/service/worker/batcher/batcher.go b/service/worker/batcher/batcher.go index acb9a9edb72..f6363923dda 100644 --- a/service/worker/batcher/batcher.go +++ b/service/worker/batcher/batcher.go @@ -32,6 +32,7 @@ import ( "go.temporal.io/sdk/workflow" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -76,6 +77,7 @@ func New( func (s *Batcher) Start() error { // start worker for batch operation workflows ctx := context.WithValue(context.Background(), batcherContextKey, s) + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) workerOpts := worker.Options{ MaxConcurrentActivityExecutionSize: s.cfg.MaxConcurrentActivityExecutionSize(), diff --git a/service/worker/parentclosepolicy/processor.go b/service/worker/parentclosepolicy/processor.go index 7de234c6669..339fd8491dc 100644 --- a/service/worker/parentclosepolicy/processor.go +++ b/service/worker/parentclosepolicy/processor.go @@ -34,6 +34,7 @@ import ( "go.temporal.io/server/client" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -102,6 +103,8 @@ func (s *Processor) Start() error { func getWorkerOptions(p *Processor) worker.Options { ctx := context.WithValue(context.Background(), processorContextKey, p) + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + return worker.Options{ MaxConcurrentActivityExecutionSize: p.cfg.MaxConcurrentActivityExecutionSize(), MaxConcurrentWorkflowTaskExecutionSize: p.cfg.MaxConcurrentWorkflowTaskExecutionSize(), diff --git a/service/worker/scanner/scanner.go b/service/worker/scanner/scanner.go index 60e4f4936a4..b3eb4ae6f33 100644 --- a/service/worker/scanner/scanner.go +++ b/service/worker/scanner/scanner.go @@ -36,6 +36,7 @@ import ( "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/persistence" @@ -119,13 +120,16 @@ func New( // Start starts the scanner func (s *Scanner) Start() error { + ctx := context.WithValue(context.Background(), scannerContextKey, s.context) + ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + workerOpts := worker.Options{ MaxConcurrentActivityExecutionSize: s.context.cfg.MaxConcurrentActivityExecutionSize(), MaxConcurrentWorkflowTaskExecutionSize: s.context.cfg.MaxConcurrentWorkflowTaskExecutionSize(), MaxConcurrentActivityTaskPollers: s.context.cfg.MaxConcurrentActivityTaskPollers(), MaxConcurrentWorkflowTaskPollers: s.context.cfg.MaxConcurrentWorkflowTaskPollers(), - BackgroundActivityContext: context.WithValue(context.Background(), scannerContextKey, s.context), + BackgroundActivityContext: ctx, } var workerTaskQueueNames []string diff --git a/service/worker/scheduler/fx.go b/service/worker/scheduler/fx.go index 90ed7388e4d..b208a24ef58 100644 --- a/service/worker/scheduler/fx.go +++ b/service/worker/scheduler/fx.go @@ -25,6 +25,8 @@ package scheduler import ( + "context" + "go.uber.org/fx" "go.temporal.io/api/workflowservice/v1" @@ -32,6 +34,7 @@ import ( "go.temporal.io/sdk/workflow" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" @@ -88,6 +91,9 @@ func (s *workerComponent) DedicatedWorkerOptions(ns *namespace.Namespace) *worke Enabled: s.enabledForNs(ns.Name().String()), TaskQueue: TaskQueueName, NumWorkers: s.numWorkers(ns.Name().String()), + Options: sdkworker.Options{ + BackgroundActivityContext: headers.SetCallerInfo(context.Background(), "", headers.CallerTypeBackground), + }, } } diff --git a/service/worker/worker.go b/service/worker/worker.go index ab1dd92e4c5..cfba3851dde 100644 --- a/service/worker/worker.go +++ b/service/worker/worker.go @@ -25,12 +25,14 @@ package worker import ( + "context" "sync/atomic" sdkworker "go.temporal.io/sdk/worker" "go.uber.org/fx" "go.temporal.io/server/common" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/sdk" @@ -76,6 +78,7 @@ func (wm *workerManager) Start() { defaultWorkerOptions := sdkworker.Options{ // TODO: add dynamic config for worker options + BackgroundActivityContext: headers.SetCallerInfo(context.Background(), "", headers.CallerTypeBackground), } sdkClient := wm.sdkClientFactory.GetSystemClient(wm.logger) defaultWorker := sdkworker.New(sdkClient, DefaultWorkerTaskQueue, defaultWorkerOptions) From 3b092a31959896d6f6c3f3c72b7cffc18fd931bd Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 25 Jul 2022 13:25:15 -0700 Subject: [PATCH 06/13] set more caller info --- common/namespace/registry.go | 7 +++++-- service/matching/matchingEngine_test.go | 10 +++++----- service/matching/taskReader.go | 10 +++++----- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/common/namespace/registry.go b/common/namespace/registry.go index 3aeb7388105..fb182949950 100644 --- a/common/namespace/registry.go +++ b/common/namespace/registry.go @@ -40,6 +40,7 @@ import ( "go.temporal.io/server/common/cache" "go.temporal.io/server/common/clock" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -219,11 +220,13 @@ func (r *registry) Start() { defer atomic.StoreInt32(&r.status, running) // initialize the cache by initial scan - err := r.refreshNamespaces(context.Background()) + ctx := headers.SetCallerInfo(context.Background(), "", headers.CallerTypeBackground) + + err := r.refreshNamespaces(ctx) if err != nil { r.logger.Fatal("Unable to initialize namespace cache", tag.Error(err)) } - r.refresher = goro.NewHandle(context.Background()).Go(r.refreshLoop) + r.refresher = goro.NewHandle(ctx).Go(r.refreshLoop) } // Stop the background refresh of Namespace data diff --git a/service/matching/matchingEngine_test.go b/service/matching/matchingEngine_test.go index e233d7cdce3..b230ae3bcd8 100644 --- a/service/matching/matchingEngine_test.go +++ b/service/matching/matchingEngine_test.go @@ -1646,14 +1646,14 @@ func (s *matchingEngineSuite) TestTaskQueueManagerGetTaskBatch() { // setReadLevel should NEVER be called without updating ackManager.outstandingTasks // This is only for unit test purpose tlMgr.taskAckManager.setReadLevel(tlMgr.taskWriter.GetMaxReadLevel()) - tasks, readLevel, isReadBatchDone, err := tlMgr.taskReader.getTaskBatch() + tasks, readLevel, isReadBatchDone, err := tlMgr.taskReader.getTaskBatch(context.Background()) s.Nil(err) s.EqualValues(0, len(tasks)) s.EqualValues(tlMgr.taskWriter.GetMaxReadLevel(), readLevel) s.True(isReadBatchDone) tlMgr.taskAckManager.setReadLevel(0) - tasks, readLevel, isReadBatchDone, err = tlMgr.taskReader.getTaskBatch() + tasks, readLevel, isReadBatchDone, err = tlMgr.taskReader.getTaskBatch(context.Background()) s.Nil(err) s.EqualValues(rangeSize, len(tasks)) s.EqualValues(rangeSize, readLevel) @@ -1686,7 +1686,7 @@ func (s *matchingEngineSuite) TestTaskQueueManagerGetTaskBatch() { } } s.EqualValues(taskCount-rangeSize, s.taskManager.getTaskCount(tlID)) - tasks, _, isReadBatchDone, err = tlMgr.taskReader.getTaskBatch() + tasks, _, isReadBatchDone, err = tlMgr.taskReader.getTaskBatch(context.Background()) s.Nil(err) s.True(0 < len(tasks) && len(tasks) <= rangeSize) s.True(isReadBatchDone) @@ -1716,14 +1716,14 @@ func (s *matchingEngineSuite) TestTaskQueueManagerGetTaskBatch_ReadBatchDone() { tlMgr.taskAckManager.setReadLevel(0) atomic.StoreInt64(&tlMgr.taskWriter.maxReadLevel, maxReadLevel) - tasks, readLevel, isReadBatchDone, err := tlMgr.taskReader.getTaskBatch() + tasks, readLevel, isReadBatchDone, err := tlMgr.taskReader.getTaskBatch(context.Background()) s.Empty(tasks) s.Equal(int64(rangeSize*10), readLevel) s.False(isReadBatchDone) s.NoError(err) tlMgr.taskAckManager.setReadLevel(readLevel) - tasks, readLevel, isReadBatchDone, err = tlMgr.taskReader.getTaskBatch() + tasks, readLevel, isReadBatchDone, err = tlMgr.taskReader.getTaskBatch(context.Background()) s.Empty(tasks) s.Equal(maxReadLevel, readLevel) s.True(isReadBatchDone) diff --git a/service/matching/taskReader.go b/service/matching/taskReader.go index cf215c04879..8863b81dfd0 100644 --- a/service/matching/taskReader.go +++ b/service/matching/taskReader.go @@ -160,7 +160,7 @@ Loop: return nil case <-tr.notifyC: - tasks, readLevel, isReadBatchDone, err := tr.getTaskBatch() + tasks, readLevel, isReadBatchDone, err := tr.getTaskBatch(ctx) tr.tlMgr.signalIfFatal(err) if err != nil { tr.Signal() // re-enqueue the event @@ -197,10 +197,10 @@ Loop: } } -func (tr *taskReader) getTaskBatchWithRange(readLevel int64, maxReadLevel int64) ([]*persistencespb.AllocatedTaskInfo, error) { +func (tr *taskReader) getTaskBatchWithRange(ctx context.Context, readLevel int64, maxReadLevel int64) ([]*persistencespb.AllocatedTaskInfo, error) { var response *persistence.GetTasksResponse var err error - err = executeWithRetry(context.TODO(), func(ctx context.Context) error { + err = executeWithRetry(ctx, func(ctx context.Context) error { response, err = tr.tlMgr.db.GetTasks(ctx, readLevel+1, maxReadLevel+1, tr.tlMgr.config.GetTasksBatchSize()) return err }) @@ -213,7 +213,7 @@ func (tr *taskReader) getTaskBatchWithRange(readLevel int64, maxReadLevel int64) // Returns a batch of tasks from persistence starting form current read level. // Also return a number that can be used to update readLevel // Also return a bool to indicate whether read is finished -func (tr *taskReader) getTaskBatch() ([]*persistencespb.AllocatedTaskInfo, int64, bool, error) { +func (tr *taskReader) getTaskBatch(ctx context.Context) ([]*persistencespb.AllocatedTaskInfo, int64, bool, error) { var tasks []*persistencespb.AllocatedTaskInfo readLevel := tr.tlMgr.taskAckManager.getReadLevel() maxReadLevel := tr.tlMgr.taskWriter.GetMaxReadLevel() @@ -224,7 +224,7 @@ func (tr *taskReader) getTaskBatch() ([]*persistencespb.AllocatedTaskInfo, int64 if upper > maxReadLevel { upper = maxReadLevel } - tasks, err := tr.getTaskBatchWithRange(readLevel, upper) + tasks, err := tr.getTaskBatchWithRange(ctx, readLevel, upper) if err != nil { return nil, readLevel, true, err } From 763a44cd17af1dd44c4f8462232a4bb917484b76 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 25 Jul 2022 16:16:45 -0700 Subject: [PATCH 07/13] only call type --- common/headers/caller_info.go | 42 ++++++++----- common/headers/caller_info_test.go | 59 ++++++++++--------- common/headers/headers.go | 6 +- common/namespace/registry.go | 2 +- .../persistenceRateLimitedClients.go | 2 +- common/rpc/interceptor/caller_info.go | 3 +- service/history/queues/executable.go | 2 +- service/history/replication/task_executor.go | 4 +- service/history/replication/task_fetcher.go | 2 +- service/history/replication/task_processor.go | 2 +- service/history/shard/context_impl.go | 2 +- service/history/transferQueueProcessorBase.go | 2 +- service/matching/forwarder.go | 3 - service/matching/taskReader.go | 4 +- service/matching/taskWriter.go | 2 +- service/worker/archiver/client_worker.go | 2 +- service/worker/batcher/batcher.go | 2 +- service/worker/parentclosepolicy/processor.go | 2 +- service/worker/scanner/scanner.go | 2 +- service/worker/scheduler/fx.go | 2 +- service/worker/worker.go | 2 +- 21 files changed, 80 insertions(+), 69 deletions(-) diff --git a/common/headers/caller_info.go b/common/headers/caller_info.go index a2021cfbb51..749346e0dcb 100644 --- a/common/headers/caller_info.go +++ b/common/headers/caller_info.go @@ -35,31 +35,45 @@ const ( CallerTypeBackground = "background" ) +type CallerInfo struct { + CallerType string + + // TODO: add fields for CallerName and CallerInitiation +} + +func NewCallerInfo( + callerType string, +) CallerInfo { + return CallerInfo{ + CallerType: callerType, + } +} + // SetCallerInfo sets callerName and callerType value in incoming context // if not already exists. +// TODO: consider only set the caller info to golang context instead of grpc metadata +// and propagate to grpc outgoing context upon making an rpc call func SetCallerInfo( ctx context.Context, - callerName string, - callerType string, + info CallerInfo, ) context.Context { mdIncoming, ok := metadata.FromIncomingContext(ctx) if !ok { mdIncoming = metadata.MD{} } - values := GetValues( - ctx, - CallerNameHeaderName, - CallerTypeHeaderName, - ) - - if values[0] == "" { - mdIncoming.Set(CallerNameHeaderName, callerName) - } - - if values[1] == "" { - mdIncoming.Set(CallerTypeHeaderName, callerType) + if len(mdIncoming.Get(callerTypeHeaderName)) == 0 { + mdIncoming.Set(callerTypeHeaderName, string(info.CallerType)) } return metadata.NewIncomingContext(ctx, mdIncoming) } + +func GetCallerInfo( + ctx context.Context, +) CallerInfo { + values := GetValues(ctx, callerTypeHeaderName) + return CallerInfo{ + CallerType: values[1], + } +} diff --git a/common/headers/caller_info_test.go b/common/headers/caller_info_test.go index c079ce2750c..c2e3f14ee50 100644 --- a/common/headers/caller_info_test.go +++ b/common/headers/caller_info_test.go @@ -51,7 +51,6 @@ func (s *callerInfoSuite) SetupTest() { func (s *callerInfoSuite) TestSetCallerInfo_PreserveOtherValues() { existingKey := "key" existingValue := "value" - callerName := "callerName" callerType := CallerTypeAPI ctx := metadata.NewIncomingContext( @@ -59,63 +58,69 @@ func (s *callerInfoSuite) TestSetCallerInfo_PreserveOtherValues() { metadata.Pairs(existingKey, existingValue), ) - ctx = SetCallerInfo(ctx, callerName, callerType) + ctx = SetCallerInfo(ctx, NewCallerInfo(callerType)) md, ok := metadata.FromIncomingContext(ctx) s.True(ok) s.Equal(existingValue, md.Get(existingKey)[0]) - s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) - s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) - s.Len(md, 3) + s.Equal(callerType, md.Get(callerTypeHeaderName)[0]) + s.Len(md, 2) } func (s *callerInfoSuite) TestSetCallerInfo_NoExistingCallerInfo() { - callerName := "callerName" callerType := CallerTypeAPI - ctx := SetCallerInfo(context.Background(), callerName, callerType) + ctx := SetCallerInfo(context.Background(), CallerInfo{ + CallerType: callerType, + }) md, ok := metadata.FromIncomingContext(ctx) s.True(ok) - s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) - s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) - s.Len(md, 2) + s.Equal(callerType, md.Get(callerTypeHeaderName)[0]) + s.Len(md, 1) } func (s *callerInfoSuite) TestSetCallerInfo_WithExistingCallerInfo() { - callerName := "callerName" callerType := CallerTypeAPI - ctx := SetCallerInfo(context.Background(), callerName, callerType) + ctx := SetCallerInfo(context.Background(), CallerInfo{ + CallerType: callerType, + }) - ctx = SetCallerInfo(ctx, "another caller", CallerTypeBackground) + ctx = SetCallerInfo(ctx, CallerInfo{ + CallerType: CallerTypeBackground, + }) md, ok := metadata.FromIncomingContext(ctx) s.True(ok) - s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) - s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) - s.Len(md, 2) + s.Equal(callerType, md.Get(callerTypeHeaderName)[0]) + s.Len(md, 1) } func (s *callerInfoSuite) TestSetCallerInfo_WithPartialCallerInfo() { - callerName := "callerName" callerType := CallerTypeBackground - ctx := SetCallerInfo(context.Background(), callerName, "") - ctx = SetCallerInfo(ctx, "another caller", callerType) + ctx := SetCallerInfo(context.Background(), CallerInfo{ + CallerType: "", + }) + ctx = SetCallerInfo(ctx, CallerInfo{ + CallerType: callerType, + }) md, ok := metadata.FromIncomingContext(ctx) s.True(ok) - s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) - s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) - s.Len(md, 2) + s.Equal(callerType, md.Get(callerTypeHeaderName)[0]) + s.Len(md, 1) - ctx = SetCallerInfo(context.Background(), "", callerType) - ctx = SetCallerInfo(ctx, callerName, "") + ctx = SetCallerInfo(context.Background(), CallerInfo{ + CallerType: callerType, + }) + ctx = SetCallerInfo(ctx, CallerInfo{ + CallerType: "", + }) md, ok = metadata.FromIncomingContext(ctx) s.True(ok) - s.Equal(callerName, md.Get(CallerNameHeaderName)[0]) - s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) - s.Len(md, 2) + s.Equal(callerType, md.Get(callerTypeHeaderName)[0]) + s.Len(md, 1) } diff --git a/common/headers/headers.go b/common/headers/headers.go index c1a9f8e9e14..9df56b897c6 100644 --- a/common/headers/headers.go +++ b/common/headers/headers.go @@ -37,8 +37,7 @@ const ( SupportedFeaturesHeaderName = "supported-features" SupportedFeaturesHeaderDelim = "," - CallerNameHeaderName = "caller-name" - CallerTypeHeaderName = "caller-type" + callerTypeHeaderName = "caller-type" ) var ( @@ -48,8 +47,7 @@ var ( ClientVersionHeaderName, SupportedServerVersionsHeaderName, SupportedFeaturesHeaderName, - CallerNameHeaderName, - CallerTypeHeaderName, + callerTypeHeaderName, } ) diff --git a/common/namespace/registry.go b/common/namespace/registry.go index fb182949950..8219a11dc25 100644 --- a/common/namespace/registry.go +++ b/common/namespace/registry.go @@ -220,7 +220,7 @@ func (r *registry) Start() { defer atomic.StoreInt32(&r.status, running) // initialize the cache by initial scan - ctx := headers.SetCallerInfo(context.Background(), "", headers.CallerTypeBackground) + ctx := headers.SetCallerInfo(context.Background(), headers.NewCallerInfo(headers.CallerTypeBackground)) err := r.refreshNamespaces(ctx) if err != nil { diff --git a/common/persistence/persistenceRateLimitedClients.go b/common/persistence/persistenceRateLimitedClients.go index 4368f688f51..08d85354d8e 100644 --- a/common/persistence/persistenceRateLimitedClients.go +++ b/common/persistence/persistenceRateLimitedClients.go @@ -985,6 +985,6 @@ func allow( api, RateLimitDefaultToken, "", // caller: currently not used when calculating priority - headers.GetValues(ctx, headers.CallerTypeHeaderName)[0], + headers.GetCallerInfo(ctx).CallerType, )) } diff --git a/common/rpc/interceptor/caller_info.go b/common/rpc/interceptor/caller_info.go index b79feafbbca..eee7740e8ba 100644 --- a/common/rpc/interceptor/caller_info.go +++ b/common/rpc/interceptor/caller_info.go @@ -57,8 +57,7 @@ func (i *CallerInfoInterceptor) Intercept( return handler( headers.SetCallerInfo( ctx, - GetNamespace(i.namespaceRegistry, req).String(), - headers.CallerTypeAPI, + headers.NewCallerInfo(headers.CallerTypeAPI), ), req, ) diff --git a/service/history/queues/executable.go b/service/history/queues/executable.go index 3d1a45e6628..b3f74b9752f 100644 --- a/service/history/queues/executable.go +++ b/service/history/queues/executable.go @@ -161,7 +161,7 @@ func (e *executableImpl) Execute() error { } ctx := metrics.AddMetricsContext(context.Background()) - ctx = headers.SetCallerInfo(ctx, e.GetNamespaceID(), headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) startTime := e.timeSource.Now() var err error diff --git a/service/history/replication/task_executor.go b/service/history/replication/task_executor.go index 85a0c60c14e..5ea9ed6209e 100644 --- a/service/history/replication/task_executor.go +++ b/service/history/replication/task_executor.go @@ -348,9 +348,7 @@ func (e *taskExecutorImpl) newTaskContext( namespaceID string, ) (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(context.Background(), replicationTimeout) - - namespace, _ := e.namespaceRegistry.GetNamespaceName(namespace.ID(namespaceID)) - ctx = headers.SetCallerInfo(ctx, namespace.String(), headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) return ctx, cancel } diff --git a/service/history/replication/task_fetcher.go b/service/history/replication/task_fetcher.go index 3ebdeca56a4..abce8a8489e 100644 --- a/service/history/replication/task_fetcher.go +++ b/service/history/replication/task_fetcher.go @@ -419,7 +419,7 @@ func (f *replicationTaskFetcherWorker) getMessages() error { ctx, cancel := rpc.NewContextWithTimeoutAndVersionHeaders(fetchTaskRequestTimeout) defer cancel() - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) request := &adminservice.GetReplicationMessagesRequest{ Tokens: tokens, diff --git a/service/history/replication/task_processor.go b/service/history/replication/task_processor.go index 484d50e1604..9236cd20dd4 100644 --- a/service/history/replication/task_processor.go +++ b/service/history/replication/task_processor.go @@ -303,7 +303,7 @@ func (p *taskProcessorImpl) handleSyncShardStatus( p.metricsClient.Scope(metrics.HistorySyncShardStatusScope).IncCounter(metrics.SyncShardFromRemoteCounter) ctx, cancel := context.WithTimeout(context.Background(), replicationTimeout) defer cancel() - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) return p.historyEngine.SyncShardStatus(ctx, &historyservice.SyncShardStatusRequest{ SourceCluster: p.sourceCluster, diff --git a/service/history/shard/context_impl.go b/service/history/shard/context_impl.go index 274c4864bf9..e2a4eccb041 100644 --- a/service/history/shard/context_impl.go +++ b/service/history/shard/context_impl.go @@ -2043,7 +2043,7 @@ func (s *ContextImpl) ensureMinContextTimeout( func (s *ContextImpl) newIOContext() (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(s.lifecycleCtx, shardIOTimeout) - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) return ctx, cancel } diff --git a/service/history/transferQueueProcessorBase.go b/service/history/transferQueueProcessorBase.go index 3f4aec01eca..e45d41b8c52 100644 --- a/service/history/transferQueueProcessorBase.go +++ b/service/history/transferQueueProcessorBase.go @@ -132,6 +132,6 @@ func newTransferTaskScheduler( func newQueueIOContext() (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(context.Background(), queueIOTimeout) - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) return ctx, cancel } diff --git a/service/matching/forwarder.go b/service/matching/forwarder.go index e7272d41952..2bee3400c3d 100644 --- a/service/matching/forwarder.go +++ b/service/matching/forwarder.go @@ -36,7 +36,6 @@ import ( "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/api/matchingservice/v1" - "go.temporal.io/server/common/headers" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/quotas" ) @@ -143,7 +142,6 @@ func (fwdr *Forwarder) ForwardTask(ctx context.Context, task *internalTask) erro } } - ctx = headers.SetCallerInfo(ctx, task.event.Data.GetNamespaceId(), "") switch fwdr.taskQueueID.taskType { case enumspb.TASK_QUEUE_TYPE_WORKFLOW: _, err = fwdr.client.AddWorkflowTask(ctx, &matchingservice.AddWorkflowTaskRequest{ @@ -196,7 +194,6 @@ func (fwdr *Forwarder) ForwardQueryTask( return nil, errNoParent } - ctx = headers.SetCallerInfo(ctx, task.query.request.GetNamespaceId(), "") resp, err := fwdr.client.QueryWorkflow(ctx, &matchingservice.QueryWorkflowRequest{ NamespaceId: task.query.request.GetNamespaceId(), TaskQueue: &taskqueuepb.TaskQueue{ diff --git a/service/matching/taskReader.go b/service/matching/taskReader.go index 8863b81dfd0..0f494deab0c 100644 --- a/service/matching/taskReader.go +++ b/service/matching/taskReader.go @@ -103,7 +103,7 @@ func (tr *taskReader) Signal() { } func (tr *taskReader) dispatchBufferedTasks(ctx context.Context) error { - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) dispatchLoop: for { @@ -136,7 +136,7 @@ dispatchLoop: } func (tr *taskReader) getTasksPump(ctx context.Context) error { - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) if err := tr.tlMgr.WaitUntilInitialized(ctx); err != nil { return err diff --git a/service/matching/taskWriter.go b/service/matching/taskWriter.go index db5c8e668d6..b14a7baaba4 100644 --- a/service/matching/taskWriter.go +++ b/service/matching/taskWriter.go @@ -217,7 +217,7 @@ func (w *taskWriter) appendTasks( } func (w *taskWriter) taskWriterLoop(ctx context.Context) error { - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) err := w.initReadWriteState(ctx) w.tlMgr.initializedError.Set(struct{}{}, err) diff --git a/service/worker/archiver/client_worker.go b/service/worker/archiver/client_worker.go index 42cbd8ad907..b23aca37f16 100644 --- a/service/worker/archiver/client_worker.go +++ b/service/worker/archiver/client_worker.go @@ -107,7 +107,7 @@ func NewClientWorker(container *BootstrapContainer) ClientWorker { globalMetricsClient = container.MetricsClient globalConfig = container.Config actCtx := context.WithValue(context.Background(), bootstrapContainerKey, container) - actCtx = headers.SetCallerInfo(actCtx, "", headers.CallerTypeBackground) + actCtx = headers.SetCallerInfo(actCtx, headers.NewCallerInfo(headers.CallerTypeBackground)) sdkClient := container.SdkClientFactory.GetSystemClient(container.Logger) wo := worker.Options{ diff --git a/service/worker/batcher/batcher.go b/service/worker/batcher/batcher.go index f6363923dda..e55e93877d9 100644 --- a/service/worker/batcher/batcher.go +++ b/service/worker/batcher/batcher.go @@ -77,7 +77,7 @@ func New( func (s *Batcher) Start() error { // start worker for batch operation workflows ctx := context.WithValue(context.Background(), batcherContextKey, s) - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) workerOpts := worker.Options{ MaxConcurrentActivityExecutionSize: s.cfg.MaxConcurrentActivityExecutionSize(), diff --git a/service/worker/parentclosepolicy/processor.go b/service/worker/parentclosepolicy/processor.go index 339fd8491dc..b54a257d7e4 100644 --- a/service/worker/parentclosepolicy/processor.go +++ b/service/worker/parentclosepolicy/processor.go @@ -103,7 +103,7 @@ func (s *Processor) Start() error { func getWorkerOptions(p *Processor) worker.Options { ctx := context.WithValue(context.Background(), processorContextKey, p) - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) return worker.Options{ MaxConcurrentActivityExecutionSize: p.cfg.MaxConcurrentActivityExecutionSize(), diff --git a/service/worker/scanner/scanner.go b/service/worker/scanner/scanner.go index b3eb4ae6f33..c52b621f473 100644 --- a/service/worker/scanner/scanner.go +++ b/service/worker/scanner/scanner.go @@ -121,7 +121,7 @@ func New( // Start starts the scanner func (s *Scanner) Start() error { ctx := context.WithValue(context.Background(), scannerContextKey, s.context) - ctx = headers.SetCallerInfo(ctx, "", headers.CallerTypeBackground) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) workerOpts := worker.Options{ MaxConcurrentActivityExecutionSize: s.context.cfg.MaxConcurrentActivityExecutionSize(), diff --git a/service/worker/scheduler/fx.go b/service/worker/scheduler/fx.go index b208a24ef58..40a37166457 100644 --- a/service/worker/scheduler/fx.go +++ b/service/worker/scheduler/fx.go @@ -92,7 +92,7 @@ func (s *workerComponent) DedicatedWorkerOptions(ns *namespace.Namespace) *worke TaskQueue: TaskQueueName, NumWorkers: s.numWorkers(ns.Name().String()), Options: sdkworker.Options{ - BackgroundActivityContext: headers.SetCallerInfo(context.Background(), "", headers.CallerTypeBackground), + BackgroundActivityContext: headers.SetCallerInfo(context.Background(), headers.NewCallerInfo(headers.CallerTypeBackground)), }, } } diff --git a/service/worker/worker.go b/service/worker/worker.go index cfba3851dde..62509e6cb79 100644 --- a/service/worker/worker.go +++ b/service/worker/worker.go @@ -78,7 +78,7 @@ func (wm *workerManager) Start() { defaultWorkerOptions := sdkworker.Options{ // TODO: add dynamic config for worker options - BackgroundActivityContext: headers.SetCallerInfo(context.Background(), "", headers.CallerTypeBackground), + BackgroundActivityContext: headers.SetCallerInfo(context.Background(), headers.NewCallerInfo(headers.CallerTypeBackground)), } sdkClient := wm.sdkClientFactory.GetSystemClient(wm.logger) defaultWorker := sdkworker.New(sdkClient, DefaultWorkerTaskQueue, defaultWorkerOptions) From 4178645e7e74e8afc9f43747fe7bc2d87659ef1a Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 25 Jul 2022 16:21:55 -0700 Subject: [PATCH 08/13] fix lint --- common/rpc/context_test.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/common/rpc/context_test.go b/common/rpc/context_test.go index 14130205509..db0dd2b148a 100644 --- a/common/rpc/context_test.go +++ b/common/rpc/context_test.go @@ -50,18 +50,14 @@ func (s *contextSuite) SetupTest() { } func (s *contextSuite) TestCopyContextValues_ValueCopied() { - strKey := "key" - strValue := "value" - - structKey := struct{}{} - structValue := struct{}{} + key := struct{}{} + value := "value" metadataKey := "header-key" metadataValue := "header-value" ctx := context.Background() - ctx = context.WithValue(ctx, strKey, strValue) - ctx = context.WithValue(ctx, structKey, structValue) + ctx = context.WithValue(ctx, key, value) ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(metadataKey, metadataValue)) newDeadline := time.Now().Add(time.Hour) @@ -69,8 +65,7 @@ func (s *contextSuite) TestCopyContextValues_ValueCopied() { newContext = CopyContextValues(newContext, ctx) - s.Equal(strValue, newContext.Value(strKey)) - s.Equal(structValue, newContext.Value(structKey)) + s.Equal(value, newContext.Value(key)) md, ok := metadata.FromIncomingContext(newContext) s.True(ok) s.Equal(metadataValue, md[metadataKey][0]) From 4497e97c0eaf3b0db3fb8af44ab1cda2b14d669f Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 25 Jul 2022 16:37:03 -0700 Subject: [PATCH 09/13] fix tests --- common/headers/caller_info.go | 2 +- common/headers/caller_info_test.go | 28 ---------------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/common/headers/caller_info.go b/common/headers/caller_info.go index 749346e0dcb..7d00a25595e 100644 --- a/common/headers/caller_info.go +++ b/common/headers/caller_info.go @@ -74,6 +74,6 @@ func GetCallerInfo( ) CallerInfo { values := GetValues(ctx, callerTypeHeaderName) return CallerInfo{ - CallerType: values[1], + CallerType: values[0], } } diff --git a/common/headers/caller_info_test.go b/common/headers/caller_info_test.go index c2e3f14ee50..e99adf25de2 100644 --- a/common/headers/caller_info_test.go +++ b/common/headers/caller_info_test.go @@ -96,31 +96,3 @@ func (s *callerInfoSuite) TestSetCallerInfo_WithExistingCallerInfo() { s.Equal(callerType, md.Get(callerTypeHeaderName)[0]) s.Len(md, 1) } - -func (s *callerInfoSuite) TestSetCallerInfo_WithPartialCallerInfo() { - callerType := CallerTypeBackground - - ctx := SetCallerInfo(context.Background(), CallerInfo{ - CallerType: "", - }) - ctx = SetCallerInfo(ctx, CallerInfo{ - CallerType: callerType, - }) - - md, ok := metadata.FromIncomingContext(ctx) - s.True(ok) - s.Equal(callerType, md.Get(callerTypeHeaderName)[0]) - s.Len(md, 1) - - ctx = SetCallerInfo(context.Background(), CallerInfo{ - CallerType: callerType, - }) - ctx = SetCallerInfo(ctx, CallerInfo{ - CallerType: "", - }) - - md, ok = metadata.FromIncomingContext(ctx) - s.True(ok) - s.Equal(callerType, md.Get(callerTypeHeaderName)[0]) - s.Len(md, 1) -} From f396d6e3d155be10d925002d72d6a3bf50adf954 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Mon, 25 Jul 2022 16:48:18 -0700 Subject: [PATCH 10/13] simplify code --- common/persistence/client/quotas.go | 142 +++------------------------- 1 file changed, 14 insertions(+), 128 deletions(-) diff --git a/common/persistence/client/quotas.go b/common/persistence/client/quotas.go index 01433becf3e..c780d9e1f2c 100644 --- a/common/persistence/client/quotas.go +++ b/common/persistence/client/quotas.go @@ -30,135 +30,17 @@ import ( ) var ( - CallerTypeAndAPIToPriority = map[string]map[string]int{ - headers.CallerTypeAPI: { - "GetOrCreateShard": 0, - "UpdateShard": 0, - "AssertShardOwnership": 0, - - "CreateWorkflowExecution": 0, - "UpdateWorkflowExecution": 0, - "ConflictResolveWorkflowExecution": 0, - "DeleteWorkflowExecution": 0, - "DeleteCurrentWorkflowExecution": 0, - "GetCurrentExecution": 0, - "GetWorkflowExecution": 0, - "SetWorkflowExecution": 0, - "ListConcreteExecutions": 0, - "AddHistoryTasks": 0, - "GetHistoryTask": 0, - "GetHistoryTasks": 0, - "CompleteHistoryTask": 0, - "RangeCompleteHistoryTasks": 0, - "PutReplicationTaskToDLQ": 0, - "GetReplicationTasksFromDLQ": 0, - "DeleteReplicationTaskFromDLQ": 0, - "RangeDeleteReplicationTaskFromDLQ": 0, - "AppendHistoryNodes": 0, - "AppendRawHistoryNodes": 0, - "ReadHistoryBranch": 0, - "ReadHistoryBranchByBatch": 0, - "ReadHistoryBranchReverse": 0, - "ReadRawHistoryBranch": 0, - "ForkHistoryBranch": 0, - "DeleteHistoryBranch": 0, - "TrimHistoryBranch": 0, - "GetHistoryTree": 0, - "GetAllHistoryTreeBranches": 0, - - "CreateTaskQueue": 0, - "UpdateTaskQueue": 0, - "GetTaskQueue": 0, - "ListTaskQueue": 0, - "DeleteTaskQueue": 0, - "CreateTasks": 0, - "GetTasks": 0, - "CompleteTask": 0, - "CompleteTasksLessThan": 0, - - "CreateNamespace": 0, - "GetNamespace": 0, - "UpdateNamespace": 0, - "RenameNamespace": 0, - "DeleteNamespace": 0, - "DeleteNamespaceByName": 0, - "ListNamespaces": 0, - "GetMetadata": 0, - "InitializeSystemNamespaces": 0, - - "GetClusterMembers": 0, - "UpsertClusterMembership": 0, - "PruneClusterMembership": 0, - "ListClusterMetadata": 0, - "GetCurrentClusterMetadata": 0, - "GetClusterMetadata": 0, - "SaveClusterMetadata": 0, - "DeleteClusterMetadata": 0, - }, - headers.CallerTypeBackground: { - "GetOrCreateShard": 0, - "UpdateShard": 0, - "AssertShardOwnership": 1, - - "CreateWorkflowExecution": 1, - "UpdateWorkflowExecution": 1, - "ConflictResolveWorkflowExecution": 1, - "DeleteWorkflowExecution": 1, - "DeleteCurrentWorkflowExecution": 1, - "GetCurrentExecution": 1, - "GetWorkflowExecution": 1, - "SetWorkflowExecution": 1, - "ListConcreteExecutions": 1, - "AddHistoryTasks": 1, - "GetHistoryTask": 1, - "GetHistoryTasks": 1, - "CompleteHistoryTask": 1, - "RangeCompleteHistoryTasks": 0, // this is a preprequisite for updating ack level - "PutReplicationTaskToDLQ": 1, - "GetReplicationTasksFromDLQ": 1, - "DeleteReplicationTaskFromDLQ": 1, - "RangeDeleteReplicationTaskFromDLQ": 1, - "AppendHistoryNodes": 1, - "AppendRawHistoryNodes": 1, - "ReadHistoryBranch": 1, - "ReadHistoryBranchByBatch": 1, - "ReadHistoryBranchReverse": 1, - "ReadRawHistoryBranch": 1, - "ForkHistoryBranch": 1, - "DeleteHistoryBranch": 1, - "TrimHistoryBranch": 1, - "GetHistoryTree": 1, - "GetAllHistoryTreeBranches": 1, - - "CreateTaskQueue": 1, - "UpdateTaskQueue": 1, - "GetTaskQueue": 1, - "ListTaskQueue": 1, - "DeleteTaskQueue": 1, - "CreateTasks": 1, - "GetTasks": 1, - "CompleteTask": 1, - "CompleteTasksLessThan": 1, + CallerTypePriority = map[string]int{ + headers.CallerTypeAPI: 0, + headers.CallerTypeBackground: 1, + } - "CreateNamespace": 1, - "GetNamespace": 1, - "UpdateNamespace": 1, - "RenameNamespace": 1, - "DeleteNamespace": 1, - "DeleteNamespaceByName": 1, - "ListNamespaces": 1, - "GetMetadata": 1, - "InitializeSystemNamespaces": 1, + APIPriorityOverride = map[string]int{ + "GetOrCreateShard": 0, + "UpdateShard": 0, - "GetClusterMembers": 1, - "UpsertClusterMembership": 1, - "PruneClusterMembership": 1, - "ListClusterMetadata": 1, - "GetCurrentClusterMetadata": 1, - "GetClusterMetadata": 1, - "SaveClusterMetadata": 1, - "DeleteClusterMetadata": 1, - }, + // this is a preprequisite for checkpoint queue process progress + "RangeCompleteHistoryTasks": 0, } RequestPrioritiesOrdered = []int{0, 1} @@ -174,7 +56,11 @@ func NewPriorityRateLimiter( return quotas.NewPriorityRateLimiter( func(req quotas.Request) int { - if priority, ok := CallerTypeAndAPIToPriority[req.CallerType][req.API]; ok { + if priority, ok := APIPriorityOverride[req.API]; ok { + return priority + } + + if priority, ok := CallerTypePriority[req.CallerType]; ok { return priority } From ce0d660bccd32c98831b65b267ad1051e005ddf8 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Tue, 26 Jul 2022 10:49:37 -0700 Subject: [PATCH 11/13] pr comments --- common/membership/rpMonitor.go | 37 +++++++---- common/persistence/client/quotas_test.go | 78 ++++++++++++++++++++++++ service/matching/taskQueueManager.go | 23 ++++++- service/matching/taskReader.go | 6 +- 4 files changed, 126 insertions(+), 18 deletions(-) create mode 100644 common/persistence/client/quotas_test.go diff --git a/common/membership/rpMonitor.go b/common/membership/rpMonitor.go index 9d13280318c..60ea4b15064 100644 --- a/common/membership/rpMonitor.go +++ b/common/membership/rpMonitor.go @@ -36,6 +36,7 @@ import ( "time" "go.temporal.io/server/common/convert" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/primitives" "github.com/pborman/uuid" @@ -57,6 +58,9 @@ const ( type ringpopMonitor struct { status int32 + lifecycleCtx context.Context + lifecycleCancel context.CancelFunc + serviceName string services map[string]int rp *RingPop @@ -79,15 +83,22 @@ func NewRingpopMonitor( broadcastHostPortResolver func() (string, error), ) Monitor { + lifecycleCtx, lifecycleCancel := context.WithCancel(context.Background()) + lifecycleCtx = headers.SetCallerInfo(lifecycleCtx, headers.NewCallerInfo(headers.CallerTypeBackground)) + rpo := &ringpopMonitor{ - broadcastHostPortResolver: broadcastHostPortResolver, - metadataManager: metadataManager, - status: common.DaemonStatusInitialized, + status: common.DaemonStatusInitialized, + + lifecycleCtx: lifecycleCtx, + lifecycleCancel: lifecycleCancel, + serviceName: serviceName, services: services, rp: rp, - logger: logger, rings: make(map[string]*ringpopServiceResolver), + logger: logger, + metadataManager: metadataManager, + broadcastHostPortResolver: broadcastHostPortResolver, hostID: uuid.NewUUID(), } for service, port := range services { @@ -119,7 +130,7 @@ func (rpo *ringpopMonitor) Start() { } rpo.rp.Start( - func() ([]string, error) { return fetchCurrentBootstrapHostports(rpo.metadataManager, rpo.logger) }, + func() ([]string, error) { return rpo.fetchCurrentBootstrapHostports() }, healthyHostLastHeartbeatCutoff/2) labels, err := rpo.rp.Labels() @@ -191,7 +202,7 @@ func SplitHostPortTyped(hostPort string) (net.IP, uint16, error) { func (rpo *ringpopMonitor) startHeartbeat(broadcastHostport string) error { // Start by cleaning up expired records to avoid growth - err := rpo.metadataManager.PruneClusterMembership(context.TODO(), &persistence.PruneClusterMembershipRequest{MaxRecordsPruned: 10}) + err := rpo.metadataManager.PruneClusterMembership(rpo.lifecycleCtx, &persistence.PruneClusterMembershipRequest{MaxRecordsPruned: 10}) if err != nil { return err } @@ -224,7 +235,7 @@ func (rpo *ringpopMonitor) startHeartbeat(broadcastHostport string) error { // Expire in 48 hours to allow for inspection of table by humans for debug scenarios. // For bootstrapping, we filter to a much shorter duration on the // read side by filtering on the last time a heartbeat was seen. - err = rpo.upsertMyMembership(context.TODO(), req) + err = rpo.upsertMyMembership(rpo.lifecycleCtx, req) if err == nil { rpo.logger.Info("Membership heartbeat upserted successfully", tag.Address(broadcastAddress.String()), @@ -237,15 +248,15 @@ func (rpo *ringpopMonitor) startHeartbeat(broadcastHostport string) error { return err } -func fetchCurrentBootstrapHostports(manager persistence.ClusterMetadataManager, log log.Logger) ([]string, error) { +func (rpo *ringpopMonitor) fetchCurrentBootstrapHostports() ([]string, error) { pageSize := 1000 set := make(map[string]struct{}) var nextPageToken []byte for { - resp, err := manager.GetClusterMembers( - context.TODO(), + resp, err := rpo.metadataManager.GetClusterMembers( + rpo.lifecycleCtx, &persistence.GetClusterMembersRequest{ LastHeartbeatWithin: healthyHostLastHeartbeatCutoff, PageSize: pageSize, @@ -268,7 +279,7 @@ func fetchCurrentBootstrapHostports(manager persistence.ClusterMetadataManager, bootstrapHostPorts = append(bootstrapHostPorts, k) } - log.Info("bootstrap hosts fetched", tag.BootstrapHostPorts(strings.Join(bootstrapHostPorts, ","))) + rpo.logger.Info("bootstrap hosts fetched", tag.BootstrapHostPorts(strings.Join(bootstrapHostPorts, ","))) return bootstrapHostPorts, nil } @@ -278,7 +289,7 @@ func fetchCurrentBootstrapHostports(manager persistence.ClusterMetadataManager, func (rpo *ringpopMonitor) startHeartbeatUpsertLoop(request *persistence.UpsertClusterMembershipRequest) { loopUpsertMembership := func() { for { - err := rpo.upsertMyMembership(context.TODO(), request) + err := rpo.upsertMyMembership(rpo.lifecycleCtx, request) if err != nil { rpo.logger.Error("Membership upsert failed.", tag.Error(err)) @@ -301,6 +312,8 @@ func (rpo *ringpopMonitor) Stop() { return } + rpo.lifecycleCancel() + for _, ring := range rpo.rings { ring.Stop() } diff --git a/common/persistence/client/quotas_test.go b/common/persistence/client/quotas_test.go new file mode 100644 index 00000000000..3132f0239e6 --- /dev/null +++ b/common/persistence/client/quotas_test.go @@ -0,0 +1,78 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package client + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "golang.org/x/exp/slices" +) + +type ( + quotasSuite struct { + suite.Suite + *require.Assertions + } +) + +func TestQuotasSuite(t *testing.T) { + s := new(quotasSuite) + suite.Run(t, s) +} + +func (s *quotasSuite) SetupSuite() { +} + +func (s *quotasSuite) TearDownSuite() { +} + +func (s *quotasSuite) SetupTest() { + s.Assertions = require.New(s.T()) +} + +func (s *quotasSuite) TearDownTest() { +} + +func (s *quotasSuite) TestCallerTypePriorityMapping() { + for _, priority := range CallerTypePriority { + index := slices.Index(RequestPrioritiesOrdered, priority) + s.NotEqual(-1, index) + } +} + +func (s *quotasSuite) TestAPIPriorityOverrideMapping() { + for _, priority := range APIPriorityOverride { + index := slices.Index(RequestPrioritiesOrdered, priority) + s.NotEqual(-1, index) + } +} + +func (s *quotasSuite) TestRequestPrioritiesOrdered() { + for idx := range RequestPrioritiesOrdered[1:] { + s.True(RequestPrioritiesOrdered[idx] < RequestPrioritiesOrdered[idx+1]) + } +} diff --git a/service/matching/taskQueueManager.go b/service/matching/taskQueueManager.go index 4e4a42d9223..b5a551ab71a 100644 --- a/service/matching/taskQueueManager.go +++ b/service/matching/taskQueueManager.go @@ -43,6 +43,7 @@ import ( "go.temporal.io/server/common/clock" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/future" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/util" persistencespb "go.temporal.io/server/api/persistence/v1" @@ -62,6 +63,8 @@ const ( // Fake Task ID to wrap a task for syncmatch syncMatchTaskId = -137 + + ioTimeout = 5 * time.Second ) type ( @@ -265,8 +268,11 @@ func (c *taskQueueManagerImpl) Stop() { // metadata. UpdateState would fail on the lease check, but don't even bother calling it. ackLevel := c.taskAckManager.getAckLevel() if ackLevel >= 0 { - c.db.UpdateState(context.TODO(), ackLevel) - c.taskGC.RunNow(context.TODO(), ackLevel) + ctx, cancel := newIOContext() + defer cancel() + + c.db.UpdateState(ctx, ackLevel) + c.taskGC.RunNow(ctx, ackLevel) } c.liveness.Stop() c.taskWriter.Stop() @@ -525,7 +531,11 @@ func (c *taskQueueManagerImpl) completeTask(task *persistencespb.AllocatedTaskIn } ackLevel := c.taskAckManager.completeTask(task.GetTaskId()) - c.taskGC.Run(context.TODO(), ackLevel) // TODO: completeTaskFunc and task.finish() should take in a context + + // TODO: completeTaskFunc and task.finish() should take in a context + ctx, cancel := newIOContext() + defer cancel() + c.taskGC.Run(ctx, ackLevel) } func rangeIDToTaskIDBlock(rangeID int64, rangeSize int64) taskIDBlock { @@ -605,3 +615,10 @@ func (c *taskQueueManagerImpl) QueueID() *taskQueueID { func (c *taskQueueManagerImpl) TaskQueueKind() enumspb.TaskQueueKind { return c.taskQueueKind } + +func newIOContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(context.Background(), ioTimeout) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) + + return ctx, cancel +} diff --git a/service/matching/taskReader.go b/service/matching/taskReader.go index 0f494deab0c..5f9e09c4c7b 100644 --- a/service/matching/taskReader.go +++ b/service/matching/taskReader.go @@ -183,7 +183,7 @@ Loop: tr.Signal() case <-updateAckTimer.C: - err := tr.persistAckLevel() + err := tr.persistAckLevel(ctx) tr.tlMgr.signalIfFatal(err) if err != nil { tr.logger().Error("Persistent store operation failure", @@ -269,10 +269,10 @@ func (tr *taskReader) addSingleTaskToBuffer( } } -func (tr *taskReader) persistAckLevel() error { +func (tr *taskReader) persistAckLevel(ctx context.Context) error { ackLevel := tr.tlMgr.taskAckManager.getAckLevel() tr.emitTaskLagMetric(ackLevel) - return tr.tlMgr.db.UpdateState(context.TODO(), ackLevel) + return tr.tlMgr.db.UpdateState(ctx, ackLevel) } func (tr *taskReader) isTaskAddedRecently(lastAddTime time.Time) bool { From 0e293684d0a847f5a4427c55e9522e10160cb821 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Tue, 26 Jul 2022 22:49:51 -0700 Subject: [PATCH 12/13] add dynamicconfig --- common/cluster/metadata.go | 9 +-- common/dynamicconfig/constants.go | 8 +++ common/persistence/client/fx.go | 28 +++++---- common/persistence/client/quotas.go | 13 +++++ common/searchattribute/manager.go | 5 +- service/frontend/fx.go | 12 ++-- service/frontend/service.go | 18 +++--- service/frontend/versionChecker.go | 4 +- service/fx.go | 21 +++++++ service/history/configs/config.go | 24 ++++---- service/history/fx.go | 12 ++-- .../replication/task_processor_manager.go | 9 ++- service/history/transferQueueProcessor.go | 6 +- service/matching/config.go | 58 ++++++++++--------- service/matching/fx.go | 13 +++-- service/worker/fx.go | 13 +++-- service/worker/service.go | 23 +++++--- temporal/fx.go | 13 +++-- temporal/server_impl.go | 13 +++-- 19 files changed, 196 insertions(+), 106 deletions(-) diff --git a/common/cluster/metadata.go b/common/cluster/metadata.go index 4da53126c1a..4439044c0fd 100644 --- a/common/cluster/metadata.go +++ b/common/cluster/metadata.go @@ -33,10 +33,10 @@ import ( "sync/atomic" "time" - "go.temporal.io/server/common/dynamicconfig" - "go.temporal.io/server/common" "go.temporal.io/server/common/collection" + "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/persistence" @@ -227,11 +227,12 @@ func (m *metadataImpl) Start() { return } - err := m.refreshClusterMetadata(context.Background()) + ctx := headers.SetCallerInfo(context.TODO(), headers.NewCallerInfo(headers.CallerTypeBackground)) + err := m.refreshClusterMetadata(ctx) if err != nil { m.logger.Fatal("Unable to initialize cluster metadata cache", tag.Error(err)) } - m.refresher = goro.NewHandle(context.Background()).Go(m.refreshLoop) + m.refresher = goro.NewHandle(ctx).Go(m.refreshLoop) } func (m *metadataImpl) Stop() { diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index 3108edd8d6f..6247fcc5101 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -148,6 +148,8 @@ const ( FrontendPersistenceMaxQPS = "frontend.persistenceMaxQPS" // FrontendPersistenceGlobalMaxQPS is the max qps frontend cluster can query DB FrontendPersistenceGlobalMaxQPS = "frontend.persistenceGlobalMaxQPS" + // FrontendEnablePersistencePriorityRateLimiting indicates if priority rate limiting is enabled in frontend persistence client + FrontendEnablePersistencePriorityRateLimiting = "frontend.enablePersistencePriorityRateLimiting" // FrontendVisibilityMaxPageSize is default max size for ListWorkflowExecutions in one page FrontendVisibilityMaxPageSize = "frontend.visibilityMaxPageSize" // FrontendESIndexMaxResultWindow is ElasticSearch index setting max_result_window @@ -242,6 +244,8 @@ const ( MatchingPersistenceMaxQPS = "matching.persistenceMaxQPS" // MatchingPersistenceGlobalMaxQPS is the max qps matching cluster can query DB MatchingPersistenceGlobalMaxQPS = "matching.persistenceGlobalMaxQPS" + // MatchingEnablePersistencePriorityRateLimiting indicates if priority rate limiting is enabled in matching persistence client + MatchingEnablePersistencePriorityRateLimiting = "matching.enablePersistencePriorityRateLimiting" // MatchingMinTaskThrottlingBurstSize is the minimum burst size for task queue throttling MatchingMinTaskThrottlingBurstSize = "matching.minTaskThrottlingBurstSize" // MatchingGetTasksBatchSize is the maximum batch size to fetch from the task buffer @@ -287,6 +291,8 @@ const ( HistoryPersistenceMaxQPS = "history.persistenceMaxQPS" // HistoryPersistenceGlobalMaxQPS is the max qps history cluster can query DB HistoryPersistenceGlobalMaxQPS = "history.persistenceGlobalMaxQPS" + // HistoryEnablePersistencePriorityRateLimiting indicates if priority rate limiting is enabled in history persistence client + HistoryEnablePersistencePriorityRateLimiting = "history.enablePersistencePriorityRateLimiting" // HistoryLongPollExpirationInterval is the long poll expiration interval in the history service HistoryLongPollExpirationInterval = "history.longPollExpirationInterval" // HistoryCacheInitialSize is initial size of history cache @@ -562,6 +568,8 @@ const ( WorkerPersistenceMaxQPS = "worker.persistenceMaxQPS" // WorkerPersistenceGlobalMaxQPS is the max qps worker cluster can query DB WorkerPersistenceGlobalMaxQPS = "worker.persistenceGlobalMaxQPS" + // WorkerEnablePersistencePriorityRateLimiting indicates if priority rate limiting is enabled in worker persistence client + WorkerEnablePersistencePriorityRateLimiting = "worker.enablePersistencePriorityRateLimiting" // WorkerIndexerConcurrency is the max concurrent messages to be processed at any given time WorkerIndexerConcurrency = "worker.indexerConcurrency" // WorkerESProcessorNumOfWorkers is num of workers for esProcessor diff --git a/common/persistence/client/fx.go b/common/persistence/client/fx.go index 926f55500c6..e708984f950 100644 --- a/common/persistence/client/fx.go +++ b/common/persistence/client/fx.go @@ -37,18 +37,20 @@ import ( ) type ( - PersistenceMaxQps dynamicconfig.IntPropertyFn - ClusterName string + PersistenceMaxQps dynamicconfig.IntPropertyFn + PriorityRateLimiting dynamicconfig.BoolPropertyFn + ClusterName string NewFactoryParams struct { fx.In - DataStoreFactory DataStoreFactory - Cfg *config.Persistence - PersistenceMaxQPS PersistenceMaxQps - ClusterName ClusterName - MetricsClient metrics.Client - Logger log.Logger + DataStoreFactory DataStoreFactory + Cfg *config.Persistence + PersistenceMaxQPS PersistenceMaxQps + PriorityRateLimiting PriorityRateLimiting + ClusterName ClusterName + MetricsClient metrics.Client + Logger log.Logger } FactoryProviderFn func(NewFactoryParams) Factory @@ -69,9 +71,13 @@ func FactoryProvider( ) Factory { var requestRatelimiter quotas.RequestRateLimiter if params.PersistenceMaxQPS != nil && params.PersistenceMaxQPS() > 0 { - requestRatelimiter = NewPriorityRateLimiter( - func() float64 { return float64(params.PersistenceMaxQPS()) }, - ) + rateFn := func() float64 { return float64(params.PersistenceMaxQPS()) } + + if params.PriorityRateLimiting != nil && params.PriorityRateLimiting() { + requestRatelimiter = NewPriorityRateLimiter(rateFn) + } else { + requestRatelimiter = NewNoopPriorityRateLimiter(rateFn) + } } return NewFactory( diff --git a/common/persistence/client/quotas.go b/common/persistence/client/quotas.go index c780d9e1f2c..ddb42110e08 100644 --- a/common/persistence/client/quotas.go +++ b/common/persistence/client/quotas.go @@ -70,3 +70,16 @@ func NewPriorityRateLimiter( rateLimiters, ) } + +func NewNoopPriorityRateLimiter( + rateFn quotas.RateFn, +) quotas.RequestRateLimiter { + priority := RequestPrioritiesOrdered[0] + + return quotas.NewPriorityRateLimiter( + func(_ quotas.Request) int { return priority }, + map[int]quotas.RateLimiter{ + priority: quotas.NewDefaultOutgoingRateLimiter(rateFn), + }, + ) +} diff --git a/common/searchattribute/manager.go b/common/searchattribute/manager.go index 6ea37a14410..17af1f5f03b 100644 --- a/common/searchattribute/manager.go +++ b/common/searchattribute/manager.go @@ -36,6 +36,7 @@ import ( persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common/clock" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/persistence" ) @@ -122,7 +123,9 @@ func (m *managerImpl) needRefreshCache(saCache cache, forceRefreshCache bool, no } func (m *managerImpl) refreshCache(saCache cache, now time.Time) (cache, error) { - clusterMetadata, err := m.clusterMetadataManager.GetCurrentClusterMetadata(context.TODO()) + ctx := headers.SetCallerInfo(context.TODO(), headers.NewCallerInfo(headers.CallerTypeBackground)) + + clusterMetadata, err := m.clusterMetadataManager.GetCurrentClusterMetadata(ctx) if err != nil { switch err.(type) { case *serviceerror.NotFound: diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 895353c09ce..ce0cdfc6c66 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -83,7 +83,7 @@ var Module = fx.Options( fx.Provide(GrpcServerOptionsProvider), fx.Provide(VisibilityManagerProvider), fx.Provide(ThrottledLoggerRpsFnProvider), - fx.Provide(PersistenceMaxQpsProvider), + fx.Provide(PersistenceRateLimitingParamsProvider), fx.Provide(FEReplicatorNamespaceReplicationQueueProvider), fx.Provide(func(so []grpc.ServerOption) *grpc.Server { return grpc.NewServer(so...) }), fx.Provide(healthServerProvider), @@ -312,10 +312,14 @@ func CallerInfoInterceptorProvider( return interceptor.NewCallerInfoInterceptor(namespaceRegistry) } -func PersistenceMaxQpsProvider( +func PersistenceRateLimitingParamsProvider( serviceConfig *Config, -) persistenceClient.PersistenceMaxQps { - return service.PersistenceMaxQpsFn(serviceConfig.PersistenceMaxQPS, serviceConfig.PersistenceGlobalMaxQPS) +) service.PersistenceRateLimitingParams { + return service.NewPersistenceRateLimitingParams( + serviceConfig.PersistenceMaxQPS, + serviceConfig.PersistenceGlobalMaxQPS, + serviceConfig.EnablePersistencePriorityRateLimiting, + ) } func VisibilityManagerProvider( diff --git a/service/frontend/service.go b/service/frontend/service.go index 449a4b6e779..1c234a8cd40 100644 --- a/service/frontend/service.go +++ b/service/frontend/service.go @@ -55,10 +55,11 @@ import ( // Config represents configuration for frontend service type Config struct { - NumHistoryShards int32 - ESIndexName string - PersistenceMaxQPS dynamicconfig.IntPropertyFn - PersistenceGlobalMaxQPS dynamicconfig.IntPropertyFn + NumHistoryShards int32 + ESIndexName string + PersistenceMaxQPS dynamicconfig.IntPropertyFn + PersistenceGlobalMaxQPS dynamicconfig.IntPropertyFn + EnablePersistencePriorityRateLimiting dynamicconfig.BoolPropertyFn StandardVisibilityPersistenceMaxReadQPS dynamicconfig.IntPropertyFn StandardVisibilityPersistenceMaxWriteQPS dynamicconfig.IntPropertyFn @@ -149,10 +150,11 @@ type Config struct { // NewConfig returns new service config with default values func NewConfig(dc *dynamicconfig.Collection, numHistoryShards int32, esIndexName string, enableReadFromES bool) *Config { return &Config{ - NumHistoryShards: numHistoryShards, - ESIndexName: esIndexName, - PersistenceMaxQPS: dc.GetIntProperty(dynamicconfig.FrontendPersistenceMaxQPS, 2000), - PersistenceGlobalMaxQPS: dc.GetIntProperty(dynamicconfig.FrontendPersistenceGlobalMaxQPS, 0), + NumHistoryShards: numHistoryShards, + ESIndexName: esIndexName, + PersistenceMaxQPS: dc.GetIntProperty(dynamicconfig.FrontendPersistenceMaxQPS, 2000), + PersistenceGlobalMaxQPS: dc.GetIntProperty(dynamicconfig.FrontendPersistenceGlobalMaxQPS, 0), + EnablePersistencePriorityRateLimiting: dc.GetBoolProperty(dynamicconfig.FrontendEnablePersistencePriorityRateLimiting, true), StandardVisibilityPersistenceMaxReadQPS: dc.GetIntProperty(dynamicconfig.StandardVisibilityPersistenceMaxReadQPS, 9000), StandardVisibilityPersistenceMaxWriteQPS: dc.GetIntProperty(dynamicconfig.StandardVisibilityPersistenceMaxWriteQPS, 9000), diff --git a/service/frontend/versionChecker.go b/service/frontend/versionChecker.go index 62cc3005ae7..790ad81f86b 100644 --- a/service/frontend/versionChecker.go +++ b/service/frontend/versionChecker.go @@ -72,7 +72,9 @@ func NewVersionChecker( func (vc *VersionChecker) Start() { if vc.config.EnableServerVersionCheck() { vc.startOnce.Do(func() { - go vc.versionCheckLoop(context.TODO()) + ctx := headers.SetCallerInfo(context.TODO(), headers.NewCallerInfo(headers.CallerTypeBackground)) + + go vc.versionCheckLoop(ctx) }) } } diff --git a/service/fx.go b/service/fx.go index dbc27ebafd4..50f635bfb66 100644 --- a/service/fx.go +++ b/service/fx.go @@ -25,6 +25,7 @@ package service import ( + "go.uber.org/fx" "google.golang.org/grpc" "go.temporal.io/server/common" @@ -38,6 +39,26 @@ import ( "go.temporal.io/server/common/telemetry" ) +type ( + PersistenceRateLimitingParams struct { + fx.Out + + PersistenceMaxQps persistenceClient.PersistenceMaxQps + PriorityRateLimiting persistenceClient.PriorityRateLimiting + } +) + +func NewPersistenceRateLimitingParams( + maxQps dynamicconfig.IntPropertyFn, + globalMaxQps dynamicconfig.IntPropertyFn, + priorityRateLimiting dynamicconfig.BoolPropertyFn, +) PersistenceRateLimitingParams { + return PersistenceRateLimitingParams{ + PersistenceMaxQps: PersistenceMaxQpsFn(maxQps, globalMaxQps), + PriorityRateLimiting: persistenceClient.PriorityRateLimiting(priorityRateLimiting), + } +} + func PersistenceMaxQpsFn( maxQps dynamicconfig.IntPropertyFn, globalMaxQps dynamicconfig.IntPropertyFn, diff --git a/service/history/configs/config.go b/service/history/configs/config.go index 3395d38d44a..bc537c18e93 100644 --- a/service/history/configs/config.go +++ b/service/history/configs/config.go @@ -40,10 +40,11 @@ type Config struct { NumberOfShards int32 DefaultVisibilityIndexName string - RPS dynamicconfig.IntPropertyFn - MaxIDLengthLimit dynamicconfig.IntPropertyFn - PersistenceMaxQPS dynamicconfig.IntPropertyFn - PersistenceGlobalMaxQPS dynamicconfig.IntPropertyFn + RPS dynamicconfig.IntPropertyFn + MaxIDLengthLimit dynamicconfig.IntPropertyFn + PersistenceMaxQPS dynamicconfig.IntPropertyFn + PersistenceGlobalMaxQPS dynamicconfig.IntPropertyFn + EnablePersistencePriorityRateLimiting dynamicconfig.BoolPropertyFn StandardVisibilityPersistenceMaxReadQPS dynamicconfig.IntPropertyFn StandardVisibilityPersistenceMaxWriteQPS dynamicconfig.IntPropertyFn @@ -276,13 +277,14 @@ func NewConfig(dc *dynamicconfig.Collection, numberOfShards int32, isAdvancedVis NumberOfShards: numberOfShards, DefaultVisibilityIndexName: defaultVisibilityIndex, - RPS: dc.GetIntProperty(dynamicconfig.HistoryRPS, 3000), - MaxIDLengthLimit: dc.GetIntProperty(dynamicconfig.MaxIDLengthLimit, 1000), - PersistenceMaxQPS: dc.GetIntProperty(dynamicconfig.HistoryPersistenceMaxQPS, 9000), - PersistenceGlobalMaxQPS: dc.GetIntProperty(dynamicconfig.HistoryPersistenceGlobalMaxQPS, 0), - ShutdownDrainDuration: dc.GetDurationProperty(dynamicconfig.HistoryShutdownDrainDuration, 0), - MaxAutoResetPoints: dc.GetIntPropertyFilteredByNamespace(dynamicconfig.HistoryMaxAutoResetPoints, DefaultHistoryMaxAutoResetPoints), - DefaultWorkflowTaskTimeout: dc.GetDurationPropertyFilteredByNamespace(dynamicconfig.DefaultWorkflowTaskTimeout, common.DefaultWorkflowTaskTimeout), + RPS: dc.GetIntProperty(dynamicconfig.HistoryRPS, 3000), + MaxIDLengthLimit: dc.GetIntProperty(dynamicconfig.MaxIDLengthLimit, 1000), + PersistenceMaxQPS: dc.GetIntProperty(dynamicconfig.HistoryPersistenceMaxQPS, 9000), + PersistenceGlobalMaxQPS: dc.GetIntProperty(dynamicconfig.HistoryPersistenceGlobalMaxQPS, 0), + EnablePersistencePriorityRateLimiting: dc.GetBoolProperty(dynamicconfig.HistoryEnablePersistencePriorityRateLimiting, true), + ShutdownDrainDuration: dc.GetDurationProperty(dynamicconfig.HistoryShutdownDrainDuration, 0), + MaxAutoResetPoints: dc.GetIntPropertyFilteredByNamespace(dynamicconfig.HistoryMaxAutoResetPoints, DefaultHistoryMaxAutoResetPoints), + DefaultWorkflowTaskTimeout: dc.GetDurationPropertyFilteredByNamespace(dynamicconfig.DefaultWorkflowTaskTimeout, common.DefaultWorkflowTaskTimeout), StandardVisibilityPersistenceMaxReadQPS: dc.GetIntProperty(dynamicconfig.StandardVisibilityPersistenceMaxReadQPS, 9000), StandardVisibilityPersistenceMaxWriteQPS: dc.GetIntProperty(dynamicconfig.StandardVisibilityPersistenceMaxWriteQPS, 9000), diff --git a/service/history/fx.go b/service/history/fx.go index da018f0c1e5..f7efac23fdb 100644 --- a/service/history/fx.go +++ b/service/history/fx.go @@ -71,7 +71,7 @@ var Module = fx.Options( fx.Provide(ESProcessorConfigProvider), fx.Provide(VisibilityManagerProvider), fx.Provide(ThrottledLoggerRpsFnProvider), - fx.Provide(PersistenceMaxQpsProvider), + fx.Provide(PersistenceRateLimitingParamsProvider), fx.Provide(ServiceResolverProvider), fx.Provide(EventNotifierProvider), fx.Provide(ArchivalClientProvider), @@ -198,10 +198,14 @@ func ESProcessorConfigProvider( } } -func PersistenceMaxQpsProvider( +func PersistenceRateLimitingParamsProvider( serviceConfig *configs.Config, -) persistenceClient.PersistenceMaxQps { - return service.PersistenceMaxQpsFn(serviceConfig.PersistenceMaxQPS, serviceConfig.PersistenceGlobalMaxQPS) +) service.PersistenceRateLimitingParams { + return service.NewPersistenceRateLimitingParams( + serviceConfig.PersistenceMaxQPS, + serviceConfig.PersistenceGlobalMaxQPS, + serviceConfig.EnablePersistencePriorityRateLimiting, + ) } func VisibilityManagerProvider( diff --git a/service/history/replication/task_processor_manager.go b/service/history/replication/task_processor_manager.go index 9d165f807b1..69056d85fa0 100644 --- a/service/history/replication/task_processor_manager.go +++ b/service/history/replication/task_processor_manager.go @@ -28,12 +28,14 @@ import ( "context" "sync" "sync/atomic" + "time" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/client" "go.temporal.io/server/client/history" "go.temporal.io/server/common" "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -238,8 +240,13 @@ func (r *taskProcessorManagerImpl) cleanupReplicationTasks() error { metrics.ReplicationTasksLag, int(r.shard.GetQueueExclusiveHighReadWatermark(tasks.CategoryReplication, currentCluster).Prev().TaskID-*minAckedTaskID), ) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(headers.CallerTypeBackground)) + defer cancel() + err := r.shard.GetExecutionManager().RangeCompleteHistoryTasks( - context.TODO(), + ctx, &persistence.RangeCompleteHistoryTasksRequest{ ShardID: r.shard.GetShardID(), TaskCategory: tasks.CategoryReplication, diff --git a/service/history/transferQueueProcessor.go b/service/history/transferQueueProcessor.go index c6dd92a29e4..39c2378f061 100644 --- a/service/history/transferQueueProcessor.go +++ b/service/history/transferQueueProcessor.go @@ -25,7 +25,6 @@ package history import ( - "context" "fmt" "sync" "sync/atomic" @@ -340,7 +339,10 @@ func (t *transferQueueProcessorImpl) completeTransfer() error { t.metricsClient.IncCounter(metrics.TransferQueueProcessorScope, metrics.TaskBatchCompleteCounter) if lowerAckLevel < upperAckLevel { - err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(context.TODO(), &persistence.RangeCompleteHistoryTasksRequest{ + ctx, cancel := newQueueIOContext() + defer cancel() + + err := t.shard.GetExecutionManager().RangeCompleteHistoryTasks(ctx, &persistence.RangeCompleteHistoryTasksRequest{ ShardID: t.shard.GetShardID(), TaskCategory: tasks.CategoryTransfer, InclusiveMinTaskKey: tasks.NewImmediateKey(lowerAckLevel + 1), diff --git a/service/matching/config.go b/service/matching/config.go index 7c54d76759c..a8373c91bc0 100644 --- a/service/matching/config.go +++ b/service/matching/config.go @@ -35,11 +35,12 @@ import ( type ( // Config represents configuration for matching service Config struct { - PersistenceMaxQPS dynamicconfig.IntPropertyFn - PersistenceGlobalMaxQPS dynamicconfig.IntPropertyFn - SyncMatchWaitDuration dynamicconfig.DurationPropertyFnWithTaskQueueInfoFilters - RPS dynamicconfig.IntPropertyFn - ShutdownDrainDuration dynamicconfig.DurationPropertyFn + PersistenceMaxQPS dynamicconfig.IntPropertyFn + PersistenceGlobalMaxQPS dynamicconfig.IntPropertyFn + EnablePersistencePriorityRateLimiting dynamicconfig.BoolPropertyFn + SyncMatchWaitDuration dynamicconfig.DurationPropertyFnWithTaskQueueInfoFilters + RPS dynamicconfig.IntPropertyFn + ShutdownDrainDuration dynamicconfig.DurationPropertyFn // taskQueueManager configuration @@ -106,29 +107,30 @@ type ( // NewConfig returns new service config with default values func NewConfig(dc *dynamicconfig.Collection) *Config { return &Config{ - PersistenceMaxQPS: dc.GetIntProperty(dynamicconfig.MatchingPersistenceMaxQPS, 3000), - PersistenceGlobalMaxQPS: dc.GetIntProperty(dynamicconfig.MatchingPersistenceGlobalMaxQPS, 0), - SyncMatchWaitDuration: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingSyncMatchWaitDuration, 200*time.Millisecond), - RPS: dc.GetIntProperty(dynamicconfig.MatchingRPS, 1200), - RangeSize: 100000, - GetTasksBatchSize: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingGetTasksBatchSize, 1000), - UpdateAckInterval: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingUpdateAckInterval, 1*time.Minute), - IdleTaskqueueCheckInterval: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingIdleTaskqueueCheckInterval, 5*time.Minute), - MaxTaskqueueIdleTime: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MaxTaskqueueIdleTime, 5*time.Minute), - LongPollExpirationInterval: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingLongPollExpirationInterval, time.Minute), - MinTaskThrottlingBurstSize: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingMinTaskThrottlingBurstSize, 1), - MaxTaskDeleteBatchSize: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingMaxTaskDeleteBatchSize, 100), - OutstandingTaskAppendsThreshold: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingOutstandingTaskAppendsThreshold, 250), - MaxTaskBatchSize: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingMaxTaskBatchSize, 100), - ThrottledLogRPS: dc.GetIntProperty(dynamicconfig.MatchingThrottledLogRPS, 20), - NumTaskqueueWritePartitions: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingNumTaskqueueWritePartitions, dynamicconfig.DefaultNumTaskQueuePartitions), - NumTaskqueueReadPartitions: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingNumTaskqueueReadPartitions, dynamicconfig.DefaultNumTaskQueuePartitions), - ForwarderMaxOutstandingPolls: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingForwarderMaxOutstandingPolls, 1), - ForwarderMaxOutstandingTasks: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingForwarderMaxOutstandingTasks, 1), - ForwarderMaxRatePerSecond: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingForwarderMaxRatePerSecond, 10), - ForwarderMaxChildrenPerNode: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingForwarderMaxChildrenPerNode, 20), - ShutdownDrainDuration: dc.GetDurationProperty(dynamicconfig.MatchingShutdownDrainDuration, 0), - MaxVersionGraphSize: dc.GetIntProperty(dynamicconfig.VersionGraphNodeLimit, 1000), + PersistenceMaxQPS: dc.GetIntProperty(dynamicconfig.MatchingPersistenceMaxQPS, 3000), + PersistenceGlobalMaxQPS: dc.GetIntProperty(dynamicconfig.MatchingPersistenceGlobalMaxQPS, 0), + EnablePersistencePriorityRateLimiting: dc.GetBoolProperty(dynamicconfig.MatchingEnablePersistencePriorityRateLimiting, true), + SyncMatchWaitDuration: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingSyncMatchWaitDuration, 200*time.Millisecond), + RPS: dc.GetIntProperty(dynamicconfig.MatchingRPS, 1200), + RangeSize: 100000, + GetTasksBatchSize: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingGetTasksBatchSize, 1000), + UpdateAckInterval: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingUpdateAckInterval, 1*time.Minute), + IdleTaskqueueCheckInterval: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingIdleTaskqueueCheckInterval, 5*time.Minute), + MaxTaskqueueIdleTime: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MaxTaskqueueIdleTime, 5*time.Minute), + LongPollExpirationInterval: dc.GetDurationPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingLongPollExpirationInterval, time.Minute), + MinTaskThrottlingBurstSize: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingMinTaskThrottlingBurstSize, 1), + MaxTaskDeleteBatchSize: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingMaxTaskDeleteBatchSize, 100), + OutstandingTaskAppendsThreshold: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingOutstandingTaskAppendsThreshold, 250), + MaxTaskBatchSize: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingMaxTaskBatchSize, 100), + ThrottledLogRPS: dc.GetIntProperty(dynamicconfig.MatchingThrottledLogRPS, 20), + NumTaskqueueWritePartitions: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingNumTaskqueueWritePartitions, dynamicconfig.DefaultNumTaskQueuePartitions), + NumTaskqueueReadPartitions: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingNumTaskqueueReadPartitions, dynamicconfig.DefaultNumTaskQueuePartitions), + ForwarderMaxOutstandingPolls: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingForwarderMaxOutstandingPolls, 1), + ForwarderMaxOutstandingTasks: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingForwarderMaxOutstandingTasks, 1), + ForwarderMaxRatePerSecond: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingForwarderMaxRatePerSecond, 10), + ForwarderMaxChildrenPerNode: dc.GetIntPropertyFilteredByTaskQueueInfo(dynamicconfig.MatchingForwarderMaxChildrenPerNode, 20), + ShutdownDrainDuration: dc.GetDurationProperty(dynamicconfig.MatchingShutdownDrainDuration, 0), + MaxVersionGraphSize: dc.GetIntProperty(dynamicconfig.VersionGraphNodeLimit, 1000), AdminNamespaceToPartitionDispatchRate: dc.GetFloatPropertyFilteredByNamespace(dynamicconfig.AdminMatchingNamespaceToPartitionDispatchRate, 10000), AdminNamespaceTaskqueueToPartitionDispatchRate: dc.GetFloatPropertyFilteredByTaskQueueInfo(dynamicconfig.AdminMatchingNamespaceTaskqueueToPartitionDispatchRate, 1000), diff --git a/service/matching/fx.go b/service/matching/fx.go index 3988a938178..75ff5c822fe 100644 --- a/service/matching/fx.go +++ b/service/matching/fx.go @@ -38,7 +38,6 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" - persistenceClient "go.temporal.io/server/common/persistence/client" "go.temporal.io/server/common/resource" "go.temporal.io/server/common/rpc/interceptor" "go.temporal.io/server/service" @@ -48,7 +47,7 @@ import ( var Module = fx.Options( fx.Provide(dynamicconfig.NewCollection), fx.Provide(NewConfig), - fx.Provide(PersistenceMaxQpsProvider), + fx.Provide(PersistenceRateLimitingParamsProvider), fx.Provide(ThrottledLoggerRpsFnProvider), fx.Provide(TelemetryInterceptorProvider), fx.Provide(RateLimitInterceptorProvider), @@ -88,10 +87,14 @@ func RateLimitInterceptorProvider( // This function is the same between services but uses different config sources. // if-case comes from resourceImpl.New. -func PersistenceMaxQpsProvider( +func PersistenceRateLimitingParamsProvider( serviceConfig *Config, -) persistenceClient.PersistenceMaxQps { - return service.PersistenceMaxQpsFn(serviceConfig.PersistenceMaxQPS, serviceConfig.PersistenceGlobalMaxQPS) +) service.PersistenceRateLimitingParams { + return service.NewPersistenceRateLimitingParams( + serviceConfig.PersistenceMaxQPS, + serviceConfig.PersistenceGlobalMaxQPS, + serviceConfig.EnablePersistencePriorityRateLimiting, + ) } func ServiceResolverProvider(membershipMonitor membership.Monitor) (membership.ServiceResolver, error) { diff --git a/service/worker/fx.go b/service/worker/fx.go index 3b752520e31..3d38b4992ca 100644 --- a/service/worker/fx.go +++ b/service/worker/fx.go @@ -34,7 +34,6 @@ import ( "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" - persistenceClient "go.temporal.io/server/common/persistence/client" "go.temporal.io/server/common/persistence/visibility" "go.temporal.io/server/common/persistence/visibility/manager" esclient "go.temporal.io/server/common/persistence/visibility/store/elasticsearch/client" @@ -58,7 +57,7 @@ var Module = fx.Options( fx.Provide(dynamicconfig.NewCollection), fx.Provide(ThrottledLoggerRpsFnProvider), fx.Provide(ConfigProvider), - fx.Provide(PersistenceMaxQpsProvider), + fx.Provide(PersistenceRateLimitingParamsProvider), fx.Provide(NewService), fx.Provide(NewWorkerManager), fx.Provide(NewPerNamespaceWorkerManager), @@ -69,10 +68,14 @@ func ThrottledLoggerRpsFnProvider(serviceConfig *Config) resource.ThrottledLogge return func() float64 { return float64(serviceConfig.ThrottledLogRPS()) } } -func PersistenceMaxQpsProvider( +func PersistenceRateLimitingParamsProvider( serviceConfig *Config, -) persistenceClient.PersistenceMaxQps { - return service.PersistenceMaxQpsFn(serviceConfig.PersistenceMaxQPS, serviceConfig.PersistenceGlobalMaxQPS) +) service.PersistenceRateLimitingParams { + return service.NewPersistenceRateLimitingParams( + serviceConfig.PersistenceMaxQPS, + serviceConfig.PersistenceGlobalMaxQPS, + serviceConfig.EnablePersistencePriorityRateLimiting, + ) } func ConfigProvider( diff --git a/service/worker/service.go b/service/worker/service.go index b0eedeb3fd1..ced6c8e35cc 100644 --- a/service/worker/service.go +++ b/service/worker/service.go @@ -100,15 +100,16 @@ type ( // Config contains all the service config for worker Config struct { - ArchiverConfig *archiver.Config - ScannerCfg *scanner.Config - ParentCloseCfg *parentclosepolicy.Config - BatcherCfg *batcher.Config - ThrottledLogRPS dynamicconfig.IntPropertyFn - PersistenceMaxQPS dynamicconfig.IntPropertyFn - PersistenceGlobalMaxQPS dynamicconfig.IntPropertyFn - EnableBatcher dynamicconfig.BoolPropertyFn - EnableParentClosePolicyWorker dynamicconfig.BoolPropertyFn + ArchiverConfig *archiver.Config + ScannerCfg *scanner.Config + ParentCloseCfg *parentclosepolicy.Config + BatcherCfg *batcher.Config + ThrottledLogRPS dynamicconfig.IntPropertyFn + PersistenceMaxQPS dynamicconfig.IntPropertyFn + PersistenceGlobalMaxQPS dynamicconfig.IntPropertyFn + EnablePersistencePriorityRateLimiting dynamicconfig.BoolPropertyFn + EnableBatcher dynamicconfig.BoolPropertyFn + EnableParentClosePolicyWorker dynamicconfig.BoolPropertyFn StandardVisibilityPersistenceMaxReadQPS dynamicconfig.IntPropertyFn StandardVisibilityPersistenceMaxWriteQPS dynamicconfig.IntPropertyFn @@ -308,6 +309,10 @@ func NewConfig(dc *dynamicconfig.Collection, persistenceConfig *config.Persisten dynamicconfig.WorkerPersistenceGlobalMaxQPS, 0, ), + EnablePersistencePriorityRateLimiting: dc.GetBoolProperty( + dynamicconfig.WorkerEnablePersistencePriorityRateLimiting, + true, + ), StandardVisibilityPersistenceMaxReadQPS: dc.GetIntProperty(dynamicconfig.StandardVisibilityPersistenceMaxReadQPS, 9000), StandardVisibilityPersistenceMaxWriteQPS: dc.GetIntProperty(dynamicconfig.StandardVisibilityPersistenceMaxWriteQPS, 9000), diff --git a/temporal/fx.go b/temporal/fx.go index 1a77854a09c..f6b2d7b30e3 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -594,12 +594,13 @@ func ApplyClusterMetadataConfigProvider( nil, ) factory := persistenceFactoryProvider(persistenceClient.NewFactoryParams{ - DataStoreFactory: dataStoreFactory, - Cfg: &config.Persistence, - PersistenceMaxQPS: nil, - ClusterName: persistenceClient.ClusterName(config.ClusterMetadata.CurrentClusterName), - MetricsClient: nil, - Logger: logger, + DataStoreFactory: dataStoreFactory, + Cfg: &config.Persistence, + PersistenceMaxQPS: nil, + PriorityRateLimiting: nil, + ClusterName: persistenceClient.ClusterName(config.ClusterMetadata.CurrentClusterName), + MetricsClient: nil, + Logger: logger, }) defer factory.Close() diff --git a/temporal/server_impl.go b/temporal/server_impl.go index 6a2388a87f8..b54f3863bf3 100644 --- a/temporal/server_impl.go +++ b/temporal/server_impl.go @@ -168,12 +168,13 @@ func initSystemNamespaces( nil, ) factory := persistenceFactoryProvider(persistenceClient.NewFactoryParams{ - DataStoreFactory: dataStoreFactory, - Cfg: cfg, - PersistenceMaxQPS: nil, - ClusterName: persistenceClient.ClusterName(currentClusterName), - MetricsClient: nil, - Logger: logger, + DataStoreFactory: dataStoreFactory, + Cfg: cfg, + PersistenceMaxQPS: nil, + PriorityRateLimiting: nil, + ClusterName: persistenceClient.ClusterName(currentClusterName), + MetricsClient: nil, + Logger: logger, }) defer factory.Close() From ffe5dd4df9e1d9767d23acfee206e36f779b2087 Mon Sep 17 00:00:00 2001 From: Yichao Yang Date: Wed, 27 Jul 2022 12:08:30 -0700 Subject: [PATCH 13/13] add comments --- common/cluster/metadata.go | 1 + common/searchattribute/manager.go | 1 + service/frontend/versionChecker.go | 1 + 3 files changed, 3 insertions(+) diff --git a/common/cluster/metadata.go b/common/cluster/metadata.go index 4439044c0fd..3b038fca3b4 100644 --- a/common/cluster/metadata.go +++ b/common/cluster/metadata.go @@ -227,6 +227,7 @@ func (m *metadataImpl) Start() { return } + // TODO: specify a timeout for the context ctx := headers.SetCallerInfo(context.TODO(), headers.NewCallerInfo(headers.CallerTypeBackground)) err := m.refreshClusterMetadata(ctx) if err != nil { diff --git a/common/searchattribute/manager.go b/common/searchattribute/manager.go index 17af1f5f03b..2a779acf700 100644 --- a/common/searchattribute/manager.go +++ b/common/searchattribute/manager.go @@ -123,6 +123,7 @@ func (m *managerImpl) needRefreshCache(saCache cache, forceRefreshCache bool, no } func (m *managerImpl) refreshCache(saCache cache, now time.Time) (cache, error) { + // TODO: specify a timeout for the context ctx := headers.SetCallerInfo(context.TODO(), headers.NewCallerInfo(headers.CallerTypeBackground)) clusterMetadata, err := m.clusterMetadataManager.GetCurrentClusterMetadata(ctx) diff --git a/service/frontend/versionChecker.go b/service/frontend/versionChecker.go index 790ad81f86b..b2644551cdc 100644 --- a/service/frontend/versionChecker.go +++ b/service/frontend/versionChecker.go @@ -72,6 +72,7 @@ func NewVersionChecker( func (vc *VersionChecker) Start() { if vc.config.EnableServerVersionCheck() { vc.startOnce.Do(func() { + // TODO: specify a timeout for the context ctx := headers.SetCallerInfo(context.TODO(), headers.NewCallerInfo(headers.CallerTypeBackground)) go vc.versionCheckLoop(ctx)