From a21e94519416cc7801995b0804696348b18fa844 Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Thu, 3 Aug 2023 12:36:34 +0200 Subject: [PATCH] fix: only query access tokens by hashed signature --- persistence/sql/persister_nid_test.go | 12 ++- persistence/sql/persister_oauth2.go | 137 +++++++++++++++----------- x/audit_test.go | 2 - x/clean_sql.go | 3 +- 4 files changed, 85 insertions(+), 69 deletions(-) diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go index 6ad1c937aec..83fad7c1452 100644 --- a/persistence/sql/persister_nid_test.go +++ b/persistence/sql/persister_nid_test.go @@ -40,14 +40,16 @@ import ( type PersisterTestSuite struct { suite.Suite registries map[string]driver.Registry - clean func(*testing.T) t1 context.Context t2 context.Context t1NID uuid.UUID t2NID uuid.UUID } -var _ PersisterTestSuite = PersisterTestSuite{} +var _ interface { + suite.SetupAllSuite + suite.TearDownTestSuite +} = (*PersisterTestSuite)(nil) func (s *PersisterTestSuite) SetupSuite() { s.registries = map[string]driver.Registry{ @@ -55,7 +57,7 @@ func (s *PersisterTestSuite) SetupSuite() { } if !testing.Short() { - s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], s.clean = internal.ConnectDatabases(s.T(), true, &contextx.Default{}) + s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = internal.ConnectDatabases(s.T(), true, &contextx.Default{}) } s.t1NID, s.t2NID = uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4()) @@ -558,11 +560,11 @@ func (s *PersisterTestSuite) DeleteAccessTokenSession() { require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t2, sig)) actual := persistencesql.OAuth2RequestSQL{Table: "access"} - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig))) require.Equal(t, s.t1NID, actual.NID) require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t1, sig)) - require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, sig)) + require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig))) }) } } diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index fb2faba6c0c..c49c9c7f823 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -67,7 +67,7 @@ func (r OAuth2RequestSQL) TableName() string { return "hydra_oauth2_" + string(r.Table) } -func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) { +func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) { subject := "" if r.GetSession() == nil { p.l.Debugf("Got an empty session in sqlSchemaFromRequest") @@ -101,7 +101,7 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin return &OAuth2RequestSQL{ Request: r.GetID(), ConsentChallenge: challenge, - ID: p.hashSignature(ctx, rawSignature, table), + ID: signature, RequestedAt: r.GetRequestedAt(), Client: r.GetClient().GetID(), Scopes: strings.Join(r.GetRequestedScopes(), "|"), @@ -160,20 +160,6 @@ func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session }, nil } -// SignatureHash hashes the signature to prevent errors where the signature is -// longer than 128 characters (and thus doesn't fit into the pk). -func SignatureHash(signature string) string { - return fmt.Sprintf("%x", sha512.Sum384([]byte(signature))) -} - -// hashSignature prevents errors where the signature is longer than 128 characters (and thus doesn't fit into the pk). -func (p *Persister) hashSignature(_ context.Context, signature string, table tableName) string { - if table == sqlTableAccess { - return SignatureHash(signature) - } - return signature -} - func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid") defer otelx.End(span, &err) @@ -228,7 +214,7 @@ func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.Bl return sqlcon.HandleError(p.CreateWithNetwork(ctx, jti)) } -func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) (err error) { +func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) error { req, err := p.sqlSchemaFromRequest(ctx, signature, requester, table) if err != nil { return err @@ -242,28 +228,21 @@ func (p *Persister) createSession(ctx context.Context, signature string, request return nil } -func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature") - defer otelx.End(span, &err) - +func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (fosite.Requester, error) { r := OAuth2RequestSQL{Table: table} - - // We look for the signature as well as the hash of the signature here. - // This is because we now always store the hash of the signature in the database, - // regardless of the type of the signature. In previous versions, we only stored - // the hash of the signature for JWT tokens. - // - // This code will be removed in a future version. - err = p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r) + err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) if errors.Is(err, sql.ErrNoRows) { return nil, errorsx.WithStack(fosite.ErrNotFound) - } else if err != nil { + } + if err != nil { return nil, sqlcon.HandleError(err) - } else if !r.Active { + } + if !r.Active { fr, err := r.toRequest(ctx, session, p) if err != nil { return nil, err - } else if table == sqlTableCode { + } + if table == sqlTableCode { return fr, errorsx.WithStack(fosite.ErrInvalidatedAuthorizeCode) } return fr, errorsx.WithStack(fosite.ErrInactiveToken) @@ -272,46 +251,35 @@ func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature str return r.toRequest(ctx, session, p) } -func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature") - defer otelx.End(span, &err) - - signature = p.hashSignature(ctx, signature, table) - - // We look for the signature as well as the hash of the signature here. - // This is because we now always store the hash of the signature in the database, - // regardless of the type of the signature. In previous versions, we only stored - // the hash of the signature for JWT tokens. - // - // This code will be removed in a future version. - err = sqlcon.HandleError( +func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) error { + err := sqlcon.HandleError( p.QueryWithNetwork(ctx). - Where("signature IN (?, ?)", signature, SignatureHash(signature)). + Where("signature = ?", signature). Delete(&OAuth2RequestSQL{Table: table})) - if errors.Is(err, sqlcon.ErrNoRows) { return errorsx.WithStack(fosite.ErrNotFound) - } else if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + } + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) - } else if err != nil { - return err } - return nil + return err } func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionByRequestID") defer otelx.End(span, &err) - /* #nosec G201 table is static */ - if err := p.QueryWithNetwork(ctx). + err = p.QueryWithNetwork(ctx). Where("request_id=?", id). - Delete(&OAuth2RequestSQL{Table: table}); errors.Is(err, sql.ErrNoRows) { + Delete(&OAuth2RequestSQL{Table: table}) + if errors.Is(err, sql.ErrNoRows) { return errorsx.WithStack(fosite.ErrNotFound) - } else if err := sqlcon.HandleError(err); err != nil { + } + if err := sqlcon.HandleError(err); err != nil { if errors.Is(err, sqlcon.ErrConcurrentUpdate) { return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) - } else if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock? + } + if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock? return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) } return err @@ -356,7 +324,7 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur return sqlcon.HandleError( p.Connection(ctx). RawQuery( - fmt.Sprintf("UPDATE %s SET active=false WHERE signature=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()), + fmt.Sprintf("UPDATE %s SET active = false WHERE signature = ? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()), signature, p.NetworkID(ctx), ). @@ -364,6 +332,12 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur ) } +// SignatureHash hashes the signature to prevent errors where the signature is +// longer than 128 characters (and thus doesn't fit into the pk). +func SignatureHash(signature string) string { + return fmt.Sprintf("%x", sha512.Sum384([]byte(signature))) +} + func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateAccessTokenSession") defer otelx.End(span, &err) @@ -372,19 +346,62 @@ func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature stri append(toEventOptions(requester), events.WithGrantType(requester.GetRequestForm().Get("grant_type")))..., ) - return p.createSession(ctx, signature, requester, sqlTableAccess) + return p.createSession(ctx, SignatureHash(signature), requester, sqlTableAccess) } func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAccessTokenSession") defer otelx.End(span, &err) - return p.findSessionBySignature(ctx, signature, session, sqlTableAccess) + + r := OAuth2RequestSQL{Table: sqlTableAccess} + err = p.QueryWithNetwork(ctx).Where("signature = ?", SignatureHash(signature)).First(&r) + if errors.Is(err, sql.ErrNoRows) { + // Backwards compatibility: we previously did not always hash the + // signature before inserting. In case there are still very old (but + // valid) access tokens in the database, this should get them. + err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(fosite.ErrNotFound) + } + } + if err != nil { + return nil, sqlcon.HandleError(err) + } + if !r.Active { + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } + return fr, errorsx.WithStack(fosite.ErrInactiveToken) + } + + return r.toRequest(ctx, session, p) } func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokenSession") defer otelx.End(span, &err) - return p.deleteSessionBySignature(ctx, signature, sqlTableAccess) + + err = sqlcon.HandleError( + p.QueryWithNetwork(ctx). + Where("signature = ?", SignatureHash(signature)). + Delete(&OAuth2RequestSQL{Table: sqlTableAccess})) + if errors.Is(err, sqlcon.ErrNoRows) { + // Backwards compatibility: we previously did not always hash the + // signature before inserting. In case there are still very old (but + // valid) access tokens in the database, this should get them. + err = sqlcon.HandleError( + p.QueryWithNetwork(ctx). + Where("signature = ?", signature). + Delete(&OAuth2RequestSQL{Table: sqlTableAccess})) + if errors.Is(err, sqlcon.ErrNoRows) { + return errorsx.WithStack(fosite.ErrNotFound) + } + } + if errors.Is(err, sqlcon.ErrConcurrentUpdate) { + return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) + } + return err } func toEventOptions(requester fosite.Requester) []trace.EventOption { diff --git a/x/audit_test.go b/x/audit_test.go index ef563c04a53..0a4061551d2 100644 --- a/x/audit_test.go +++ b/x/audit_test.go @@ -43,8 +43,6 @@ func TestLogAudit(t *testing.T) { l.Logger.Out = buf LogAudit(r, tc.message, l) - t.Logf("%s", buf.String()) - assert.Contains(t, buf.String(), "audience=audit") for _, expectContain := range tc.expectContains { assert.Contains(t, buf.String(), expectContain) diff --git a/x/clean_sql.go b/x/clean_sql.go index 59628fb3f97..a02a9a054ce 100644 --- a/x/clean_sql.go +++ b/x/clean_sql.go @@ -10,7 +10,6 @@ import ( ) func DeleteHydraRows(t *testing.T, c *pop.Connection) { - t.Logf("Deleting hydra rows in database: %s", c.Dialect.Name()) for _, tb := range []string{ "hydra_oauth2_access", "hydra_oauth2_refresh", @@ -57,7 +56,7 @@ func CleanSQLPop(t *testing.T, c *pop.Connection) { "schema_migration", } { if err := c.RawQuery("DROP TABLE IF EXISTS " + tb).Exec(); err != nil { - t.Logf(`Unable to clean up table "%s": %s`, tb, err) + t.Fatalf(`Unable to clean up table "%s": %s`, tb, err) } } t.Logf("Successfully cleaned up database: %s", c.Dialect.Name())