diff --git a/cmd/query/app/token_propagation_hander_test.go b/cmd/query/app/token_propagation_hander_test.go index 67d7de023422..8e1f2a5428b2 100644 --- a/cmd/query/app/token_propagation_hander_test.go +++ b/cmd/query/app/token_propagation_hander_test.go @@ -51,16 +51,19 @@ func Test_bearTokenPropagationHandler(t *testing.T) { } testCases := []struct { - name string - sendHeader bool - header string - handler func(stop *sync.WaitGroup) http.HandlerFunc + name string + sendHeader bool + headerValue string + headerName string + handler func(stop *sync.WaitGroup) http.HandlerFunc }{ - { name:"Bearer token", sendHeader: true, header: "Bearer " + bearerToken, handler:validTokenHandler}, - { name:"Invalid header",sendHeader: true, header: bearerToken, handler:emptyHandler}, - { name:"No header", sendHeader: false, handler:emptyHandler}, - { name:"Basic Auth", sendHeader: true, header: "Basic " + bearerToken, handler:emptyHandler}, - { name:"X-Forwarded-Access-Token", sendHeader: true, header: "Bearer " + bearerToken, handler:validTokenHandler}, + { name:"Bearer token", sendHeader: true, headerName:"Authorization", headerValue: "Bearer " + bearerToken, handler:validTokenHandler}, + { name:"Raw bearer token",sendHeader: true, headerName:"Authorization", headerValue: bearerToken, handler:validTokenHandler}, + { name:"No headerValue", sendHeader: false, headerName:"Authorization", handler:emptyHandler}, + { name:"Basic Auth", sendHeader: true, headerName:"Authorization", headerValue: "Basic " + bearerToken, handler:emptyHandler}, + { name:"X-Forwarded-Access-Token", headerName:"X-Forwarded-Access-Token", sendHeader: true, headerValue: "Bearer " + bearerToken, handler:validTokenHandler}, + { name:"Invalid header", headerName:"X-Forwarded-Access-Token", sendHeader: true, headerValue: "Bearer " + bearerToken + " another stuff", handler:emptyHandler}, + } for _, testCase := range testCases { @@ -73,7 +76,7 @@ func Test_bearTokenPropagationHandler(t *testing.T) { req , err := http.NewRequest("GET", server.URL, nil) assert.Nil(t,err) if testCase.sendHeader { - req.Header.Add("Authorization", testCase.header) + req.Header.Add(testCase.headerName, testCase.headerValue) } _, err = httpClient.Do(req) assert.Nil(t, err) diff --git a/cmd/query/app/token_propagation_handler.go b/cmd/query/app/token_propagation_handler.go index c16b9ece662e..74afd33f767b 100644 --- a/cmd/query/app/token_propagation_handler.go +++ b/cmd/query/app/token_propagation_handler.go @@ -27,7 +27,7 @@ func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Hand return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() authHeaderValue := r.Header.Get("Authorization") - // If no Authorization header is present, try with X-Forwarded-Access-Token + // If no Authorization headerValue is present, try with X-Forwarded-Access-Token if authHeaderValue == "" { authHeaderValue = r.Header.Get("X-Forwarded-Access-Token") } @@ -39,8 +39,11 @@ func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Hand if headerValue[0] == "Bearer" { token = headerValue[1] } + } else if len(headerValue) == 1 { + // Tread all value as a token + token = authHeaderValue } else { - logger.Warn("Invalid authorization header, skipping bearer token propagation") + logger.Warn("Invalid authorization header value, skipping token propagation") } h.ServeHTTP(w, r.WithContext(spanstore.ContextWithBearerToken(ctx, token))) } else {