Skip to content

Commit

Permalink
split auth into decode and authorize + logging jwt sub
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadelrouby committed Sep 6, 2024
1 parent 83bc148 commit c89bab4
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 21 deletions.
7 changes: 3 additions & 4 deletions cmds/core-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,9 @@ func RunHTTPServer(ctx context.Context, ctxCanceler func(), address, locality st
multiRouter.Routers = append(multiRouter.Routers, &scdV1Router)
}

handler := logging.HTTPMiddleware(logger, *dumpRequests,
healthyEndpointMiddleware(logger,
&multiRouter,
))
handler := logging.HTTPMiddleware(logger, *dumpRequests, &multiRouter)

handler = auth.DecoderMiddleware(authorizer, healthyEndpointMiddleware(logger, handler))

httpServer := &http.Server{
Addr: address,
Expand Down
84 changes: 67 additions & 17 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,29 +185,26 @@ func (a *Authorizer) setKeys(keys []interface{}) {
// Authorize extracts and verifies bearer tokens from a http.Request.
func (a *Authorizer) Authorize(_ http.ResponseWriter, r *http.Request, authOptions []api.AuthorizationOption) api.AuthorizationResult {

tknStr, ok := getToken(r)
if !ok {
missing, ok := r.Context().Value("authTokenMissing").(bool)
if !ok || missing {
return api.AuthorizationResult{Error: stacktrace.NewErrorWithCode(dsserr.Unauthenticated, "Missing access token")}
}

a.keyGuard.RLock()
keys := a.keys
a.keyGuard.RUnlock()
validated := false
var err error
var keyClaims claims
validated, ok := r.Context().Value("authValidated").(bool)
if !ok {
validated = false
}

for _, key := range keys {
err, ok := r.Context().Value("authError").(error)
if !ok {
err = nil
}

keyClaims, ok := r.Context().Value("authClaims").(claims)
if !ok {
keyClaims = claims{}
key := key
_, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) {
return key, nil
})
if err == nil {
validated = true
break
}
}

if !validated {
return api.AuthorizationResult{Error: stacktrace.PropagateWithCode(err, dsserr.Unauthenticated, "Access token validation failed")}
}
Expand Down Expand Up @@ -300,3 +297,56 @@ func getToken(r *http.Request) (string, bool) {
}
return authHeader[7:], true
}

// Extract valid claims from a JWT token.
func (a *Authorizer) extractClaims(r *http.Request) (missing bool, validated bool, keyClaims claims, err error) {
tknStr, ok := getToken(r)

if !ok {
return true, false, claims{}, nil
}

a.keyGuard.RLock()
keys := a.keys
a.keyGuard.RUnlock()
validated = false
missing = false

for _, key := range keys {
keyClaims = claims{}
key := key
_, err = jwt.ParseWithClaims(tknStr, &keyClaims, func(token *jwt.Token) (interface{}, error) {
return key, nil
})
if err == nil {
validated = true
break
}
}

return missing, validated, keyClaims, err

}

// Decoder Middleware
func DecoderMiddleware(a *Authorizer, handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

missing, validated, keyClaims, err := a.extractClaims(r)

ctx := context.WithValue(r.Context(), "authValidated", validated)
ctx = context.WithValue(ctx, "authTokenMissing", missing)
ctx = context.WithValue(ctx, "authClaims", keyClaims)
ctx = context.WithValue(ctx, "authError", err)

if validated && err == nil {
// If the token is valid, we can extract the subject from the token and add them to the context.
ctx = context.WithValue(ctx, "authSubject", keyClaims.Subject)
}

// Create a new request with the updated context
r = r.WithContext(ctx)

handler.ServeHTTP(w, r)
})
}
6 changes: 6 additions & 0 deletions pkg/logging/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ func HTTPMiddleware(logger *zap.Logger, dump bool, handler http.Handler) http.Ha
// replace req.Body with a copy
r.Body = io.NopCloser(bytes.NewReader(reqData))
}

subject, ok := r.Context().Value("authSubject").(string)
if !ok {
subject = ""
}
logger = logger.With(zap.String("req_sub", subject))
}

handler.ServeHTTP(trw, r)
Expand Down

0 comments on commit c89bab4

Please sign in to comment.