From e8fe0bbd804d214c9c578be46679aefbd5a3b44d Mon Sep 17 00:00:00 2001
From: Pavol Loffay
Date: Mon, 9 Jan 2023 16:00:38 +0100
Subject: [PATCH] Support tenant header propagation in query service
Signed-off-by: Pavol Loffay
---
cmd/query/app/flags.go | 5 ++++
cmd/query/app/flags_test.go | 2 ++
cmd/query/app/server.go | 2 +-
pkg/bearertoken/context.go | 27 ++++++++++++++++++----
pkg/bearertoken/context_test.go | 9 ++++++++
pkg/bearertoken/http.go | 11 ++++++++-
pkg/bearertoken/http_test.go | 18 ++++++++++++++-
plugin/storage/grpc/shared/grpc_client.go | 28 ++++++++++++++++-------
8 files changed, 87 insertions(+), 15 deletions(-)
diff --git a/cmd/query/app/flags.go b/cmd/query/app/flags.go
index 7a729fd8ec1..9b8a02d1229 100644
--- a/cmd/query/app/flags.go
+++ b/cmd/query/app/flags.go
@@ -44,6 +44,7 @@ const (
queryStaticFiles = "query.static-files"
queryUIConfig = "query.ui-config"
queryTokenPropagation = "query.bearer-token-propagation"
+ queryTenantPropagation = "query.tenant-propagation"
queryAdditionalHeaders = "query.additional-headers"
queryMaxClockSkewAdjust = "query.max-clock-skew-adjustment"
)
@@ -72,6 +73,8 @@ type QueryOptions struct {
UIConfig string
// BearerTokenPropagation activate/deactivate bearer token propagation to storage
BearerTokenPropagation bool
+ // TenantHeaderPropagation activate/deactivate propagation of tenant header
+ TenantHeaderPropagation string
// TLSGRPC configures secure transport (Consumer to Query service GRPC API)
TLSGRPC tlscfg.Options
// TLSHTTP configures secure transport (Consumer to Query service HTTP API)
@@ -93,6 +96,7 @@ func AddFlags(flagSet *flag.FlagSet) {
flagSet.String(queryStaticFiles, "", "The directory path override for the static assets for the UI")
flagSet.String(queryUIConfig, "", "The path to the UI configuration file in JSON format")
flagSet.Bool(queryTokenPropagation, false, "Allow propagation of bearer token to be used by storage plugins")
+ flagSet.String(queryTenantPropagation, "", "Enables propagation of a tenant header")
flagSet.Duration(queryMaxClockSkewAdjust, 0, "The maximum delta by which span timestamps may be adjusted in the UI due to clock skew; set to 0s to disable clock skew adjustments")
tlsGRPCFlagsConfig.AddFlags(flagSet)
tlsHTTPFlagsConfig.AddFlags(flagSet)
@@ -116,6 +120,7 @@ func (qOpts *QueryOptions) InitFromViper(v *viper.Viper, logger *zap.Logger) (*Q
qOpts.StaticAssets = v.GetString(queryStaticFiles)
qOpts.UIConfig = v.GetString(queryUIConfig)
qOpts.BearerTokenPropagation = v.GetBool(queryTokenPropagation)
+ qOpts.TenantHeaderPropagation = v.GetString(queryTenantPropagation)
qOpts.MaxClockSkewAdjust = v.GetDuration(queryMaxClockSkewAdjust)
stringSlice := v.GetStringSlice(queryAdditionalHeaders)
diff --git a/cmd/query/app/flags_test.go b/cmd/query/app/flags_test.go
index 2afbb873b5e..f4437983f54 100644
--- a/cmd/query/app/flags_test.go
+++ b/cmd/query/app/flags_test.go
@@ -42,6 +42,7 @@ func TestQueryBuilderFlags(t *testing.T) {
"--query.additional-headers=access-control-allow-origin:blerg",
"--query.additional-headers=whatever:thing",
"--query.max-clock-skew-adjustment=10s",
+ "--query.tenant-propagation=x-scope-orgid",
})
qOpts, err := new(QueryOptions).InitFromViper(v, zap.NewNop())
require.NoError(t, err)
@@ -55,6 +56,7 @@ func TestQueryBuilderFlags(t *testing.T) {
"Whatever": []string{"thing"},
}, qOpts.AdditionalHeaders)
assert.Equal(t, 10*time.Second, qOpts.MaxClockSkewAdjust)
+ assert.Equal(t, "x-scope-orgid", qOpts.TenantHeaderPropagation)
}
func TestQueryBuilderBadHeadersFlags(t *testing.T) {
diff --git a/cmd/query/app/server.go b/cmd/query/app/server.go
index 70e2988e7e4..15bea821936 100644
--- a/cmd/query/app/server.go
+++ b/cmd/query/app/server.go
@@ -178,7 +178,7 @@ func createHTTPServer(querySvc *querysvc.QueryService, metricsQuerySvc querysvc.
var handler http.Handler = r
handler = additionalHeadersHandler(handler, queryOpts.AdditionalHeaders)
if queryOpts.BearerTokenPropagation {
- handler = bearertoken.PropagationHandler(logger, handler)
+ handler = bearertoken.PropagationHandler(logger, handler, queryOpts.TenantHeaderPropagation)
}
handler = handlers.CompressHandler(handler)
recoveryHandler := recoveryhandler.NewRecoveryHandler(logger, true)
diff --git a/pkg/bearertoken/context.go b/pkg/bearertoken/context.go
index 4f9221fe631..6a9a90f6f71 100644
--- a/pkg/bearertoken/context.go
+++ b/pkg/bearertoken/context.go
@@ -14,11 +14,16 @@
package bearertoken
-import "context"
+import (
+ "context"
+)
type contextKeyType int
-const contextKey = contextKeyType(iota)
+const (
+ bearerTokenContextKey = contextKeyType(iota)
+ tenantHeaderContextKey
+)
// StoragePropagationKey is a key for viper configuration to pass this option to storage plugins.
const StoragePropagationKey = "storage.propagate.token"
@@ -28,11 +33,25 @@ func ContextWithBearerToken(ctx context.Context, token string) context.Context {
if token == "" {
return ctx
}
- return context.WithValue(ctx, contextKey, token)
+ return context.WithValue(ctx, bearerTokenContextKey, token)
}
// GetBearerToken from context, or empty string if there is no token.
func GetBearerToken(ctx context.Context) (string, bool) {
- val, ok := ctx.Value(contextKey).(string)
+ val, ok := ctx.Value(bearerTokenContextKey).(string)
+ return val, ok
+}
+
+// ContextWithTenant sets tenant into context.
+func ContextWithTenant(ctx context.Context, tenant string) context.Context {
+ if tenant == "" {
+ return ctx
+ }
+ return context.WithValue(ctx, tenantHeaderContextKey, tenant)
+}
+
+// GetTenant returns tenant, or empty string if there is no tenant.
+func GetTenant(ctx context.Context) (string, bool) {
+ val, ok := ctx.Value(tenantHeaderContextKey).(string)
return val, ok
}
diff --git a/pkg/bearertoken/context_test.go b/pkg/bearertoken/context_test.go
index 7a3f5184f20..c3b46462e15 100644
--- a/pkg/bearertoken/context_test.go
+++ b/pkg/bearertoken/context_test.go
@@ -29,3 +29,12 @@ func Test_GetBearerToken(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, contextToken, token)
}
+
+func Test_GetTenant(t *testing.T) {
+ const tenant = "jdoe"
+ ctx := context.Background()
+ ctx = ContextWithTenant(ctx, tenant)
+ contextTenant, ok := GetTenant(ctx)
+ assert.True(t, ok)
+ assert.Equal(t, contextTenant, tenant)
+}
diff --git a/pkg/bearertoken/http.go b/pkg/bearertoken/http.go
index 76cb0014bec..9588885baa2 100644
--- a/pkg/bearertoken/http.go
+++ b/pkg/bearertoken/http.go
@@ -24,9 +24,18 @@ import (
// PropagationHandler returns a http.Handler containing the logic to extract
// the Bearer token from the Authorization header of the http.Request and insert it into request.Context
// for propagation. The token can be accessed via GetBearerToken.
-func PropagationHandler(logger *zap.Logger, h http.Handler) http.Handler {
+// The handler as well extracts tenant header which can be accessed via GetTenantHeader.
+func PropagationHandler(logger *zap.Logger, h http.Handler, tenantHeader string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
+
+ if tenantHeader != "" {
+ header := r.Header.Get(tenantHeader)
+ if header != "" {
+ ctx = ContextWithTenant(ctx, header)
+ }
+ }
+
authHeaderValue := r.Header.Get("Authorization")
// If no Authorization header is present, try with X-Forwarded-Access-Token
if authHeaderValue == "" {
diff --git a/pkg/bearertoken/http_test.go b/pkg/bearertoken/http_test.go
index 88d2d4618b7..1b3b72fe163 100644
--- a/pkg/bearertoken/http_test.go
+++ b/pkg/bearertoken/http_test.go
@@ -32,6 +32,8 @@ func Test_PropagationHandler(t *testing.T) {
logger := zap.NewNop()
const bearerToken = "blah"
+ const tenantName = "jdoe"
+ const tenantHeader = "x-scope-orgid"
validTokenHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
@@ -43,11 +45,24 @@ func Test_PropagationHandler(t *testing.T) {
}
}
+ validTenantHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ tenant, ok := GetTenant(ctx)
+ assert.True(t, ok)
+ assert.Equal(t, tenantName, tenant)
+ stop.Done()
+ }
+ }
+
emptyHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, _ := GetBearerToken(ctx)
assert.Empty(t, token, bearerToken)
+ assert.Empty(t, token)
+ tenant, _ := GetTenant(ctx)
+ assert.Empty(t, tenant)
stop.Done()
}
}
@@ -65,13 +80,14 @@ func Test_PropagationHandler(t *testing.T) {
{name: "Basic Auth", sendHeader: true, headerName: "Authorization", headerValue: "Basic " + bearerToken, handler: emptyHandler},
{name: "X-Forwarded-Access-Token", headerName: "X-Forwarded-Access-Token", sendHeader: true, headerValue: "Bearer " + bearerToken, handler: validTokenHandler},
{name: "Invalid header", headerName: "X-Forwarded-Access-Token", sendHeader: true, headerValue: "Bearer " + bearerToken + " another stuff", handler: emptyHandler},
+ {name: "Tenant header", sendHeader: true, headerName: tenantHeader, headerValue: tenantName, handler: validTenantHandler},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
stop := sync.WaitGroup{}
stop.Add(1)
- r := PropagationHandler(logger, testCase.handler(&stop))
+ r := PropagationHandler(logger, testCase.handler(&stop), tenantHeader)
server := httptest.NewServer(r)
defer server.Close()
req, err := http.NewRequest("GET", server.URL, nil)
diff --git a/plugin/storage/grpc/shared/grpc_client.go b/plugin/storage/grpc/shared/grpc_client.go
index 4d230cec9af..16825f55859 100644
--- a/plugin/storage/grpc/shared/grpc_client.go
+++ b/plugin/storage/grpc/shared/grpc_client.go
@@ -32,8 +32,12 @@ import (
"github.com/jaegertracing/jaeger/storage/spanstore"
)
-// BearerTokenKey is the key name for the bearer token context value.
-const BearerTokenKey = "bearer.token"
+const (
+ // BearerTokenKey is the key name for the bearer token context value.
+ BearerTokenKey = "bearer.token"
+ // TenantKey is the key name for the tenant name.
+ TenantKey = "tenant.name"
+)
var (
_ StoragePlugin = (*grpcClient)(nil)
@@ -85,16 +89,24 @@ func composeContextUpgradeFuncs(funcs ...ContextUpgradeFunc) ContextUpgradeFunc
// in the request metadata, if the original context has bearer token attached.
// Otherwise returns original context.
func upgradeContextWithBearerToken(ctx context.Context) context.Context {
+ md, ok := metadata.FromOutgoingContext(ctx)
+ if !ok {
+ md = metadata.New(nil)
+ }
+
+ headerValue, hasHeader := bearertoken.GetTenant(ctx)
+ if hasHeader {
+ md.Set(TenantKey, headerValue)
+ }
bearerToken, hasToken := bearertoken.GetBearerToken(ctx)
if hasToken {
- md, ok := metadata.FromOutgoingContext(ctx)
- if !ok {
- md = metadata.New(nil)
- }
md.Set(BearerTokenKey, bearerToken)
- return metadata.NewOutgoingContext(ctx, md)
}
- return ctx
+
+ if !hasHeader && !hasToken {
+ return ctx
+ }
+ return metadata.NewOutgoingContext(ctx, md)
}
// DependencyReader implements shared.StoragePlugin.