Skip to content

Commit

Permalink
fix: suppress jwt code duplication and use internal jwt
Browse files Browse the repository at this point in the history
  • Loading branch information
shaj13 committed Feb 28, 2021
1 parent 18d8af4 commit 751bf61
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 81 deletions.
9 changes: 5 additions & 4 deletions auth/strategies/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ import (
func GetAuthenticateFunc(s SecretsKeeper, opts ...auth.Option) token.AuthenticateFunc {
t := newAccessToken(s, opts...)
return func(ctx context.Context, r *http.Request, tk string) (auth.Info, time.Time, error) {
c, err := t.parse(tk)
c, info, err := t.parse(tk)
if err != nil {
return nil, time.Time{}, err
}

if len(c.Scopes) > 0 {
token.WithNamedScopes(c.UserInfo, c.Scopes...)
if len(c.Scope) > 0 {
token.WithNamedScopes(info, c.Scope.Split()...)
}
return c.UserInfo, c.Expiry.Time(), err

return info, time.Time(*c.ExpiresAt), nil
}
}

Expand Down
108 changes: 37 additions & 71 deletions auth/strategies/jwt/token.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package jwt

import (
"errors"
"fmt"
"time"

"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"

"github.com/shaj13/go-guardian/v2/auth"
"github.com/shaj13/go-guardian/v2/auth/claims"
"github.com/shaj13/go-guardian/v2/auth/internal/jwt"
)

const headerKID = "kid"

const (
// EdDSA signature algorithm.
EdDSA = "EdDSA"
Expand All @@ -29,11 +26,11 @@ const (
RS512 = "RS512"
// ES256 signature algorithm -- ECDSA using P-256 and SHA-256.
ES256 = "ES256"
// ES384 signature algorithm -- ECDSA using P-384 and SHA-384.
// ES384 signature algorithm -- ECDSA using P-384 and SHA-384.
ES384 = "ES384"
// ES512 signature algorithm -- ECDSA using P-521 and SHA-512.
// ES512 signature algorithm -- ECDSA using P-521 and SHA-512.
ES512 = "ES512"
// PS256 signature algorithm -- RSASSA-PSS using SHA256 and MGF1-SHA256.
// PS256 signature algorithm -- RSASSA-PSS using SHA256 and MGF1-SHA256.
PS256 = "PS256"
// PS384 signature algorithm -- RSASSA-PSS using SHA384 and MGF1-SHA384.
PS384 = "PS384"
Expand All @@ -44,24 +41,18 @@ const (
var (
// ErrMissingKID is returned by Authenticate Strategy method,
// when failed to retrieve kid from token header.
ErrMissingKID = errors.New("strategies/jwt: Token missing " + headerKID + " header")
ErrMissingKID = jwt.ErrMissingKID

// ErrInvalidAlg is returned by Authenticate Strategy method,
// when jwt token alg header does not match key algorithm.
ErrInvalidAlg = errors.New("strategies/jwt: Invalid signing algorithm, token alg header does not match key algorithm")
ErrInvalidAlg = jwt.ErrInvalidAlg
)

// IssueAccessToken issue jwt access token for the provided user info.
func IssueAccessToken(info auth.Info, s SecretsKeeper, opts ...auth.Option) (string, error) {
return newAccessToken(s, opts...).issue(info)
}

type claims struct {
UserInfo auth.Info `json:"info"`
Scopes []string `json:"scp"`
jwt.Claims
}

type accessToken struct {
keeper SecretsKeeper
dur time.Duration
Expand All @@ -71,76 +62,51 @@ type accessToken struct {
}

func (at accessToken) issue(info auth.Info) (string, error) {
kid := at.keeper.KID()
secret, alg, err := at.keeper.Get(kid)
if err != nil {
return "", err
}

opt := (&jose.SignerOptions{}).WithType("JWT").WithHeader(headerKID, kid)
key := jose.SigningKey{Algorithm: jose.SignatureAlgorithm(alg), Key: secret}
sig, err := jose.NewSigner(key, opt)

if err != nil {
return "", err
}

now := time.Now().UTC()
now := time.Now().UTC().Add(-claims.DefaultLeeway)
exp := now.Add(at.dur)

c := claims{
UserInfo: info,
Scopes: at.scp,
Claims: jwt.Claims{
Subject: info.GetID(),
Issuer: at.iss,
Audience: jwt.Audience{at.aud},
Expiry: jwt.NewNumericDate(exp),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
c := claims.Standard{
Subject: info.GetID(),
Issuer: at.iss,
Audience: claims.StringOrList{at.aud},
ExpiresAt: (*claims.Time)(&exp),
IssuedAt: (*claims.Time)(&now),
NotBefore: (*claims.Time)(&now),
Scope: at.scp,
}

return jwt.Signed(sig).Claims(c).CompactSerialize()
}

func (at accessToken) parse(tstr string) (*claims, error) {
c := &claims{
UserInfo: auth.NewUserInfo("", "", nil, nil),
}

jt, err := jwt.ParseSigned(tstr)
str, err := jwt.IssueToken(at.keeper, c, info)
if err != nil {
return nil, err
return "", fmt.Errorf("strategies/jwt: %w", err)
}

if len(jt.Headers) == 0 {
return nil, errors.New("strategies/jwt: : No headers found in JWT token")
}
return str, nil
}

if len(jt.Headers[0].KeyID) == 0 {
return nil, ErrMissingKID
func (at accessToken) parse(tstr string) (claims.Standard, auth.Info, error) {
fail := func(err error) (claims.Standard, auth.Info, error) {
return claims.Standard{}, nil, fmt.Errorf("strategies/jwt: %w", err)
}

secret, alg, err := at.keeper.Get(jt.Headers[0].KeyID)

if err != nil {
return nil, err
info := auth.NewUserInfo("", "", nil, make(auth.Extensions))
c := claims.Standard{}
opts := claims.VerifyOptions{
Audience: claims.StringOrList{at.aud},
Issuer: at.iss,
Time: func() (t time.Time) {
return time.Now().UTC().Add(-claims.DefaultLeeway)
},
}

if jt.Headers[0].Algorithm != alg {
return nil, ErrInvalidAlg
if err := jwt.ParseToken(at.keeper, tstr, &c, info); err != nil {
return fail(err)
}

if err := jt.Claims(secret, c); err != nil {
return nil, err
if err := c.Verify(opts); err != nil {
return fail(err)
}

return c, c.Validate(jwt.Expected{
Time: time.Now().UTC(),
Issuer: at.iss,
Audience: jwt.Audience{at.aud},
})
return c, info, nil
}

func newAccessToken(s SecretsKeeper, opts ...auth.Option) *accessToken {
Expand Down
12 changes: 6 additions & 6 deletions auth/strategies/jwt/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ func TestToken(t *testing.T) {
str, err := tk.issue(info)
assert.NoError(t, err)

c, err := tk.parse(str)
_, u, err := tk.parse(str)
assert.NoError(t, err)
assert.Equal(t, c.UserInfo, info)
assert.Equal(t, u, info)
}

func TestTokenAlg(t *testing.T) {
Expand All @@ -68,15 +68,15 @@ func TestTokenAlg(t *testing.T) {
assert.NoError(t, err)

tk.keeper = hs256
_, err = tk.parse(str)
assert.Equal(t, ErrInvalidAlg, err)
_, _, err = tk.parse(str)
assert.Contains(t, err.Error(), ErrInvalidAlg.Error())
}

func TestTokenKID(t *testing.T) {
str := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.P4Lqll22jQQJ1eMJikvNg5HKG-cKB0hUZA9BZFIG7Jk"
tk := newAccessToken(nil)
_, err := tk.parse(str)
assert.Equal(t, ErrMissingKID, err)
_, _, err := tk.parse(str)
assert.Contains(t, err.Error(), ErrMissingKID.Error())
}

func TestNewToken(t *testing.T) {
Expand Down

0 comments on commit 751bf61

Please sign in to comment.