From da2775458be0e72b379423158ebc2a61c3c24ecb Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Thu, 31 Oct 2024 17:05:16 +0100 Subject: [PATCH] feat: improve QueryForCredentials (#4181) --- identity/test/pool.go | 15 +- .../sql/identity/persister_identity.go | 162 +++++++++--------- persistence/sql/migratest/migration_test.go | 5 +- 3 files changed, 89 insertions(+), 93 deletions(-) diff --git a/identity/test/pool.go b/identity/test/pool.go index 14b41f4bf4d8..4f898917449f 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -13,10 +13,6 @@ import ( "testing" "time" - confighelpers "github.com/ory/kratos/driver/config/testhelpers" - - "github.com/ory/x/crdbx" - "github.com/go-faker/faker/v4" "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" @@ -24,12 +20,15 @@ import ( "github.com/tidwall/gjson" "github.com/ory/kratos/driver/config" + confighelpers "github.com/ory/kratos/driver/config/testhelpers" "github.com/ory/kratos/identity" "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/persistence" + idpersistence "github.com/ory/kratos/persistence/sql/identity" "github.com/ory/kratos/schema" "github.com/ory/kratos/x" "github.com/ory/x/assertx" + "github.com/ory/x/crdbx" "github.com/ory/x/errorsx" "github.com/ory/x/pagination/keysetpagination" "github.com/ory/x/randx" @@ -1214,21 +1213,21 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, t.Run("suite=credential-types", func(t *testing.T) { for _, ct := range identity.AllCredentialTypes { t.Run("type="+ct.String(), func(t *testing.T) { - id, err := p.FindIdentityCredentialsTypeByName(ctx, ct) + id, err := idpersistence.FindIdentityCredentialsTypeByName(p.GetConnection(ctx), ct) require.NoError(t, err) require.NotEqual(t, uuid.Nil, id) - name, err := p.FindIdentityCredentialsTypeByID(ctx, id) + name, err := idpersistence.FindIdentityCredentialsTypeByID(p.GetConnection(ctx), id) require.NoError(t, err) assert.Equal(t, ct, name) }) } - _, err := p.FindIdentityCredentialsTypeByName(ctx, "unknown") + _, err := idpersistence.FindIdentityCredentialsTypeByName(p.GetConnection(ctx), "unknown") require.Error(t, err) - _, err = p.FindIdentityCredentialsTypeByID(ctx, x.NewUUID()) + _, err = idpersistence.FindIdentityCredentialsTypeByID(p.GetConnection(ctx), x.NewUUID()) require.Error(t, err) }) diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 48db65dfb0fa..a5143736712d 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -10,6 +10,7 @@ import ( "fmt" "sort" "strings" + "sync" "time" "github.com/ory/kratos/x/events" @@ -60,17 +61,12 @@ type IdentityPersister struct { r dependencies c *pop.Connection nid uuid.UUID - - credentialTypesID *x.SyncMap[uuid.UUID, identity.CredentialsType] - credentialTypesName *x.SyncMap[identity.CredentialsType, uuid.UUID] } func NewPersister(r dependencies, c *pop.Connection) *IdentityPersister { return &IdentityPersister{ - c: c, - r: r, - credentialTypesID: x.NewSyncMap[uuid.UUID, identity.CredentialsType](), - credentialTypesName: x.NewSyncMap[identity.CredentialsType, uuid.UUID](), + c: c, + r: r, } } @@ -313,7 +309,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn cred.Config = sqlxx.JSONRawMessage("{}") } - ct, err := p.FindIdentityCredentialsTypeByName(ctx, cred.Type) + ct, err := FindIdentityCredentialsTypeByName(conn, cred.Type) if err != nil { return err } @@ -344,7 +340,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn "Unable to create identity credentials with missing or empty identifier.")) } - ct, err := p.FindIdentityCredentialsTypeByName(ctx, cred.Type) + ct, err := FindIdentityCredentialsTypeByName(conn, cred.Type) if err != nil { return err } @@ -662,10 +658,7 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i * attribute.Stringer("network.id", p.NetworkID(ctx)))) defer otelx.End(span, &err) - var ( - con = p.GetConnection(ctx) - nid = p.NetworkID(ctx) - ) + nid := p.NetworkID(ctx) eg, ctx := errgroup.WithContext(ctx) if expand.Has(identity.ExpandFieldRecoveryAddresses) { @@ -674,7 +667,7 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i * // from complaining incorrectly. // // https://github.com/gobuffalo/pop/issues/723 - if err := con.WithContext(ctx). + if err := p.GetConnection(ctx).WithContext(ctx). Where("identity_id = ? AND nid = ?", i.ID, nid). Order("id ASC"). All(&i.RecoveryAddresses); err != nil { @@ -690,7 +683,7 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i * // from complaining incorrectly. // // https://github.com/gobuffalo/pop/issues/723 - if err := con.WithContext(ctx). + if err := p.GetConnection(ctx).WithContext(ctx). Order("id ASC"). Where("identity_id = ? AND nid = ?", i.ID, nid). All(&i.VerifiableAddresses); err != nil { @@ -706,9 +699,9 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i * // from complaining incorrectly. // // https://github.com/gobuffalo/pop/issues/723 - con := con.WithContext(ctx) - creds, err := QueryForCredentials(con, - Where{"(identity_credentials.identity_id = ? AND identity_credentials.nid = ?)", []interface{}{i.ID, nid}}) + creds, err := QueryForCredentials(p.GetConnection(ctx).WithContext(ctx), + Where{"identity_credentials.identity_id = ?", []interface{}{i.ID}}, + Where{"identity_credentials.nid = ?", []interface{}{nid}}) if err != nil { return err } @@ -733,16 +726,8 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i * } type queryCredentials struct { - ID uuid.UUID `db:"cred_id"` - IdentityID uuid.UUID `db:"identity_id"` - NID uuid.UUID `db:"nid"` - Type identity.CredentialsType `db:"cred_type"` - TypeID uuid.UUID `db:"cred_type_id"` - Identifier string `db:"cred_identifier"` - Config sqlxx.JSONRawMessage `db:"cred_config"` - Version int `db:"cred_version"` - CreatedAt time.Time `db:"created_at"` - UpdatedAt time.Time `db:"updated_at"` + Identifier string `db:"cred_identifier"` + identity.Credentials } func (queryCredentials) TableName() string { @@ -756,35 +741,23 @@ type Where struct { // QueryForCredentials queries for identity credentials with custom WHERE // clauses, returning the results resolved by the owning identity's UUID. -func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](map[identity.CredentialsType]identity.Credentials), error) { - ici := "identity_credential_identifiers" - switch con.Dialect.Name() { - case "cockroach": - ici += "@identity_credential_identifiers_identity_credential_id_idx" - case "sqlite3": - ici += " INDEXED BY identity_credential_identifiers_identity_credential_id_idx" - case "mysql": - ici += " USE INDEX(identity_credential_identifiers_identity_credential_id_idx)" - default: - // good luck 🤷‍♂️ - } +func QueryForCredentials(con *pop.Connection, where ...Where) (credentialsPerIdentity map[uuid.UUID](map[identity.CredentialsType]identity.Credentials), err error) { + // This query has been meticulously crafted to be as fast as possible. + // If you touch it, you will likely introduce a performance regression. q := con.Select( - "identity_credentials.id cred_id", - "identity_credentials.identity_id identity_id", - "identity_credentials.nid nid", - "ict.name cred_type", - "ict.id cred_type_id", "COALESCE(identity_credential_identifiers.identifier, '') cred_identifier", - "identity_credentials.config cred_config", - "identity_credentials.version cred_version", - "identity_credentials.created_at created_at", - "identity_credentials.updated_at updated_at", - ).InnerJoin( - "identity_credential_types ict", - "(identity_credentials.identity_credential_type_id = ict.id)", - ).LeftJoin( - ici, + "identity_credentials.id", + "identity_credentials.identity_credential_type_id", + "identity_credentials.identity_id", + "identity_credentials.nid", + "identity_credentials.config", + "identity_credentials.version", + "identity_credentials.created_at", + "identity_credentials.updated_at", + ).LeftJoin(identifiersTableNameWithIndexHint(con), "identity_credential_identifiers.identity_credential_id = identity_credentials.id AND identity_credential_identifiers.nid = identity_credentials.nid", + ).Order( + "identity_credentials.id ASC", ) for _, w := range where { q = q.Where("("+w.Condition+")", w.Args...) @@ -793,8 +766,16 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma if err := q.All(&results); err != nil { return nil, sqlcon.HandleError(err) } - credentialsPerIdentity := map[uuid.UUID](map[identity.CredentialsType]identity.Credentials){} + + // assemble + credentialsPerIdentity = map[uuid.UUID](map[identity.CredentialsType]identity.Credentials){} for _, res := range results { + + res.Type, err = FindIdentityCredentialsTypeByID(con, res.IdentityCredentialTypeID) + if err != nil { + return nil, err + } + credentials, ok := credentialsPerIdentity[res.IdentityID] if !ok { credentialsPerIdentity[res.IdentityID] = make(map[identity.CredentialsType]identity.Credentials) @@ -807,20 +788,10 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma if identifiers == nil { identifiers = make([]string, 0) } - c := identity.Credentials{ - ID: res.ID, - IdentityID: res.IdentityID, - NID: res.NID, - Type: res.Type, - IdentityCredentialTypeID: res.TypeID, - Identifiers: identifiers, - Config: res.Config, - Version: res.Version, - CreatedAt: res.CreatedAt, - UpdatedAt: res.UpdatedAt, - } - credentials[res.Type] = c + res.Identifiers = identifiers + credentials[res.Type] = res.Credentials } + // We need deterministic ordering for testing, but sorting in the // database can be expensive under certain circumstances. for _, creds := range credentialsPerIdentity { @@ -831,6 +802,21 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma return credentialsPerIdentity, nil } +func identifiersTableNameWithIndexHint(con *pop.Connection) string { + ici := "identity_credential_identifiers" + switch con.Dialect.Name() { + case "cockroach": + ici += "@identity_credential_identifiers_nid_i_ici_idx" + case "sqlite3": + ici += " INDEXED BY identity_credential_identifiers_nid_i_ici_idx" + case "mysql": + ici += " USE INDEX(identity_credential_identifiers_nid_i_ici_idx)" + default: + // good luck 🤷‍♂️ + } + return ici +} + func paginationAttributes(params *identity.ListIdentityParameters, paginator *keysetpagination.Paginator) []attribute.KeyValue { attrs := []attribute.KeyValue{ attribute.StringSlice("expand", params.Expand.ToEager()), @@ -856,7 +842,7 @@ func (p *IdentityPersister) getCredentialTypeIDs(ctx context.Context, credential result := map[identity.CredentialsType]uuid.UUID{} for _, ct := range credentialTypes { - typeID, err := p.FindIdentityCredentialsTypeByName(ctx, ct) + typeID, err := FindIdentityCredentialsTypeByName(p.GetConnection(ctx), ct) if err != nil { return nil, err } @@ -991,7 +977,7 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. switch e { case identity.ExpandFieldCredentials: creds, err := QueryForCredentials(con, - Where{"identity_credentials.nid = ?", []any{nid}}, + Where{"identity_credentials.nid = ?", []interface{}{nid}}, Where{"identity_credentials.identity_id IN (?)", identityIDs}) if err != nil { return err @@ -1315,14 +1301,19 @@ func (p *IdentityPersister) InjectTraitsSchemaURL(ctx context.Context, i *identi return nil } -func (p *IdentityPersister) FindIdentityCredentialsTypeByID(ctx context.Context, id uuid.UUID) (identity.CredentialsType, error) { - result, found := p.credentialTypesID.Load(id) +var ( + credentialTypesID = x.NewSyncMap[uuid.UUID, identity.CredentialsType]() + credentialTypesName = x.NewSyncMap[identity.CredentialsType, uuid.UUID]() +) + +func FindIdentityCredentialsTypeByID(con *pop.Connection, id uuid.UUID) (identity.CredentialsType, error) { + result, found := credentialTypesID.Load(id) if !found { - if err := p.loadCredentialTypes(ctx); err != nil { + if err := loadCredentialTypes(con); err != nil { return "", err } - result, found = p.credentialTypesID.Load(id) + result, found = credentialTypesID.Load(id) } if !found { @@ -1332,14 +1323,14 @@ func (p *IdentityPersister) FindIdentityCredentialsTypeByID(ctx context.Context, return result, nil } -func (p *IdentityPersister) FindIdentityCredentialsTypeByName(ctx context.Context, ct identity.CredentialsType) (uuid.UUID, error) { - result, found := p.credentialTypesName.Load(ct) +func FindIdentityCredentialsTypeByName(con *pop.Connection, ct identity.CredentialsType) (uuid.UUID, error) { + result, found := credentialTypesName.Load(ct) if !found { - if err := p.loadCredentialTypes(ctx); err != nil { + if err := loadCredentialTypes(con); err != nil { return uuid.Nil, err } - result, found = p.credentialTypesName.Load(ct) + result, found = credentialTypesName.Load(ct) } if !found { @@ -1349,18 +1340,23 @@ func (p *IdentityPersister) FindIdentityCredentialsTypeByName(ctx context.Contex return result, nil } -func (p *IdentityPersister) loadCredentialTypes(ctx context.Context) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.identity.loadCredentialTypes") +var mux sync.Mutex + +func loadCredentialTypes(con *pop.Connection) (err error) { + ctx, span := trace.SpanFromContext(con.Context()).TracerProvider().Tracer("").Start(con.Context(), "persistence.sql.identity.loadCredentialTypes") defer otelx.End(span, &err) + _ = ctx + mux.Lock() + defer mux.Unlock() var tt []identity.CredentialsTypeTable - if err := p.GetConnection(ctx).All(&tt); err != nil { + if err := con.WithContext(ctx).All(&tt); err != nil { return sqlcon.HandleError(err) } for _, t := range tt { - p.credentialTypesID.Store(t.ID, t.Name) - p.credentialTypesName.Store(t.Name, t.ID) + credentialTypesID.Store(t.ID, t.Name) + credentialTypesName.Store(t.Name, t.ID) } return nil diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index bf727683248b..bafe38c040d6 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -422,6 +422,7 @@ func testDatabase(t *testing.T, db string, c *pop.Connection) { }) }) - tm.DumpMigrations = false - require.NoError(t, tm.Down(ctx, -1)) + tm.DumpMigrations = false // true for debug + err = tm.Down(ctx, -1) // for easy breakpointing + require.NoError(t, err) }