diff --git a/handler/openid/flow_hybrid.go b/handler/openid/flow_hybrid.go index df0f2f647..8de6fe05a 100644 --- a/handler/openid/flow_hybrid.go +++ b/handler/openid/flow_hybrid.go @@ -23,7 +23,6 @@ package openid import ( "context" - "encoding/base64" "time" "github.com/ory/x/errorsx" @@ -122,11 +121,11 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. resp.AddParameter("code", code) ar.SetResponseTypeHandled("code") - hash, err := c.Enigma.Hash(ctx, []byte(resp.GetParameters().Get("code"))) + hash, err := c.IDTokenHandleHelper.ComputeHash(ctx, sess, resp.GetParameters().Get("code")) if err != nil { return err } - claims.CodeHash = base64.RawURLEncoding.EncodeToString([]byte(hash[:c.Enigma.GetSigningMethodLength()/2])) + claims.CodeHash = hash if ar.GetGrantedScopes().Has("openid") { if err := c.OpenIDConnectRequestStorage.CreateOpenIDConnectSession(ctx, resp.GetCode(), ar.Sanitize(oidcParameters)); err != nil { @@ -143,11 +142,11 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. } ar.SetResponseTypeHandled("token") - hash, err := c.Enigma.Hash(ctx, []byte(resp.GetParameters().Get("access_token"))) + hash, err := c.IDTokenHandleHelper.ComputeHash(ctx, sess, resp.GetParameters().Get("access_token")) if err != nil { return err } - claims.AccessTokenHash = base64.RawURLEncoding.EncodeToString([]byte(hash[:c.Enigma.GetSigningMethodLength()/2])) + claims.AccessTokenHash = hash } if resp.GetParameters().Get("state") == "" { diff --git a/handler/openid/flow_implicit.go b/handler/openid/flow_implicit.go index f3e0639e1..12e860414 100644 --- a/handler/openid/flow_implicit.go +++ b/handler/openid/flow_implicit.go @@ -23,7 +23,6 @@ package openid import ( "context" - "encoding/base64" "github.com/ory/x/errorsx" @@ -93,12 +92,12 @@ func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx contex } ar.SetResponseTypeHandled("token") - hash, err := c.RS256JWTStrategy.Hash(ctx, []byte(resp.GetParameters().Get("access_token"))) + hash, err := c.ComputeHash(ctx, sess, resp.GetParameters().Get("access_token")) if err != nil { return err } - claims.AccessTokenHash = base64.RawURLEncoding.EncodeToString([]byte(hash[:c.RS256JWTStrategy.GetSigningMethodLength()/2])) + claims.AccessTokenHash = hash } else { resp.AddParameter("state", ar.GetState()) } diff --git a/handler/openid/helper.go b/handler/openid/helper.go index ef30c0b98..d18f804e0 100644 --- a/handler/openid/helper.go +++ b/handler/openid/helper.go @@ -25,7 +25,9 @@ import ( "bytes" "context" "crypto/sha256" + "crypto/sha512" "encoding/base64" + "strconv" "github.com/ory/fosite" ) @@ -36,6 +38,16 @@ type IDTokenHandleHelper struct { func (i *IDTokenHandleHelper) GetAccessTokenHash(ctx context.Context, requester fosite.AccessRequester, responder fosite.AccessResponder) string { token := responder.GetAccessToken() + // The session should always be a openid.Session but best to safely cast + if session, ok := requester.GetSession().(Session); ok { + val, err := i.ComputeHash(ctx, session, token) + if err != nil { + // this should never happen + panic(err) + } + + return val + } buffer := bytes.NewBufferString(token) hash := sha256.New() @@ -76,3 +88,27 @@ func (i *IDTokenHandleHelper) IssueExplicitIDToken(ctx context.Context, ar fosit resp.SetExtra("id_token", token) return nil } + +// ComputeHash computes the hash using the alg defined in the id_token header +func (i *IDTokenHandleHelper) ComputeHash(ctx context.Context, sess Session, token string) (string, error) { + var err error + hash := sha256.New() + if alg, ok := sess.IDTokenHeaders().Get("alg").(string); ok && len(alg) > 2 { + if hashSize, err := strconv.Atoi(alg[2:]); err == nil { + if hashSize == 384 { + hash = sha512.New384() + } else if hashSize == 512 { + hash = sha512.New() + } + } + } + + buffer := bytes.NewBufferString(token) + _, err = hash.Write(buffer.Bytes()) + if err != nil { + return "", err + } + hashBuf := bytes.NewBuffer(hash.Sum([]byte{})) + + return base64.RawURLEncoding.EncodeToString(hashBuf.Bytes()[:hashBuf.Len()/2]), nil +} diff --git a/handler/openid/helper_test.go b/handler/openid/helper_test.go index 035924c62..b90653747 100644 --- a/handler/openid/helper_test.go +++ b/handler/openid/helper_test.go @@ -130,6 +130,70 @@ func TestGetAccessTokenHash(t *testing.T) { defer ctrl.Finish() + req.EXPECT().GetSession().Return(nil) + resp.EXPECT().GetAccessToken().Return("7a35f818-9164-48cb-8c8f-e1217f44228431c41102-d410-4ed5-9276-07ba53dfdcd8") + + h := &IDTokenHandleHelper{IDTokenStrategy: strat} + + hash := h.GetAccessTokenHash(nil, req, resp) + assert.Equal(t, "Zfn_XBitThuDJiETU3OALQ", hash) +} + +func TestGetAccessTokenHashWithDifferentKeyLength(t *testing.T) { + ctrl := gomock.NewController(t) + req := internal.NewMockAccessRequester(ctrl) + resp := internal.NewMockAccessResponder(ctrl) + + defer ctrl.Finish() + + headers := &jwt.Headers{ + Extra: map[string]interface{}{ + "alg": "RS384", + }, + } + req.EXPECT().GetSession().Return(&DefaultSession{Headers: headers}) + resp.EXPECT().GetAccessToken().Return("7a35f818-9164-48cb-8c8f-e1217f44228431c41102-d410-4ed5-9276-07ba53dfdcd8") + + h := &IDTokenHandleHelper{IDTokenStrategy: strat} + + hash := h.GetAccessTokenHash(nil, req, resp) + assert.Equal(t, "VNX38yiOyeqBPheW5jDsWQKa6IjJzK66", hash) +} + +func TestGetAccessTokenHashWithBadAlg(t *testing.T) { + ctrl := gomock.NewController(t) + req := internal.NewMockAccessRequester(ctrl) + resp := internal.NewMockAccessResponder(ctrl) + + defer ctrl.Finish() + + headers := &jwt.Headers{ + Extra: map[string]interface{}{ + "alg": "R", + }, + } + req.EXPECT().GetSession().Return(&DefaultSession{Headers: headers}) + resp.EXPECT().GetAccessToken().Return("7a35f818-9164-48cb-8c8f-e1217f44228431c41102-d410-4ed5-9276-07ba53dfdcd8") + + h := &IDTokenHandleHelper{IDTokenStrategy: strat} + + hash := h.GetAccessTokenHash(nil, req, resp) + assert.Equal(t, "Zfn_XBitThuDJiETU3OALQ", hash) +} + +func TestGetAccessTokenHashWithMissingKeyLength(t *testing.T) { + ctrl := gomock.NewController(t) + req := internal.NewMockAccessRequester(ctrl) + resp := internal.NewMockAccessResponder(ctrl) + + defer ctrl.Finish() + + headers := &jwt.Headers{ + Extra: map[string]interface{}{ + "alg": "RS", + }, + } + req.EXPECT().GetSession().Return(&DefaultSession{Headers: headers}) resp.EXPECT().GetAccessToken().Return("7a35f818-9164-48cb-8c8f-e1217f44228431c41102-d410-4ed5-9276-07ba53dfdcd8") h := &IDTokenHandleHelper{IDTokenStrategy: strat}