diff --git a/server/auth/jwt.go b/server/auth/jwt.go index dfaefa59ecc..a797794494b 100644 --- a/server/auth/jwt.go +++ b/server/auth/jwt.go @@ -42,7 +42,7 @@ func (t *tokenJWT) info(ctx context.Context, token string, rev uint64) (*AuthInf // rev isn't used in JWT, it is only used in simple token var ( username string - revision uint64 + revision float64 ) parsed, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { @@ -74,10 +74,19 @@ func (t *tokenJWT) info(ctx context.Context, token string, rev uint64) (*AuthInf return nil, false } - username = claims["username"].(string) - revision = uint64(claims["revision"].(float64)) + username, ok = claims["username"].(string) + if !ok { + t.lg.Warn("failed to obtain user claims from jwt token") + return nil, false + } + + revision, ok = claims["revision"].(float64) + if !ok { + t.lg.Warn("failed to obtain revision claims from jwt token") + return nil, false + } - return &AuthInfo{Username: username, Revision: revision}, true + return &AuthInfo{Username: username, Revision: uint64(revision)}, true } func (t *tokenJWT) assign(ctx context.Context, username string, revision uint64) (string, error) { diff --git a/server/auth/jwt_test.go b/server/auth/jwt_test.go index a3983cc5a56..df431c30b09 100644 --- a/server/auth/jwt_test.go +++ b/server/auth/jwt_test.go @@ -18,7 +18,10 @@ import ( "context" "fmt" "testing" + "time" + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -202,3 +205,75 @@ func TestJWTBad(t *testing.T) { func testJWTOpts() string { return fmt.Sprintf("%s,pub-key=%s,priv-key=%s,sign-method=RS256", tokenTypeJWT, jwtRSAPubKey, jwtRSAPrivKey) } + +func TestJWTTokenWithMissingFields(t *testing.T) { + testCases := []struct { + name string + username string // An empty string means not present + revision uint64 // 0 means not present + expectValid bool + }{ + { + name: "valid token", + username: "hello", + revision: 100, + expectValid: true, + }, + { + name: "no username", + username: "", + revision: 100, + expectValid: false, + }, + { + name: "no revision", + username: "hello", + revision: 0, + expectValid: false, + }, + } + + for _, tc := range testCases { + tc := tc + optsMap := map[string]string{ + "priv-key": jwtRSAPrivKey, + "sign-method": "RS256", + "ttl": "1h", + } + + t.Run(tc.name, func(t *testing.T) { + // prepare claims + claims := jwt.MapClaims{ + "exp": time.Now().Add(time.Hour).Unix(), + } + if tc.username != "" { + claims["username"] = tc.username + } + if tc.revision != 0 { + claims["revision"] = tc.revision + } + + // generate a JWT token with the given claims + var opts jwtOptions + err := opts.ParseWithDefaults(optsMap) + require.NoError(t, err) + key, err := opts.Key() + require.NoError(t, err) + + tk := jwt.NewWithClaims(opts.SignMethod, claims) + token, err := tk.SignedString(key) + require.NoError(t, err) + + // verify the token + jwtProvider, err := newTokenProviderJWT(zap.NewNop(), optsMap) + require.NoError(t, err) + ai, ok := jwtProvider.info(context.TODO(), token, 123) + + require.Equal(t, tc.expectValid, ok) + if ok { + require.Equal(t, tc.username, ai.Username) + require.Equal(t, tc.revision, ai.Revision) + } + }) + } +}