diff --git a/auth/strategies/jwt/jwt.go b/auth/strategies/jwt/jwt.go index ba3822f..3521b4d 100644 --- a/auth/strategies/jwt/jwt.go +++ b/auth/strategies/jwt/jwt.go @@ -16,15 +16,16 @@ import ( func GetAuthenticateFunc(s SecretsKeeper, opts ...auth.Option) token.AuthenticateFunc { t := newAccessToken(s, opts...) return func(ctx context.Context, r *http.Request, tk string) (auth.Info, time.Time, error) { - c, err := t.parse(tk) + c, info, err := t.parse(tk) if err != nil { return nil, time.Time{}, err } - if len(c.Scopes) > 0 { - token.WithNamedScopes(c.UserInfo, c.Scopes...) + if len(c.Scope) > 0 { + token.WithNamedScopes(info, c.Scope.Split()...) } - return c.UserInfo, c.Expiry.Time(), err + + return info, time.Time(*c.ExpiresAt), nil } } diff --git a/auth/strategies/jwt/token.go b/auth/strategies/jwt/token.go index 3ebefd8..3ebb4b3 100644 --- a/auth/strategies/jwt/token.go +++ b/auth/strategies/jwt/token.go @@ -1,17 +1,14 @@ package jwt import ( - "errors" + "fmt" "time" - "gopkg.in/square/go-jose.v2" - "gopkg.in/square/go-jose.v2/jwt" - "github.com/shaj13/go-guardian/v2/auth" + "github.com/shaj13/go-guardian/v2/auth/claims" + "github.com/shaj13/go-guardian/v2/auth/internal/jwt" ) -const headerKID = "kid" - const ( // EdDSA signature algorithm. EdDSA = "EdDSA" @@ -29,11 +26,11 @@ const ( RS512 = "RS512" // ES256 signature algorithm -- ECDSA using P-256 and SHA-256. ES256 = "ES256" - // ES384 signature algorithm -- ECDSA using P-384 and SHA-384. + // ES384 signature algorithm -- ECDSA using P-384 and SHA-384. ES384 = "ES384" - // ES512 signature algorithm -- ECDSA using P-521 and SHA-512. + // ES512 signature algorithm -- ECDSA using P-521 and SHA-512. ES512 = "ES512" - // PS256 signature algorithm -- RSASSA-PSS using SHA256 and MGF1-SHA256. + // PS256 signature algorithm -- RSASSA-PSS using SHA256 and MGF1-SHA256. PS256 = "PS256" // PS384 signature algorithm -- RSASSA-PSS using SHA384 and MGF1-SHA384. PS384 = "PS384" @@ -44,11 +41,11 @@ const ( var ( // ErrMissingKID is returned by Authenticate Strategy method, // when failed to retrieve kid from token header. - ErrMissingKID = errors.New("strategies/jwt: Token missing " + headerKID + " header") + ErrMissingKID = jwt.ErrMissingKID // ErrInvalidAlg is returned by Authenticate Strategy method, // when jwt token alg header does not match key algorithm. - ErrInvalidAlg = errors.New("strategies/jwt: Invalid signing algorithm, token alg header does not match key algorithm") + ErrInvalidAlg = jwt.ErrInvalidAlg ) // IssueAccessToken issue jwt access token for the provided user info. @@ -56,12 +53,6 @@ func IssueAccessToken(info auth.Info, s SecretsKeeper, opts ...auth.Option) (str return newAccessToken(s, opts...).issue(info) } -type claims struct { - UserInfo auth.Info `json:"info"` - Scopes []string `json:"scp"` - jwt.Claims -} - type accessToken struct { keeper SecretsKeeper dur time.Duration @@ -71,76 +62,51 @@ type accessToken struct { } func (at accessToken) issue(info auth.Info) (string, error) { - kid := at.keeper.KID() - secret, alg, err := at.keeper.Get(kid) - if err != nil { - return "", err - } - - opt := (&jose.SignerOptions{}).WithType("JWT").WithHeader(headerKID, kid) - key := jose.SigningKey{Algorithm: jose.SignatureAlgorithm(alg), Key: secret} - sig, err := jose.NewSigner(key, opt) - - if err != nil { - return "", err - } - - now := time.Now().UTC() + now := time.Now().UTC().Add(-claims.DefaultLeeway) exp := now.Add(at.dur) - c := claims{ - UserInfo: info, - Scopes: at.scp, - Claims: jwt.Claims{ - Subject: info.GetID(), - Issuer: at.iss, - Audience: jwt.Audience{at.aud}, - Expiry: jwt.NewNumericDate(exp), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, + c := claims.Standard{ + Subject: info.GetID(), + Issuer: at.iss, + Audience: claims.StringOrList{at.aud}, + ExpiresAt: (*claims.Time)(&exp), + IssuedAt: (*claims.Time)(&now), + NotBefore: (*claims.Time)(&now), + Scope: at.scp, } - return jwt.Signed(sig).Claims(c).CompactSerialize() -} - -func (at accessToken) parse(tstr string) (*claims, error) { - c := &claims{ - UserInfo: auth.NewUserInfo("", "", nil, nil), - } - - jt, err := jwt.ParseSigned(tstr) + str, err := jwt.IssueToken(at.keeper, c, info) if err != nil { - return nil, err + return "", fmt.Errorf("strategies/jwt: %w", err) } - if len(jt.Headers) == 0 { - return nil, errors.New("strategies/jwt: : No headers found in JWT token") - } + return str, nil +} - if len(jt.Headers[0].KeyID) == 0 { - return nil, ErrMissingKID +func (at accessToken) parse(tstr string) (claims.Standard, auth.Info, error) { + fail := func(err error) (claims.Standard, auth.Info, error) { + return claims.Standard{}, nil, fmt.Errorf("strategies/jwt: %w", err) } - secret, alg, err := at.keeper.Get(jt.Headers[0].KeyID) - - if err != nil { - return nil, err + info := auth.NewUserInfo("", "", nil, make(auth.Extensions)) + c := claims.Standard{} + opts := claims.VerifyOptions{ + Audience: claims.StringOrList{at.aud}, + Issuer: at.iss, + Time: func() (t time.Time) { + return time.Now().UTC().Add(-claims.DefaultLeeway) + }, } - if jt.Headers[0].Algorithm != alg { - return nil, ErrInvalidAlg + if err := jwt.ParseToken(at.keeper, tstr, &c, info); err != nil { + return fail(err) } - if err := jt.Claims(secret, c); err != nil { - return nil, err + if err := c.Verify(opts); err != nil { + return fail(err) } - return c, c.Validate(jwt.Expected{ - Time: time.Now().UTC(), - Issuer: at.iss, - Audience: jwt.Audience{at.aud}, - }) + return c, info, nil } func newAccessToken(s SecretsKeeper, opts ...auth.Option) *accessToken { diff --git a/auth/strategies/jwt/token_test.go b/auth/strategies/jwt/token_test.go index db6ead3..956b700 100644 --- a/auth/strategies/jwt/token_test.go +++ b/auth/strategies/jwt/token_test.go @@ -42,9 +42,9 @@ func TestToken(t *testing.T) { str, err := tk.issue(info) assert.NoError(t, err) - c, err := tk.parse(str) + _, u, err := tk.parse(str) assert.NoError(t, err) - assert.Equal(t, c.UserInfo, info) + assert.Equal(t, u, info) } func TestTokenAlg(t *testing.T) { @@ -68,15 +68,15 @@ func TestTokenAlg(t *testing.T) { assert.NoError(t, err) tk.keeper = hs256 - _, err = tk.parse(str) - assert.Equal(t, ErrInvalidAlg, err) + _, _, err = tk.parse(str) + assert.Contains(t, err.Error(), ErrInvalidAlg.Error()) } func TestTokenKID(t *testing.T) { str := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.P4Lqll22jQQJ1eMJikvNg5HKG-cKB0hUZA9BZFIG7Jk" tk := newAccessToken(nil) - _, err := tk.parse(str) - assert.Equal(t, ErrMissingKID, err) + _, _, err := tk.parse(str) + assert.Contains(t, err.Error(), ErrMissingKID.Error()) } func TestNewToken(t *testing.T) {