Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support more claims in password grant #3864

Merged
merged 11 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions consent/strategy_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ func TestStrategyLoginConsentNext(t *testing.T) {
makeRequestAndExpectCode(t, hc, c, url.Values{})

// Make request with additional scope and prompt none, which fails
makeRequestAndExpectError(t, hc, c, url.Values{"prompt": {"none"}, "scope": {"openid"}},
makeRequestAndExpectError(t, hc, c, url.Values{"prompt": {"none"}, "scope": {"openid"}, "redirect_uri": {c.RedirectURIs[0]}},
"Prompt 'none' was requested, but no previous consent was found")
})

Expand Down Expand Up @@ -930,11 +930,11 @@ func TestStrategyLoginConsentNext(t *testing.T) {
}{
{
d: "check all the sub claims",
values: url.Values{"scope": {"openid"}},
values: url.Values{"scope": {"openid"}, "redirect_uri": {c.RedirectURIs[0]}},
},
{
d: "works with id_token_hint",
values: url.Values{"scope": {"openid"}, "id_token_hint": {testhelpers.NewIDToken(t, reg, hash)}},
values: url.Values{"scope": {"openid"}, "redirect_uri": {c.RedirectURIs[0]}, "id_token_hint": {testhelpers.NewIDToken(t, reg, hash)}},
},
} {
t.Run("case="+tc.d, func(t *testing.T) {
Expand Down Expand Up @@ -974,7 +974,7 @@ func TestStrategyLoginConsentNext(t *testing.T) {
}),
acceptConsentHandler(t, &hydra.AcceptOAuth2ConsentRequest{GrantScope: []string{"openid"}}))

code := makeRequestAndExpectCode(t, nil, c, url.Values{})
code := makeRequestAndExpectCode(t, nil, c, url.Values{"redirect_uri": {c.RedirectURIs[0]}})

conf := oauth2Config(t, c)
token, err := conf.Exchange(context.Background(), code)
Expand Down
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ replace github.com/ory/hydra-client-go/v2 => ./internal/httpclient

replace github.com/gobuffalo/pop/v6 => github.com/ory/pop/v6 v6.2.0

// Bump Fosite to https://github.com/ory/fosite/tree/hperl/v0.47.0%2B168636f, which contains
// https://github.com/ory/fosite/commit/b40b1cbb1997e2160eaaf97fb6f73960db4c6118 on top of the latest release.
replace github.com/ory/fosite => github.com/ory/fosite v0.47.1-0.20241030092116-b40b1cbb1997

require (
github.com/ThalesIgnite/crypto11 v1.2.5
github.com/bradleyjkemp/cupaloy/v2 v2.8.0
Expand Down Expand Up @@ -69,8 +73,6 @@ require (
golang.org/x/tools v0.23.0
)

require github.com/hashicorp/go-cleanhttp v0.5.2 // indirect

require (
code.dny.dev/ssrf v0.2.0 // indirect
dario.cat/mergo v1.0.0 // indirect
Expand Down Expand Up @@ -147,6 +149,7 @@ require (
github.com/gorilla/websocket v1.5.0 // indirect
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/huandu/xstrings v1.4.0 // indirect
github.com/imdario/mergo v0.3.16 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,12 @@ github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5
github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk=
github.com/ory/fosite v0.47.0 h1:Iqu5uhx54JqZQPn2hRhqjESrmRRyQb00uJjfEi1a1QI=
github.com/ory/fosite v0.47.0/go.mod h1:5U6c9nOLxyTdD/qrFv7N88TSxkdk5Wq8NzvB7UViDP0=
github.com/ory/fosite v0.47.1-0.20241029112424-62f07ce22e57 h1:/eMox8UstN3u1r6YfVpIdiXhuz9y+ESPBUzlEHsK4AU=
github.com/ory/fosite v0.47.1-0.20241029112424-62f07ce22e57/go.mod h1:LC+0FyghTTjdSAznGVbtj0yK2nq0LAElh6TbMck8diA=
github.com/ory/fosite v0.47.1-0.20241029134014-168636ff33c7 h1:QyLWLIUgC32pPrHoeW82xlkDiIL2j2o2vq64y5SsLRM=
github.com/ory/fosite v0.47.1-0.20241029134014-168636ff33c7/go.mod h1:LC+0FyghTTjdSAznGVbtj0yK2nq0LAElh6TbMck8diA=
github.com/ory/fosite v0.47.1-0.20241030092116-b40b1cbb1997 h1:dryAvfoAFa1hYn6C0SPmISglYn+S775XOZgCCm54tbw=
github.com/ory/fosite v0.47.1-0.20241030092116-b40b1cbb1997/go.mod h1:5U6c9nOLxyTdD/qrFv7N88TSxkdk5Wq8NzvB7UViDP0=
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc=
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI=
github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8=
Expand Down
14 changes: 8 additions & 6 deletions internal/kratos/fake_kratos.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"

"github.com/ory/fosite"
client "github.com/ory/kratos-client-go"
)

type (
Expand All @@ -17,9 +18,10 @@ type (
)

const (
FakeSessionID = "fake-kratos-session-id"
FakeUsername = "fake-kratos-username"
FakePassword = "fake-kratos-password" // nolint: gosec
FakeSessionID = "fake-kratos-session-id"
FakeUsername = "fake-kratos-username"
FakePassword = "fake-kratos-password" // nolint: gosec
FakeIdentityID = "fake-kratos-identity-id"
)

var _ Client = new(FakeKratos)
Expand All @@ -35,11 +37,11 @@ func (f *FakeKratos) DisableSession(_ context.Context, identityProviderSessionID
return nil
}

func (f *FakeKratos) Authenticate(_ context.Context, username, password string) error {
func (f *FakeKratos) Authenticate(_ context.Context, username, password string) (*client.Session, error) {
if username == FakeUsername && password == FakePassword {
return nil
return &client.Session{Identity: &client.Identity{Id: FakeIdentityID}}, nil
}
return fosite.ErrNotFound
return nil, fosite.ErrNotFound
}

func (f *FakeKratos) Reset() {
Expand Down
14 changes: 7 additions & 7 deletions internal/kratos/kratos.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type (
}
Client interface {
DisableSession(ctx context.Context, identityProviderSessionID string) error
Authenticate(ctx context.Context, name, secret string) error
Authenticate(ctx context.Context, name, secret string) (*client.Session, error)
}
Default struct {
dependencies
Expand All @@ -42,7 +42,7 @@ func New(d dependencies) Client {
return &Default{dependencies: d}
}

func (k *Default) Authenticate(ctx context.Context, name, secret string) (err error) {
func (k *Default) Authenticate(ctx context.Context, name, secret string) (session *client.Session, err error) {
ctx, span := k.Tracer(ctx).Tracer().Start(ctx, "kratos.Authenticate")
otelx.End(span, &err)

Expand All @@ -52,28 +52,28 @@ func (k *Default) Authenticate(ctx context.Context, name, secret string) (err er
span.SetAttributes(attribute.Bool("skipped", true))
span.SetAttributes(attribute.String("reason", "kratos public url not set"))

return errors.New("kratos public url not set")
return nil, errors.New("kratos public url not set")
}

kratos := k.newKratosClient(ctx, publicURL)

flow, _, err := kratos.FrontendAPI.CreateNativeLoginFlow(ctx).Execute()
if err != nil {
return err
return nil, err
}

_, _, err = kratos.FrontendAPI.UpdateLoginFlow(ctx).Flow(flow.Id).UpdateLoginFlowBody(client.UpdateLoginFlowBody{
res, _, err := kratos.FrontendAPI.UpdateLoginFlow(ctx).Flow(flow.Id).UpdateLoginFlowBody(client.UpdateLoginFlowBody{
UpdateLoginFlowWithPasswordMethod: &client.UpdateLoginFlowWithPasswordMethod{
Method: "password",
Identifier: name,
Password: secret,
},
}).Execute()
if err != nil {
return fosite.ErrNotFound.WithWrap(err)
return nil, fosite.ErrNotFound.WithWrap(err)
}

return nil
return &res.Session, nil
}

func (k *Default) DisableSession(ctx context.Context, identityProviderSessionID string) (err error) {
Expand Down
17 changes: 15 additions & 2 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,8 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
}

if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeClientCredentials)) ||
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeJWTBearer)) {
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeJWTBearer)) ||
accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypePassword)) {
var accessTokenKeyID string
if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(accessRequest.GetClient())) == "jwt" {
accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx)
Expand All @@ -975,9 +976,21 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
}

// only for client_credentials, otherwise Authentication is included in session
if accessRequest.GetGrantTypes().ExactOne("client_credentials") {
if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeClientCredentials)) {
session.Subject = accessRequest.GetClient().GetID()
}
// only for password grant, otherwise Authentication is included in session
if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypePassword)) {
if sess, ok := accessRequest.GetSession().(fosite.ExtraClaimsSession); ok {
sess.GetExtraClaims()["username"] = accessRequest.GetRequestForm().Get("username")
session.DefaultSession.Username = accessRequest.GetRequestForm().Get("username")
}

// Also add audience claims
for _, aud := range accessRequest.GetClient().GetAudience() {
accessRequest.GrantAudience(aud)
}
}
session.ClientID = accessRequest.GetClient().GetID()
session.KID = accessTokenKeyID
session.DefaultSession.Claims.Issuer = h.c.IssuerURL(ctx).String()
Expand Down
106 changes: 101 additions & 5 deletions oauth2/oauth2_rop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,30 @@ package oauth2_test

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"

