Skip to content

Commit

Permalink
Support tenant header propagation in query service
Browse files Browse the repository at this point in the history
Signed-off-by: Pavol Loffay <[email protected]>
  • Loading branch information
pavolloffay committed Jan 9, 2023
1 parent b71616c commit e8fe0bb
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 15 deletions.
5 changes: 5 additions & 0 deletions cmd/query/app/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const (
queryStaticFiles = "query.static-files"
queryUIConfig = "query.ui-config"
queryTokenPropagation = "query.bearer-token-propagation"
queryTenantPropagation = "query.tenant-propagation"
queryAdditionalHeaders = "query.additional-headers"
queryMaxClockSkewAdjust = "query.max-clock-skew-adjustment"
)
Expand Down Expand Up @@ -72,6 +73,8 @@ type QueryOptions struct {
UIConfig string
// BearerTokenPropagation activate/deactivate bearer token propagation to storage
BearerTokenPropagation bool
// TenantHeaderPropagation activate/deactivate propagation of tenant header
TenantHeaderPropagation string
// TLSGRPC configures secure transport (Consumer to Query service GRPC API)
TLSGRPC tlscfg.Options
// TLSHTTP configures secure transport (Consumer to Query service HTTP API)
Expand All @@ -93,6 +96,7 @@ func AddFlags(flagSet *flag.FlagSet) {
flagSet.String(queryStaticFiles, "", "The directory path override for the static assets for the UI")
flagSet.String(queryUIConfig, "", "The path to the UI configuration file in JSON format")
flagSet.Bool(queryTokenPropagation, false, "Allow propagation of bearer token to be used by storage plugins")
flagSet.String(queryTenantPropagation, "", "Enables propagation of a tenant header")
flagSet.Duration(queryMaxClockSkewAdjust, 0, "The maximum delta by which span timestamps may be adjusted in the UI due to clock skew; set to 0s to disable clock skew adjustments")
tlsGRPCFlagsConfig.AddFlags(flagSet)
tlsHTTPFlagsConfig.AddFlags(flagSet)
Expand All @@ -116,6 +120,7 @@ func (qOpts *QueryOptions) InitFromViper(v *viper.Viper, logger *zap.Logger) (*Q
qOpts.StaticAssets = v.GetString(queryStaticFiles)
qOpts.UIConfig = v.GetString(queryUIConfig)
qOpts.BearerTokenPropagation = v.GetBool(queryTokenPropagation)
qOpts.TenantHeaderPropagation = v.GetString(queryTenantPropagation)

qOpts.MaxClockSkewAdjust = v.GetDuration(queryMaxClockSkewAdjust)
stringSlice := v.GetStringSlice(queryAdditionalHeaders)
Expand Down
2 changes: 2 additions & 0 deletions cmd/query/app/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func TestQueryBuilderFlags(t *testing.T) {
"--query.additional-headers=access-control-allow-origin:blerg",
"--query.additional-headers=whatever:thing",
"--query.max-clock-skew-adjustment=10s",
"--query.tenant-propagation=x-scope-orgid",
})
qOpts, err := new(QueryOptions).InitFromViper(v, zap.NewNop())
require.NoError(t, err)
Expand All @@ -55,6 +56,7 @@ func TestQueryBuilderFlags(t *testing.T) {
"Whatever": []string{"thing"},
}, qOpts.AdditionalHeaders)
assert.Equal(t, 10*time.Second, qOpts.MaxClockSkewAdjust)
assert.Equal(t, "x-scope-orgid", qOpts.TenantHeaderPropagation)
}

