From ced8b621b29fdca9ca8a07a74473fadb7dcd4d31 Mon Sep 17 00:00:00 2001 From: Nikos Date: Fri, 15 Nov 2024 13:03:28 +0200 Subject: [PATCH] refactor: update device session persistence logic --- consent/handler_test.go | 21 +- consent/manager.go | 3 - consent/strategy_default.go | 14 +- ..._token_hook_if_configured-hook=legacy.json | 3 +- ...esh_token_hook_if_configured-hook=new.json | 3 +- ..._token_hook_if_configured-hook=legacy.json | 3 +- ...esh_token_hook_if_configured-hook=new.json | 3 +- ..._token_hook_if_configured-hook=legacy.json | 3 +- ...esh_token_hook_if_configured-hook=new.json | 3 +- ..._token_hook_if_configured-hook=legacy.json | 3 +- ...esh_token_hook_if_configured-hook=new.json | 3 +- ..._token_hook_if_configured-hook=legacy.json | 3 +- ...esh_token_hook_if_configured-hook=new.json | 3 +- ..._token_hook_if_configured-hook=legacy.json | 3 +- ...esh_token_hook_if_configured-hook=new.json | 3 +- .../TestUnmarshalSession-v1.11.8.json | 3 +- .../TestUnmarshalSession-v1.11.9.json | 3 +- oauth2/fixtures/v1.11.8-session.json | 3 +- oauth2/fixtures/v1.11.9-session.json | 3 +- oauth2/handler.go | 20 +- oauth2/oauth2_device_code_test.go | 17 +- oauth2/session.go | 9 - oauth2/session_test.go | 1 - persistence/sql/persister_device.go | 296 ++++++++++++++++++ persistence/sql/persister_oauth2.go | 149 +-------- x/clean_sql.go | 6 +- x/fosite_storer.go | 6 +- 27 files changed, 358 insertions(+), 232 deletions(-) create mode 100644 persistence/sql/persister_device.go diff --git a/consent/handler_test.go b/consent/handler_test.go index 73745bfd46d..f11a9aabd09 100644 --- a/consent/handler_test.go +++ b/consent/handler_test.go @@ -341,12 +341,13 @@ func TestAcceptDeviceRequest(t *testing.T) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, }, ) + _, deviceCodesig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(ctx) + require.NoError(t, err) userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(ctx) require.NoError(t, err) - reg.OAuth2Storage().CreateUserCodeSession(ctx, sig, deviceRequest) + reg.OAuth2Storage().CreateDeviceAuthSession(ctx, deviceCodesig, sig, deviceRequest) require.NoError(t, err) acceptUserCode := &hydra.AcceptDeviceUserCodeRequest{UserCode: &userCode} @@ -405,12 +406,13 @@ func TestAcceptDuplicateDeviceRequest(t *testing.T) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, }, ) + _, deviceCodesig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(ctx) + require.NoError(t, err) userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(ctx) require.NoError(t, err) - reg.OAuth2Storage().CreateUserCodeSession(ctx, sig, deviceRequest) + reg.OAuth2Storage().CreateDeviceAuthSession(ctx, deviceCodesig, sig, deviceRequest) require.NoError(t, err) acceptUserCode := &hydra.AcceptDeviceUserCodeRequest{UserCode: &userCode} @@ -490,7 +492,6 @@ func TestAcceptCodeDeviceRequestFailure(t *testing.T) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, }, ) userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(ctx) @@ -514,7 +515,6 @@ func TestAcceptCodeDeviceRequestFailure(t *testing.T) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, }, ) userCode := "" @@ -537,7 +537,6 @@ func TestAcceptCodeDeviceRequestFailure(t *testing.T) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, }, ) userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(ctx) @@ -561,7 +560,6 @@ func TestAcceptCodeDeviceRequestFailure(t *testing.T) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, }, ) userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(ctx) @@ -585,9 +583,10 @@ func TestAcceptCodeDeviceRequestFailure(t *testing.T) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, }, ) + _, deviceCodesig, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(ctx) + require.NoError(t, err) userCode, sig, err := reg.RFC8628HMACStrategy().GenerateUserCode(ctx) require.NoError(t, err) deviceRequest.SetSession( @@ -595,12 +594,11 @@ func TestAcceptCodeDeviceRequestFailure(t *testing.T) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, }, ) exp := time.Now().UTC() deviceRequest.Session.SetExpiresAt(fosite.UserCode, exp) - err = reg.OAuth2Storage().CreateUserCodeSession(ctx, sig, deviceRequest) + err = reg.OAuth2Storage().CreateDeviceAuthSession(ctx, deviceCodesig, sig, deviceRequest) require.NoError(t, err) return json.Marshal(&hydra.AcceptDeviceUserCodeRequest{UserCode: &userCode}) }, @@ -624,7 +622,6 @@ func TestAcceptCodeDeviceRequestFailure(t *testing.T) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, }, ) userCode, _, err := reg.RFC8628HMACStrategy().GenerateUserCode(ctx) diff --git a/consent/manager.go b/consent/manager.go index 577fffa27f1..f09c803c06b 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -6,7 +6,6 @@ package consent import ( "context" - "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/ory/hydra/v2/client" @@ -66,8 +65,6 @@ type ( GetDeviceUserAuthRequest(ctx context.Context, challenge string) (*flow.DeviceUserAuthRequest, error) HandleDeviceUserAuthRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error) VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context, verifier string) (*flow.HandledDeviceUserAuthRequest, error) - - Transaction(context.Context, func(ctx context.Context, c *pop.Connection) error) error } ManagerProvider interface { diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 027c644ce2e..7b7ef300336 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -14,7 +14,6 @@ import ( "strings" "time" - "github.com/gobuffalo/pop/v6" "github.com/gorilla/sessions" "github.com/hashicorp/go-retryablehttp" "github.com/pborman/uuid" @@ -1245,15 +1244,10 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest( var consentSession *flow.AcceptOAuth2ConsentRequest var f *flow.Flow - err = s.r.ConsentManager().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - consentSession, f, err = s.verifyConsent(ctx, w, r, consentVerifier) - if err != nil { - return err - } - err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(ctx, string(f.DeviceCodeRequestID), f.ID) - - return err - }) + consentSession, f, err = s.verifyConsent(ctx, w, r, consentVerifier) + if err != nil { + return nil, nil, err + } return consentSession, f, err } diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json index a687c788d99..61dfba78726 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -30,8 +30,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "requester": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json index 4be660d13cc..3748c3744f1 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json @@ -31,8 +31,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "request": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json index a687c788d99..61dfba78726 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -30,8 +30,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "requester": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json index 4be660d13cc..3748c3744f1 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json @@ -31,8 +31,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "request": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json index a687c788d99..61dfba78726 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -30,8 +30,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "requester": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json index 4be660d13cc..3748c3744f1 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json @@ -31,8 +31,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "request": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json index a687c788d99..61dfba78726 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -30,8 +30,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "requester": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json index 4be660d13cc..3748c3744f1 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json @@ -31,8 +31,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "request": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json index a687c788d99..61dfba78726 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -30,8 +30,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "requester": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json index 4be660d13cc..3748c3744f1 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json @@ -31,8 +31,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "request": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json index a687c788d99..61dfba78726 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -30,8 +30,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "requester": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json index 4be660d13cc..3748c3744f1 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json @@ -31,8 +31,7 @@ "consent_challenge": "", "exclude_not_before_claim": false, "allowed_top_level_claims": [], - "mirror_top_level_claims": true, - "browser_flow_completed": false + "mirror_top_level_claims": true }, "request": { "client_id": "app-client", diff --git a/oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json b/oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json index 341c3556d09..03e8881ee72 100644 --- a/oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json +++ b/oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json @@ -47,6 +47,5 @@ "zone", "login_session_id" ], - "mirror_top_level_claims": false, - "browser_flow_completed": false + "mirror_top_level_claims": false } diff --git a/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json b/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json index 341c3556d09..03e8881ee72 100644 --- a/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json +++ b/oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json @@ -47,6 +47,5 @@ "zone", "login_session_id" ], - "mirror_top_level_claims": false, - "browser_flow_completed": false + "mirror_top_level_claims": false } diff --git a/oauth2/fixtures/v1.11.8-session.json b/oauth2/fixtures/v1.11.8-session.json index 8f7f9a13125..4608026d74e 100644 --- a/oauth2/fixtures/v1.11.8-session.json +++ b/oauth2/fixtures/v1.11.8-session.json @@ -44,6 +44,5 @@ "market", "zone", "login_session_id" - ], - "BrowserFlowCompleted": false + ] } diff --git a/oauth2/fixtures/v1.11.9-session.json b/oauth2/fixtures/v1.11.9-session.json index 10bd3ec8d87..9636d07b8d6 100644 --- a/oauth2/fixtures/v1.11.9-session.json +++ b/oauth2/fixtures/v1.11.9-session.json @@ -44,6 +44,5 @@ "market", "zone", "login_session_id" - ], - "browser_flow_completed": false + ] } diff --git a/oauth2/handler.go b/oauth2/handler.go index 0dea7c2b26f..0bc1ec85341 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -729,6 +729,15 @@ func (h *Handler) getOidcUserInfo(w http.ResponseWriter, r *http.Request) { func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { ctx := r.Context() + // When this endpoint is called with a valid consent_verifier (meaning that the login flow completed successfully) + // there are 3 writes happening to the database: + // - The flow is created + // - The device auth session is updated (user_code is marked as accepted) + // - The OpenID session is created + // If there were multiple flows created for the same user_code then we may end up with multiple flow objects + // persisted to the database, while only one of them was actually used to validate the user_code + // (see https://github.com/ory/hydra/pull/3851#discussion_r1843678761) + // TODO: We should wrap these queries in a transaction consentSession, f, err := h.r.ConsentStrategy().HandleOAuth2DeviceAuthorizationRequest(ctx, w, r) if errors.Is(err, consent.ErrAbortOAuth2Request) { x.LogAudit(r, nil, h.r.AuditLogger()) @@ -747,26 +756,26 @@ func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r * return } - req, err := h.r.OAuth2Storage().GetDeviceCodeSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), &Session{}) + req, sig, err := h.r.OAuth2Storage().GetDeviceCodeSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), &Session{}) if err != nil { x.LogError(r, err, h.r.Logger()) h.r.Writer().WriteError(w, r, err) return } + req.SetUserCodeState(fosite.UserCodeAccepted) session, err := h.updateSessionWithRequest(ctx, consentSession, f, r, req, req.GetSession().(*Session)) if err != nil { h.r.Writer().WriteError(w, r, err) return } - session.SetBrowserFlowCompleted(true) - req.SetSession(session) // Update the device code session with // - the claims for which the user gave consent // - the granted scopes // - the granted audiences + // - the user_code_state set to accepted // This marks it as ready to be used for the token exchange endpoint. - err = h.r.OAuth2Storage().UpdateDeviceCodeSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), req) + err = h.r.OAuth2Storage().UpdateDeviceCodeSessionBySignature(ctx, sig, req) if err != nil { x.LogError(r, err, h.r.Logger()) h.r.Writer().WriteError(w, r, err) @@ -775,7 +784,7 @@ func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r * // Update the OpenID Connect session if "openid" scope is granted if req.GetGrantedScopes().Has("openid") { - err = h.r.OAuth2Storage().CreateOpenIDConnectSession(ctx, req.GetID(), req.Sanitize([]string{"grant_type", + err = h.r.OAuth2Storage().CreateOpenIDConnectSession(ctx, sig, req.Sanitize([]string{"grant_type", "max_age", "prompt", "acr_values", @@ -868,7 +877,6 @@ func (h *Handler) oAuth2DeviceFlow(w http.ResponseWriter, r *http.Request) { DefaultSession: &openid.DefaultSession{ Headers: &jwt.Headers{}, }, - BrowserFlowCompleted: false, } resp, err := h.r.OAuth2Provider().NewDeviceResponse(ctx, request, session) diff --git a/oauth2/oauth2_device_code_test.go b/oauth2/oauth2_device_code_test.go index d8ed58a3bf3..23e2dcccf97 100644 --- a/oauth2/oauth2_device_code_test.go +++ b/oauth2/oauth2_device_code_test.go @@ -128,14 +128,15 @@ func TestDeviceTokenRequest(t *testing.T) { testCases := []struct { description string - setUp func(signature string) + setUp func(signature, userCodeSignature string) check func(t *testing.T, token *oauth2.Token, err error) cleanUp func() }{ { description: "should pass with refresh token", - setUp: func(signature string) { + setUp: func(signature, userCodeSignature string) { authreq := &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ ID: c.GetID(), @@ -152,13 +153,12 @@ func TestDeviceTokenRequest(t *testing.T) { fosite.DeviceCode: time.Now().Add(time.Hour).UTC(), }, }, - BrowserFlowCompleted: true, }, RequestedAt: time.Now(), }, } - require.NoError(t, reg.OAuth2Storage().CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) }, check: func(t *testing.T, token *oauth2.Token, err error) { assert.NotEmpty(t, token.AccessToken) @@ -167,8 +167,9 @@ func TestDeviceTokenRequest(t *testing.T) { }, { description: "should pass with ID token", - setUp: func(signature string) { + setUp: func(signature, userCodeSignature string) { authreq := &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ ID: c.GetID(), @@ -185,13 +186,12 @@ func TestDeviceTokenRequest(t *testing.T) { fosite.DeviceCode: time.Now().Add(time.Hour).UTC(), }, }, - BrowserFlowCompleted: true, }, RequestedAt: time.Now(), }, } - require.NoError(t, reg.OAuth2Storage().CreateDeviceCodeSession(context.TODO(), signature, authreq)) + require.NoError(t, reg.OAuth2Storage().CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq)) require.NoError(t, reg.OAuth2Storage().CreateOpenIDConnectSession(context.TODO(), signature, authreq)) }, check: func(t *testing.T, token *oauth2.Token, err error) { @@ -205,10 +205,11 @@ func TestDeviceTokenRequest(t *testing.T) { for _, testCase := range testCases { t.Run("case="+testCase.description, func(t *testing.T) { code, signature, err := reg.RFC8628HMACStrategy().GenerateDeviceCode(context.TODO()) + _, userCodeSignature, err := reg.RFC8628HMACStrategy().GenerateUserCode(context.TODO()) require.NoError(t, err) if testCase.setUp != nil { - testCase.setUp(signature) + testCase.setUp(signature, userCodeSignature) } var token *oauth2.Token diff --git a/oauth2/session.go b/oauth2/session.go index cc067916a34..0630cb09142 100644 --- a/oauth2/session.go +++ b/oauth2/session.go @@ -33,7 +33,6 @@ type Session struct { ExcludeNotBeforeClaim bool `json:"exclude_not_before_claim"` AllowedTopLevelClaims []string `json:"allowed_top_level_claims"` MirrorTopLevelClaims bool `json:"mirror_top_level_claims"` - BrowserFlowCompleted bool `json:"browser_flow_completed"` Flow *flow.Flow `json:"-"` } @@ -209,11 +208,3 @@ func (s *Session) GetExtraClaims() map[string]interface{} { return s.Extra } - -func (s *Session) GetBrowserFlowCompleted() bool { - return s.BrowserFlowCompleted -} - -func (s *Session) SetBrowserFlowCompleted(flag bool) { - s.BrowserFlowCompleted = flag -} diff --git a/oauth2/session_test.go b/oauth2/session_test.go index a5094b4d9cd..461d753581a 100644 --- a/oauth2/session_test.go +++ b/oauth2/session_test.go @@ -77,7 +77,6 @@ func TestUnmarshalSession(t *testing.T) { "zone", "login_session_id", }, - BrowserFlowCompleted: false, } t.Run("v1.11.8", func(t *testing.T) { diff --git a/persistence/sql/persister_device.go b/persistence/sql/persister_device.go new file mode 100644 index 00000000000..50675646b45 --- /dev/null +++ b/persistence/sql/persister_device.go @@ -0,0 +1,296 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package sql + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net/url" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/tidwall/gjson" + + "github.com/ory/fosite" + "github.com/ory/hydra/v2/oauth2" + "github.com/ory/x/errorsx" + "github.com/ory/x/otelx" + "github.com/ory/x/sqlcon" + "github.com/ory/x/sqlxx" + "github.com/ory/x/stringsx" +) + +const ( + sqlTableDeviceAuthCodes tableName = "hydra_oauth2_device_auth_codes" +) + +type DeviceRequestSQL struct { + ID string `db:"device_code_signature"` + UserCodeID string `db:"user_code_signature"` + NID uuid.UUID `db:"nid"` + Request string `db:"request_id"` + ConsentChallenge sql.NullString `db:"challenge_id"` + RequestedAt time.Time `db:"requested_at"` + Client string `db:"client_id"` + Scopes string `db:"scope"` + GrantedScope string `db:"granted_scope"` + RequestedAudience string `db:"requested_audience"` + GrantedAudience string `db:"granted_audience"` + Form string `db:"form_data"` + Subject string `db:"subject"` + DeviceCodeActive bool `db:"device_code_active"` + UserCodeState fosite.UserCodeState `db:"user_code_state"` + Session []byte `db:"session_data"` + // InternalExpiresAt denormalizes the expiry from the session to additionally store it as a row. + InternalExpiresAt sqlxx.NullTime `db:"expires_at" json:"-"` +} + +func (r DeviceRequestSQL) TableName() string { + return string(sqlTableDeviceAuthCodes) +} + +func (r *DeviceRequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (_ *fosite.DeviceRequest, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeviceRequestSQL.toRequest") + defer otelx.End(span, &err) + + sess := r.Session + if !gjson.ValidBytes(sess) { + var err error + sess, err = p.r.KeyCipher().Decrypt(ctx, string(sess), nil) + if err != nil { + return nil, errorsx.WithStack(err) + } + } + + if session != nil { + if err := json.Unmarshal(sess, session); err != nil { + return nil, errorsx.WithStack(err) + } + } else { + p.l.Debugf("Got an empty session in toRequest") + } + + c, err := p.GetClient(ctx, r.Client) + if err != nil { + return nil, err + } + + val, err := url.ParseQuery(r.Form) + if err != nil { + return nil, errorsx.WithStack(err) + } + + return &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeState(r.UserCodeState), + Request: fosite.Request{ + ID: r.Request, + RequestedAt: r.RequestedAt, + // ExpiresAt does not need to be populated as we get the expiry time from the session. + Client: c, + RequestedScope: stringsx.Splitx(r.Scopes, "|"), + GrantedScope: stringsx.Splitx(r.GrantedScope, "|"), + RequestedAudience: stringsx.Splitx(r.RequestedAudience, "|"), + GrantedAudience: stringsx.Splitx(r.GrantedAudience, "|"), + Form: val, + Session: session, + }, + }, nil +} + +func (p *Persister) sqlDeviceSchemaFromRequest(ctx context.Context, deviceCodeSignature, userCodeSignature string, r fosite.DeviceRequester, expiresAt time.Time) (*DeviceRequestSQL, error) { + subject := "" + if r.GetSession() == nil { + p.l.Debugf("Got an empty session in sqlSchemaFromRequest") + } else { + subject = r.GetSession().GetSubject() + } + + session, err := json.Marshal(r.GetSession()) + if err != nil { + return nil, errorsx.WithStack(err) + } + + if p.config.EncryptSessionData(ctx) { + ciphertext, err := p.r.KeyCipher().Encrypt(ctx, session, nil) + if err != nil { + return nil, errorsx.WithStack(err) + } + session = []byte(ciphertext) + } + + var challenge sql.NullString + rr, ok := r.GetSession().(*oauth2.Session) + if !ok && r.GetSession() != nil { + return nil, errors.Errorf("Expected request to be of type *Session, but got: %T", r.GetSession()) + } else if ok { + if len(rr.ConsentChallenge) > 0 { + challenge = sql.NullString{Valid: true, String: rr.ConsentChallenge} + } + } + + return &DeviceRequestSQL{ + Request: r.GetID(), + ConsentChallenge: challenge, + ID: deviceCodeSignature, + UserCodeID: userCodeSignature, + RequestedAt: r.GetRequestedAt(), + InternalExpiresAt: sqlxx.NullTime(expiresAt), + Client: r.GetClient().GetID(), + Scopes: strings.Join(r.GetRequestedScopes(), "|"), + GrantedScope: strings.Join(r.GetGrantedScopes(), "|"), + GrantedAudience: strings.Join(r.GetGrantedAudience(), "|"), + RequestedAudience: strings.Join(r.GetRequestedAudience(), "|"), + Form: r.GetRequestForm().Encode(), + Session: session, + Subject: subject, + DeviceCodeActive: true, + UserCodeState: r.GetUserCodeState(), + }, nil +} + +func (p *Persister) createDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, requester fosite.DeviceRequester, expiresAt time.Time) error { + req, err := p.sqlDeviceSchemaFromRequest(ctx, deviceCodeSignature, userCodeSignature, requester, expiresAt) + if err != nil { + return err + } + + if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } else if err != nil { + return err + } + return nil +} + +// CreateDeviceCodeSession creates a new device code session and stores it in the database +func (p *Persister) CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, requester fosite.DeviceRequester) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateDeviceCodeSession") + defer otelx.End(span, &err) + return p.createDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, requester, requester.GetSession().GetExpiresAt(fosite.DeviceCode).UTC()) +} + +// UpdateDeviceCodeSessionBySignature updates a device code session by the device_code signature +func (p *Persister) UpdateDeviceCodeSessionBySignature(ctx context.Context, signature string, requester fosite.DeviceRequester) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateDeviceCodeSessionBySignature") + defer otelx.End(span, &err) + + req, err := p.sqlDeviceSchemaFromRequest(ctx, signature, "", requester, requester.GetSession().GetExpiresAt(fosite.DeviceCode).UTC()) + if err != nil { + return err + } + + stmt := fmt.Sprintf( + "UPDATE %s SET granted_scope=?, granted_audience=?, session_data=?, user_code_state=? WHERE device_code_signature=? AND nid = ?", + sqlTableDeviceAuthCodes, + ) + + /* #nosec G201 table is static */ + err = p.Connection(ctx).RawQuery(stmt, req.GrantedScope, req.GrantedAudience, req.Session, req.UserCodeState, signature, p.NetworkID(ctx)).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + + return nil +} + +// GetDeviceCodeSession returns a device code session from the database +func (p *Persister) GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.DeviceRequester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSession") + defer otelx.End(span, &err) + + r := DeviceRequestSQL{} + err = p.QueryWithNetwork(ctx).Where("device_code_signature = ?", signature).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(fosite.ErrNotFound) + } + if err != nil { + return nil, sqlcon.HandleError(err) + } + if !r.DeviceCodeActive { + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + return fr, errorsx.WithStack(fosite.ErrInactiveToken) + } + + return r.toRequest(ctx, session, p) +} + +// GetDeviceCodeSessionByRequestID returns a device code session from the database +func (p *Persister) GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, session fosite.Session) (_ fosite.DeviceRequester, deviceCodeSignature string, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSessionByRequestID") + defer otelx.End(span, &err) + + r := DeviceRequestSQL{} + err = p.QueryWithNetwork(ctx).Where("request_id = ?", requestID).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, "", errorsx.WithStack(fosite.ErrNotFound) + } + if err != nil { + return nil, "", sqlcon.HandleError(err) + } + if !r.DeviceCodeActive { + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, "", err + } + return fr, r.ID, errorsx.WithStack(fosite.ErrInactiveToken) + } + + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, "", err + } + return fr, r.ID, nil +} + +// InvalidateDeviceCodeSession invalidates a device code session +func (p *Persister) InvalidateDeviceCodeSession(ctx context.Context, signature string) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InvalidateDeviceCodeSession") + defer otelx.End(span, &err) + + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.Connection(ctx). + RawQuery( + fmt.Sprintf("UPDATE %s SET device_code_active=false WHERE device_code_signature=? AND nid = ?", sqlTableDeviceAuthCodes), + signature, + p.NetworkID(ctx), + ). + Exec(), + ) +} + +// GetUserCodeSession returns a user code session from the database +func (p *Persister) GetUserCodeSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.DeviceRequester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUserCodeSession") + defer otelx.End(span, &err) + + r := DeviceRequestSQL{} + if session == nil { + session = oauth2.NewSession("") + } + err = p.QueryWithNetwork(ctx).Where("user_code_signature = ?", signature).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(fosite.ErrNotFound) + } + if err != nil { + return nil, sqlcon.HandleError(err) + } + + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + if r.UserCodeState != fosite.UserCodeUnused { + return fr, errorsx.WithStack(fosite.ErrInactiveToken) + } + + return fr, err +} diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 648695e5ee4..7595d3b58c4 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -65,13 +65,11 @@ type ( ) const ( - sqlTableOpenID tableName = "oidc" - sqlTableAccess tableName = "access" - sqlTableRefresh tableName = "refresh" - sqlTableCode tableName = "code" - sqlTablePKCE tableName = "pkce" - sqlTableDeviceCode tableName = "device_code" - sqlTableUserCode tableName = "user_code" + sqlTableOpenID tableName = "oidc" + sqlTableAccess tableName = "access" + sqlTableRefresh tableName = "refresh" + sqlTableCode tableName = "code" + sqlTablePKCE tableName = "pkce" ) func (r OAuth2RequestSQL) TableName() string { @@ -286,29 +284,6 @@ func (p *Persister) findSessionBySignature(ctx context.Context, signature string return r.toRequest(ctx, session, p) } -func (p *Persister) findSessionByRequestID(ctx context.Context, requestID string, session fosite.Session, table tableName) (fosite.Requester, error) { - r := OAuth2RequestSQL{Table: table} - err := p.QueryWithNetwork(ctx).Where("request_id = ?", requestID).First(&r) - if errors.Is(err, sql.ErrNoRows) { - return nil, errorsx.WithStack(fosite.ErrNotFound) - } - if err != nil { - return nil, sqlcon.HandleError(err) - } - if !r.Active { - fr, err := r.toRequest(ctx, session, p) - if err != nil { - return nil, err - } - if table == sqlTableCode { - return fr, errorsx.WithStack(fosite.ErrInvalidatedAuthorizeCode) - } - return fr, errorsx.WithStack(fosite.ErrInactiveToken) - } - - return r.toRequest(ctx, session, p) -} - func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) error { err := sqlcon.HandleError( p.QueryWithNetwork(ctx). @@ -779,117 +754,3 @@ func (p *Persister) RotateRefreshToken(ctx context.Context, requestID string, re return handleRetryError(p.strictRefreshRotation(ctx, requestID)) } - -// CreateDeviceCodeSession creates a new device code session and stores it in the database -func (p *Persister) CreateDeviceCodeSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateDeviceCodeSession") - defer otelx.End(span, &err) - return p.createSession(ctx, signature, requester, sqlTableDeviceCode, requester.GetSession().GetExpiresAt(fosite.DeviceCode).UTC()) -} - -// UpdateDeviceCodeSessionByRequestID updates a device code session by requestID -func (p *Persister) UpdateDeviceCodeSessionByRequestID(ctx context.Context, requestID string, requester fosite.Requester) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateDeviceCodeSessionByRequestID") - defer otelx.End(span, &err) - - req, err := p.sqlSchemaFromRequest(ctx, requestID, requester, sqlTableDeviceCode, requester.GetSession().GetExpiresAt(fosite.DeviceCode).UTC()) - if err != nil { - return err - } - - stmt := fmt.Sprintf( - "UPDATE %s SET granted_scope=?, granted_audience=?, session_data=? WHERE request_id=? AND nid = ?", - OAuth2RequestSQL{Table: sqlTableDeviceCode}.TableName(), - ) - - /* #nosec G201 table is static */ - err = p.Connection(ctx).RawQuery(stmt, req.GrantedScope, req.GrantedAudience, req.Session, requestID, p.NetworkID(ctx)).Exec() - if err != nil { - return sqlcon.HandleError(err) - } - - return nil -} - -// GetDeviceCodeSession returns a device code session from the database -func (p *Persister) GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.Requester, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSession") - defer otelx.End(span, &err) - return p.findSessionBySignature(ctx, signature, session, sqlTableDeviceCode) -} - -// GetDeviceCodeSessionByRequestID returns a device code session from the database -func (p *Persister) GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, session fosite.Session) (_ fosite.Requester, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetDeviceCodeSessionByRequestID") - defer otelx.End(span, &err) - return p.findSessionByRequestID(ctx, requestID, session, sqlTableDeviceCode) -} - -// InvalidateDeviceCodeSession invalidates a device code session -func (p *Persister) InvalidateDeviceCodeSession(ctx context.Context, signature string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InvalidateDeviceCodeSession") - defer otelx.End(span, &err) - - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.Connection(ctx). - RawQuery( - fmt.Sprintf("UPDATE %s SET active=false WHERE signature=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableDeviceCode}.TableName()), - signature, - p.NetworkID(ctx), - ). - Exec(), - ) -} - -// CreateUserCodeSession creates a new user code session and stores it in the database -func (p *Persister) CreateUserCodeSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateUserCodeSession") - defer otelx.End(span, &err) - return p.createSession(ctx, signature, requester, sqlTableUserCode, requester.GetSession().GetExpiresAt(fosite.UserCode).UTC()) -} - -// GetUserCodeSession returns a user code session from the database -func (p *Persister) GetUserCodeSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.Requester, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUserCodeSession") - defer otelx.End(span, &err) - if session == nil { - session = oauth2.NewSession("") - } - return p.findSessionBySignature(ctx, signature, session, sqlTableUserCode) -} - -// InvalidateUserCodeSession invalidates a user code session -func (p *Persister) InvalidateUserCodeSession(ctx context.Context, signature string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InvalidateUserCodeSession") - defer otelx.End(span, &err) - - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.Connection(ctx). - RawQuery( - fmt.Sprintf("UPDATE %s SET active=false WHERE signature=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableUserCode}.TableName()), - signature, - p.NetworkID(ctx), - ). - Exec(), - ) -} - -// UpdateAndInvalidateUserCodeSession invalidates a user code session and connects it with the device flow request ID -func (p *Persister) UpdateAndInvalidateUserCodeSessionByRequestID(ctx context.Context, request_id, challenge_id string) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateAndInvalidateUserCodeSession") - defer otelx.End(span, &err) - - if count, err := p.Connection(ctx).RawQuery( - fmt.Sprintf("UPDATE %s SET active=false, challenge_id=? WHERE request_id=? AND nid = ? AND active=true", OAuth2RequestSQL{Table: sqlTableUserCode}.TableName()), - challenge_id, - request_id, - p.NetworkID(ctx), - ).ExecWithCount(); count == 0 && err == nil { - return errorsx.WithStack(x.ErrNotFound) - } else if err != nil { - return sqlcon.HandleError(err) - } - return nil -} diff --git a/x/clean_sql.go b/x/clean_sql.go index 243d65033d8..2b51ec2cde3 100644 --- a/x/clean_sql.go +++ b/x/clean_sql.go @@ -16,8 +16,7 @@ func DeleteHydraRows(t *testing.T, c *pop.Connection) { "hydra_oauth2_code", "hydra_oauth2_oidc", "hydra_oauth2_pkce", - "hydra_oauth2_device_code", - "hydra_oauth2_user_code", + "hydra_oauth2_device_auth_codes", "hydra_oauth2_flow", "hydra_oauth2_authentication_session", "hydra_oauth2_obfuscated_authentication_session", @@ -41,8 +40,7 @@ func CleanSQLPop(t *testing.T, c *pop.Connection) { "hydra_oauth2_code", "hydra_oauth2_oidc", "hydra_oauth2_pkce", - "hydra_oauth2_device_code", - "hydra_oauth2_user_code", + "hydra_oauth2_device_auth_codes", "hydra_oauth2_flow", "hydra_oauth2_authentication_session", "hydra_oauth2_obfuscated_authentication_session", diff --git a/x/fosite_storer.go b/x/fosite_storer.go index 8b879e11da7..2313ca199d8 100644 --- a/x/fosite_storer.go +++ b/x/fosite_storer.go @@ -44,7 +44,7 @@ type FositeStorer interface { // This is duplicated from Ory Fosite to help against deprecation linting errors. DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error - GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, requester fosite.Session) (fosite.Requester, error) - UpdateDeviceCodeSessionByRequestID(ctx context.Context, requestID string, requester fosite.Requester) error - UpdateAndInvalidateUserCodeSessionByRequestID(ctx context.Context, signature, request_id string) (err error) + GetUserCodeSession(context.Context, string, fosite.Session) (fosite.DeviceRequester, error) + GetDeviceCodeSessionByRequestID(ctx context.Context, requestID string, requester fosite.Session) (fosite.DeviceRequester, string, error) + UpdateDeviceCodeSessionBySignature(ctx context.Context, requestID string, requester fosite.DeviceRequester) error }