"github.com/ory/fosite/compose"
"github.com/ory/fosite/token/jwt"
hydra "github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/fositex"
"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/hydra/v2/internal/testhelpers"
hydraoauth2 "github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
"github.com/ory/x/sqlxx"
)

func TestResourceOwnerPasswordGrant(t *testing.T) {
Expand All @@ -27,12 +37,19 @@ func TestResourceOwnerPasswordGrant(t *testing.T) {
reg := internal.NewMockedRegistry(t, &contextx.Default{})
reg.WithKratos(fakeKratos)
reg.WithExtraFositeFactories([]fositex.Factory{compose.OAuth2ResourceOwnerPasswordCredentialsFactory})
_, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg)
publicTS, adminTS := testhelpers.NewOAuth2Server(ctx, t, reg)

secret := uuid.New().String()
audience := sqlxx.StringSliceJSONFormat{"https://aud.example.com"}
client := &hydra.Client{
Secret: secret,
GrantTypes: []string{"password"},
GrantTypes: []string{"password", "refresh_token"},
Scope: "offline",
Audience: audience,
Lifespans: hydra.Lifespans{
PasswordGrantAccessTokenLifespan: x.NullDuration{Duration: 1 * time.Hour, Valid: true},
PasswordGrantRefreshTokenLifespan: x.NullDuration{Duration: 1 * time.Hour, Valid: true},
},
}
require.NoError(t, reg.ClientManager().CreateClient(ctx, client))

Expand All @@ -44,15 +61,94 @@ func TestResourceOwnerPasswordGrant(t *testing.T) {
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: []string{"offline"},
}

hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")
assert.Equal(t, r.Header.Get("Authorization"), "Bearer secret value")

var hookReq hydraoauth2.TokenHookRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq))
assert.NotEmpty(t, hookReq.Session)
assert.NotEmpty(t, hookReq.Request)

