diff --git a/internal/audit/log.go b/internal/audit/log.go index 4d7ec3b..62e4487 100644 --- a/internal/audit/log.go +++ b/internal/audit/log.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "time" "github.com/rs/zerolog" ) @@ -32,8 +33,11 @@ type Entry struct { SourceIP string UserAgent string RequestedProfile string - Identity string Authorized bool + Subject string + Issuer string + Audience []string + Expiry int64 Error string Repositories []string Permissions []string @@ -49,11 +53,25 @@ func (e *Entry) MarshalZerologObject(event *zerolog.Event) { Str("sourceIP", e.SourceIP). Str("userAgent", e.UserAgent). Str("requestedProfile", e.RequestedProfile). - Str("identity", e.Identity). Bool("authorized", e.Authorized). - Str("error", e.Error). - Strs("repositories", e.Repositories). - Strs("permissions", e.Permissions) + Str("subject", e.Subject). + Str("issuer", e.Issuer). + Str("error", e.Error) + + if e.Expiry > 0 { + event.Time("expiry", time.UnixMilli(e.Expiry)) + } + if len(e.Audience) > 0 { + event.Strs("audience", e.Audience) + } + + if len(e.Repositories) > 0 { + event.Strs("repositories", e.Repositories) + } + + if len(e.Permissions) > 0 { + event.Strs("permissions", e.Permissions) + } } // Begin sets up the audit log entry for the current request with details from the request. diff --git a/internal/jwt/jwt.go b/internal/jwt/jwt.go index 7406294..1c8f5ae 100644 --- a/internal/jwt/jwt.go +++ b/internal/jwt/jwt.go @@ -9,12 +9,13 @@ import ( "time" "github.com/go-jose/go-jose/v4" + "github.com/justinas/alice" jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" "github.com/auth0/go-jwt-middleware/v2/jwks" "github.com/auth0/go-jwt-middleware/v2/validator" - "github.com/rs/zerolog/log" + "github.com/jamestelfer/chinmina-bridge/internal/audit" "github.com/jamestelfer/chinmina-bridge/internal/config" ) @@ -49,21 +50,23 @@ func Middleware(cfg config.AuthorizationConfig, options ...jwtmiddleware.Option) return nil, fmt.Errorf("failed to set up the validator: %v", err) } + // Auditing of the validation process uses a combination of the error handler + // and the audit middleware. The first ensures that validation errors are marked in + // the audit log, while the second ensures that the claims are logged when the + // token is valid. + + // force the use of the audit error handler + options = append(options, jwtmiddleware.WithErrorHandler(auditErrorHandler())) + // wrap the standard validator with additional validation that ensures the // core claims (including validity periods) are present tokenValidator := registeredClaimsValidator(jwtValidator.ValidateToken) - return jwtmiddleware.New(tokenValidator, options...).CheckJWT, nil -} + validationMiddleware := jwtmiddleware.New(tokenValidator, options...).CheckJWT -func LogErrorHandler() jwtmiddleware.ErrorHandler { - return func(w http.ResponseWriter, r *http.Request, err error) { - log.Warn(). - Err(err). - Msg("JWT decode failure") + subChain := alice.New(validationMiddleware, auditClaimsMiddleware()).Then - jwtmiddleware.DefaultErrorHandler(w, r, err) - } + return subChain, nil } // ContextWithClaims returns a new context.Context with the provided validated claims @@ -103,6 +106,40 @@ func RequireBuildkiteClaimsFromContext(ctx context.Context) BuildkiteClaims { return *c } +func auditClaimsMiddleware() func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + entry := audit.Log(r.Context()) + claims := ClaimsFromContext(r.Context()) + + if claims == nil { + entry.Error = "JWT claims missing from context" + } else { + reg := claims.RegisteredClaims + entry.Authorized = true + entry.Subject = reg.Subject + entry.Issuer = reg.Issuer + entry.Audience = reg.Audience + entry.Expiry = reg.Expiry + } + + next.ServeHTTP(w, r) + }) + } +} + +func auditErrorHandler() jwtmiddleware.ErrorHandler { + return func(w http.ResponseWriter, r *http.Request, err error) { + entry := audit.Log(r.Context()) + entry.Error = fmt.Sprintf("JWT authorization failure: %s", err.Error()) + + // The default error handler will write the appropriate response status + // code. The status code is recorded centrally by the central audit + // middleware. + jwtmiddleware.DefaultErrorHandler(w, r, err) + } +} + type KeyFunc = func(ctx context.Context) (any, error) func remoteJWKS(cfg config.AuthorizationConfig) (url.URL, KeyFunc, error) { diff --git a/internal/jwt/jwt_test.go b/internal/jwt/jwt_test.go index dcd1eff..e240de5 100644 --- a/internal/jwt/jwt_test.go +++ b/internal/jwt/jwt_test.go @@ -1,6 +1,7 @@ package jwt import ( + "context" "crypto/rand" "crypto/rsa" "encoding/json" @@ -11,8 +12,10 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" + "github.com/justinas/alice" jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" + "github.com/jamestelfer/chinmina-bridge/internal/audit" "github.com/jamestelfer/chinmina-bridge/internal/config" "github.com/jamestelfer/chinmina-bridge/internal/testhelpers" "github.com/stretchr/testify/assert" @@ -94,18 +97,6 @@ func TestMiddleware(t *testing.T) { wantStatusCode: http.StatusUnauthorized, wantBodyText: "JWT is invalid", }, - { - name: "error handler", - claims: valid(jwt.Claims{ - Audience: []string{"audience"}, - Subject: "subject", - Issuer: "issuer", - }), - customClaims: custom("that dog ain't gonna hunt", "test-pipeline"), - wantStatusCode: http.StatusUnauthorized, - wantBodyText: "JWT is invalid", - options: []jwtmiddleware.Option{jwtmiddleware.WithErrorHandler(LogErrorHandler())}, - }, } jwk := generateJWK(t) @@ -121,9 +112,13 @@ func TestMiddleware(t *testing.T) { t.Run(test.name, func(t *testing.T) { testhelpers.SetupLogger(t) + ctx, _ := audit.Context(context.Background()) + request, err := http.NewRequest(http.MethodGet, "", nil) require.NoError(t, err) + request = request.WithContext(ctx) + audience := "an-actor-demands-an" if len(test.claims.Audience) > 0 { audience = test.claims.Audience[0] @@ -139,32 +134,35 @@ func TestMiddleware(t *testing.T) { responseRecorder := httptest.NewRecorder() - options := []jwtmiddleware.Option{ - jwtmiddleware.WithErrorHandler(errorHandler(t)), - } - - if len(test.options) > 0 { - options = append(options, test.options...) - } - - mw, err := Middleware(cfg, options...) + authMiddleware, err := Middleware(cfg, test.options...) require.NoError(t, err) - handler := mw(successHandler) + testMiddleware := alice.New(audit.Middleware(), authMiddleware) + + handler := testMiddleware.Then(successHandler) handler.ServeHTTP(responseRecorder, request) assert.Equal(t, test.wantStatusCode, responseRecorder.Code) assert.Contains(t, responseRecorder.Body.String(), test.wantBodyText) - }) - } -} - -func errorHandler(t *testing.T) jwtmiddleware.ErrorHandler { - return func(w http.ResponseWriter, r *http.Request, err error) { - t.Helper() - t.Logf("error handler called: %s, %v", err.Error(), err) - jwtmiddleware.DefaultErrorHandler(w, r, err) + // once the request has been processed, the audit log should have the necessary details + auditEntry := audit.Log(ctx) + if test.wantStatusCode == http.StatusOK { + assert.True(t, auditEntry.Authorized) + assert.Empty(t, auditEntry.Error) + assert.NotEmpty(t, auditEntry.Issuer) + assert.Equal(t, "subject", auditEntry.Subject) + assert.ElementsMatch(t, []string{"audience"}, auditEntry.Audience) + assert.NotZero(t, auditEntry.Expiry) + } else { + assert.False(t, auditEntry.Authorized) + assert.NotEmpty(t, auditEntry.Error) + assert.Empty(t, auditEntry.Issuer) + assert.Empty(t, auditEntry.Subject) + assert.Empty(t, auditEntry.Audience) + assert.Zero(t, auditEntry.Expiry) + } + }) } } @@ -240,7 +238,7 @@ func setupTestServer(t *testing.T, jwk *jose.JSONWebKey) (server *httptest.Serve } func createRequestJWT(t *testing.T, jwk *jose.JSONWebKey, issuer string, claims ...any) string { - // t.Helper() + t.Helper() key := jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), diff --git a/main.go b/main.go index 0e5a741..9b1b71d 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,6 @@ import ( "strings" "time" - jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" "github.com/jamestelfer/chinmina-bridge/internal/audit" "github.com/jamestelfer/chinmina-bridge/internal/buildkite" "github.com/jamestelfer/chinmina-bridge/internal/config" @@ -31,7 +30,7 @@ func configureServerRoutes(ctx context.Context, cfg config.Config) (http.Handler // configure middleware auditor := audit.Middleware() - authorizer, err := jwt.Middleware(cfg.Authorization, jwtmiddleware.WithErrorHandler(jwt.LogErrorHandler())) + authorizer, err := jwt.Middleware(cfg.Authorization) if err != nil { return nil, fmt.Errorf("authorizer configuration failed: %w", err) }