func TestQueryBuilderBadHeadersFlags(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion cmd/query/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func createHTTPServer(querySvc *querysvc.QueryService, metricsQuerySvc querysvc.
var handler http.Handler = r
handler = additionalHeadersHandler(handler, queryOpts.AdditionalHeaders)
if queryOpts.BearerTokenPropagation {
handler = bearertoken.PropagationHandler(logger, handler)
handler = bearertoken.PropagationHandler(logger, handler, queryOpts.TenantHeaderPropagation)
}
handler = handlers.CompressHandler(handler)
recoveryHandler := recoveryhandler.NewRecoveryHandler(logger, true)
Expand Down
27 changes: 23 additions & 4 deletions pkg/bearertoken/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@

package bearertoken

import "context"
import (
"context"
)

type contextKeyType int

const contextKey = contextKeyType(iota)
const (
bearerTokenContextKey = contextKeyType(iota)
tenantHeaderContextKey
)

// StoragePropagationKey is a key for viper configuration to pass this option to storage plugins.
const StoragePropagationKey = "storage.propagate.token"
Expand All @@ -28,11 +33,25 @@ func ContextWithBearerToken(ctx context.Context, token string) context.Context {
if token == "" {
return ctx
}
return context.WithValue(ctx, contextKey, token)
return context.WithValue(ctx, bearerTokenContextKey, token)
}

// GetBearerToken from context, or empty string if there is no token.
func GetBearerToken(ctx context.Context) (string, bool) {
val, ok := ctx.Value(contextKey).(string)
val, ok := ctx.Value(bearerTokenContextKey).(string)
return val, ok
}

// ContextWithTenant sets tenant into context.
func ContextWithTenant(ctx context.Context, tenant string) context.Context {
if tenant == "" {
return ctx
}
return context.WithValue(ctx, tenantHeaderContextKey, tenant)
}

// GetTenant returns tenant, or empty string if there is no tenant.
func GetTenant(ctx context.Context) (string, bool) {
val, ok := ctx.Value(tenantHeaderContextKey).(string)
return val, ok
}
9 changes: 9 additions & 0 deletions pkg/bearertoken/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,12 @@ func Test_GetBearerToken(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, contextToken, token)
}

func Test_GetTenant(t *testing.T) {
const tenant = "jdoe"
ctx := context.Background()
ctx = ContextWithTenant(ctx, tenant)
contextTenant, ok := GetTenant(ctx)
assert.True(t, ok)
assert.Equal(t, contextTenant, tenant)
}
11 changes: 10 additions & 1 deletion pkg/bearertoken/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@ import (
// PropagationHandler returns a http.Handler containing the logic to extract
// the Bearer token from the Authorization header of the http.Request and insert it into request.Context
// for propagation. The token can be accessed via GetBearerToken.
func PropagationHandler(logger *zap.Logger, h http.Handler) http.Handler {
// The handler as well extracts tenant header which can be accessed via GetTenantHeader.
func PropagationHandler(logger *zap.Logger, h http.Handler, tenantHeader string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

if tenantHeader != "" {
header := r.Header.Get(tenantHeader)
if header != "" {
ctx = ContextWithTenant(ctx, header)
}
}

authHeaderValue := r.Header.Get("Authorization")
// If no Authorization header is present, try with X-Forwarded-Access-Token
if authHeaderValue == "" {
Expand Down
18 changes: 17 additions & 1 deletion pkg/bearertoken/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func Test_PropagationHandler(t *testing.T) {

logger := zap.NewNop()
const bearerToken = "blah"
const tenantName = "jdoe"
const tenantHeader = "x-scope-orgid"

validTokenHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -43,11 +45,24 @@ func Test_PropagationHandler(t *testing.T) {
}
}

validTenantHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
tenant, ok := GetTenant(ctx)
assert.True(t, ok)
assert.Equal(t, tenantName, tenant)
stop.Done()
}
}

emptyHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, _ := GetBearerToken(ctx)
assert.Empty(t, token, bearerToken)
assert.Empty(t, token)
tenant, _ := GetTenant(ctx)
assert.Empty(t, tenant)
stop.Done()
}
}
Expand All @@ -65,13 +80,14 @@ func Test_PropagationHandler(t *testing.T) {
{name: "Basic Auth", sendHeader: true, headerName: "Authorization", headerValue: "Basic " + bearerToken, handler: emptyHandler},
{name: "X-Forwarded-Access-Token", headerName: "X-Forwarded-Access-Token", sendHeader: true, headerValue: "Bearer " + bearerToken, handler: validTokenHandler},
{name: "Invalid header", headerName: "X-Forwarded-Access-Token", sendHeader: true, headerValue: "Bearer " + bearerToken + " another stuff", handler: emptyHandler},
{name: "Tenant header", sendHeader: true, headerName: tenantHeader, headerValue: tenantName, handler: validTenantHandler},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
stop := sync.WaitGroup{}
stop.Add(1)
r := PropagationHandler(logger, testCase.handler(&stop))
r := PropagationHandler(logger, testCase.handler(&stop), tenantHeader)
server := httptest.NewServer(r)
defer server.Close()
req, err := http.NewRequest("GET", server.URL, nil)
Expand Down
28 changes: 20 additions & 8 deletions plugin/storage/grpc/shared/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ import (
"github.com/jaegertracing/jaeger/storage/spanstore"
)

// BearerTokenKey is the key name for the bearer token context value.
const BearerTokenKey = "bearer.token"
const (
// BearerTokenKey is the key name for the bearer token context value.
BearerTokenKey = "bearer.token"
// TenantKey is the key name for the tenant name.
TenantKey = "tenant.name"
)

var (
_ StoragePlugin = (*grpcClient)(nil)
Expand Down Expand Up @@ -85,16 +89,24 @@ func composeContextUpgradeFuncs(funcs ...ContextUpgradeFunc) ContextUpgradeFunc
// in the request metadata, if the original context has bearer token attached.
// Otherwise returns original context.
func upgradeContextWithBearerToken(ctx context.Context) context.Context {
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
md = metadata.New(nil)
}

headerValue, hasHeader := bearertoken.GetTenant(ctx)
if hasHeader {
md.Set(TenantKey, headerValue)
}
bearerToken, hasToken := bearertoken.GetBearerToken(ctx)
if hasToken {
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
md = metadata.New(nil)
}
md.Set(BearerTokenKey, bearerToken)
return metadata.NewOutgoingContext(ctx, md)
}
return ctx

if !hasHeader && !hasToken {
return ctx
}
return metadata.NewOutgoingContext(ctx, md)
}

// DependencyReader implements shared.StoragePlugin.
Expand Down

0 comments on commit e8fe0bb

Please sign in to comment.