claims := hookReq.Session.Extra
claims["hooked"] = true
if hookReq.Request.GrantTypes[0] == "refresh_token" {
claims["refreshed"] = true
}

hookResp := hydraoauth2.TokenHookResponse{
Session: flow.AcceptOAuth2ConsentRequestSession{
AccessToken: claims,
IDToken: claims,
},
}

w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(&hookResp))
}))
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyTokenHook, &config.HookConfig{
URL: hs.URL,
Auth: &config.Auth{
Type: "api_key",
Config: config.AuthConfig{
In: "header",
Name: "Authorization",
Value: "Bearer secret value",
},
},
})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "jwt")

t.Run("case=get ROP grant token with valid username and password", func(t *testing.T) {
token, err := oauth2Config.PasswordCredentialsToken(ctx, kratos.FakeUsername, kratos.FakePassword)
require.NoError(t, err)
require.NotEmpty(t, token.AccessToken)
i := testhelpers.IntrospectToken(t, oauth2Config, token.AccessToken, adminTS)
assert.True(t, i.Get("active").Bool(), "%s", i)
assert.EqualValues(t, oauth2Config.ClientID, i.Get("client_id").String(), "%s", i)

// Access token should have hook and identity_id claims
jwtAT, err := jwt.Parse(token.AccessToken, func(token *jwt.Token) (interface{}, error) {
return reg.AccessTokenJWTStrategy().GetPublicKey(ctx)
})
require.NoError(t, err)
assert.Equal(t, kratos.FakeUsername, jwtAT.Claims["ext"].(map[string]any)["username"])
assert.Equal(t, kratos.FakeIdentityID, jwtAT.Claims["sub"])
assert.Equal(t, publicTS.URL, jwtAT.Claims["iss"])
assert.True(t, jwtAT.Claims["ext"].(map[string]any)["hooked"].(bool))
assert.ElementsMatch(t, audience, jwtAT.Claims["aud"])

t.Run("case=introspect token", func(t *testing.T) {
// Introspected token should have hook and identity_id claims
i := testhelpers.IntrospectToken(t, oauth2Config, token.AccessToken, adminTS)
assert.True(t, i.Get("active").Bool(), "%s", i)
assert.Equal(t, kratos.FakeUsername, i.Get("ext.username").String(), "%s", i)
assert.Equal(t, kratos.FakeIdentityID, i.Get("sub").String(), "%s", i)
assert.True(t, i.Get("ext.hooked").Bool(), "%s", i)
assert.EqualValues(t, oauth2Config.ClientID, i.Get("client_id").String(), "%s", i)
})

t.Run("case=refresh token", func(t *testing.T) {
// Refreshed access token should have hook and identity_id claims
require.NotEmpty(t, token.RefreshToken)
token.Expiry = token.Expiry.Add(-time.Hour * 24)
refreshedToken, err := oauth2Config.TokenSource(context.Background(), token).Token()
require.NoError(t, err)

require.NotEqual(t, token.AccessToken, refreshedToken.AccessToken)
require.NotEqual(t, token.RefreshToken, refreshedToken.RefreshToken)

jwtAT, err := jwt.Parse(refreshedToken.AccessToken, func(token *jwt.Token) (interface{}, error) {
return reg.AccessTokenJWTStrategy().GetPublicKey(ctx)
})
require.NoError(t, err)
assert.Equal(t, kratos.FakeIdentityID, jwtAT.Claims["sub"])
assert.Equal(t, kratos.FakeUsername, jwtAT.Claims["ext"].(map[string]any)["username"])
assert.True(t, jwtAT.Claims["ext"].(map[string]any)["hooked"].(bool))
assert.True(t, jwtAT.Claims["ext"].(map[string]any)["refreshed"].(bool))
})
})

t.Run("case=access denied for invalid password", func(t *testing.T) {
Expand Down
14 changes: 14 additions & 0 deletions oauth2/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,17 @@ func (s *Session) UnmarshalJSON(original []byte) (err error) {

return nil
}

// GetExtraClaims implements ExtraClaimsSession for Session.
// The returned value can be modified in-place.
func (s *Session) GetExtraClaims() map[string]interface{} {
if s == nil {
return nil
}

if s.Extra == nil {
s.Extra = make(map[string]interface{})
}

return s.Extra
}
12 changes: 9 additions & 3 deletions persistence/sql/persister_authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@

package sql

import "context"
import (
"context"
)

func (p *Persister) Authenticate(ctx context.Context, name, secret string) error {
return p.r.Kratos().Authenticate(ctx, name, secret)
func (p *Persister) Authenticate(ctx context.Context, name, secret string) (subject string, err error) {
session, err := p.r.Kratos().Authenticate(ctx, name, secret)
if err != nil {
return "", err
}
return session.Identity.Id, nil
}
Loading