Skip to content

Commit

Permalink
feat: improve QueryForCredentials (ory#4181)
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr authored and malosayli committed Nov 6, 2024
1 parent 2b3b521 commit da27754
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 93 deletions.
15 changes: 7 additions & 8 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,22 @@ import (
"testing"
"time"

confighelpers "github.com/ory/kratos/driver/config/testhelpers"

"github.com/ory/x/crdbx"

"github.com/go-faker/faker/v4"
"github.com/gofrs/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/ory/kratos/driver/config"
confighelpers "github.com/ory/kratos/driver/config/testhelpers"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/internal/testhelpers"
"github.com/ory/kratos/persistence"
idpersistence "github.com/ory/kratos/persistence/sql/identity"
"github.com/ory/kratos/schema"
"github.com/ory/kratos/x"
"github.com/ory/x/assertx"
"github.com/ory/x/crdbx"
"github.com/ory/x/errorsx"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/randx"
Expand Down Expand Up @@ -1214,21 +1213,21 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
t.Run("suite=credential-types", func(t *testing.T) {
for _, ct := range identity.AllCredentialTypes {
t.Run("type="+ct.String(), func(t *testing.T) {
id, err := p.FindIdentityCredentialsTypeByName(ctx, ct)
id, err := idpersistence.FindIdentityCredentialsTypeByName(p.GetConnection(ctx), ct)
require.NoError(t, err)

require.NotEqual(t, uuid.Nil, id)
name, err := p.FindIdentityCredentialsTypeByID(ctx, id)
name, err := idpersistence.FindIdentityCredentialsTypeByID(p.GetConnection(ctx), id)
require.NoError(t, err)

assert.Equal(t, ct, name)
})
}

_, err := p.FindIdentityCredentialsTypeByName(ctx, "unknown")
_, err := idpersistence.FindIdentityCredentialsTypeByName(p.GetConnection(ctx), "unknown")
require.Error(t, err)

_, err = p.FindIdentityCredentialsTypeByID(ctx, x.NewUUID())
_, err = idpersistence.FindIdentityCredentialsTypeByID(p.GetConnection(ctx), x.NewUUID())
require.Error(t, err)
})

Expand Down
162 changes: 79 additions & 83 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"sort"
"strings"
"sync"
"time"

"github.com/ory/kratos/x/events"
Expand Down Expand Up @@ -60,17 +61,12 @@ type IdentityPersister struct {
r dependencies
c *pop.Connection
nid uuid.UUID

credentialTypesID *x.SyncMap[uuid.UUID, identity.CredentialsType]
credentialTypesName *x.SyncMap[identity.CredentialsType, uuid.UUID]
}

func NewPersister(r dependencies, c *pop.Connection) *IdentityPersister {
return &IdentityPersister{
c: c,
r: r,
credentialTypesID: x.NewSyncMap[uuid.UUID, identity.CredentialsType](),
credentialTypesName: x.NewSyncMap[identity.CredentialsType, uuid.UUID](),
c: c,
r: r,
}
}

Expand Down Expand Up @@ -313,7 +309,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn
cred.Config = sqlxx.JSONRawMessage("{}")
}

ct, err := p.FindIdentityCredentialsTypeByName(ctx, cred.Type)
ct, err := FindIdentityCredentialsTypeByName(conn, cred.Type)
if err != nil {
return err
}
Expand Down Expand Up @@ -344,7 +340,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn
"Unable to create identity credentials with missing or empty identifier."))
}

ct, err := p.FindIdentityCredentialsTypeByName(ctx, cred.Type)
ct, err := FindIdentityCredentialsTypeByName(conn, cred.Type)
if err != nil {
return err
}
Expand Down Expand Up @@ -662,10 +658,7 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *
attribute.Stringer("network.id", p.NetworkID(ctx))))
defer otelx.End(span, &err)

var (
con = p.GetConnection(ctx)
nid = p.NetworkID(ctx)
)
nid := p.NetworkID(ctx)

eg, ctx := errgroup.WithContext(ctx)
if expand.Has(identity.ExpandFieldRecoveryAddresses) {
Expand All @@ -674,7 +667,7 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *
// from complaining incorrectly.
//
// https://github.com/gobuffalo/pop/issues/723
if err := con.WithContext(ctx).
if err := p.GetConnection(ctx).WithContext(ctx).
Where("identity_id = ? AND nid = ?", i.ID, nid).
Order("id ASC").
All(&i.RecoveryAddresses); err != nil {
Expand All @@ -690,7 +683,7 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *
// from complaining incorrectly.
//
// https://github.com/gobuffalo/pop/issues/723
if err := con.WithContext(ctx).
if err := p.GetConnection(ctx).WithContext(ctx).
Order("id ASC").
Where("identity_id = ? AND nid = ?", i.ID, nid).
All(&i.VerifiableAddresses); err != nil {
Expand All @@ -706,9 +699,9 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *
// from complaining incorrectly.
//
// https://github.com/gobuffalo/pop/issues/723
con := con.WithContext(ctx)
creds, err := QueryForCredentials(con,
Where{"(identity_credentials.identity_id = ? AND identity_credentials.nid = ?)", []interface{}{i.ID, nid}})
creds, err := QueryForCredentials(p.GetConnection(ctx).WithContext(ctx),
Where{"identity_credentials.identity_id = ?", []interface{}{i.ID}},
Where{"identity_credentials.nid = ?", []interface{}{nid}})
if err != nil {
return err
}
Expand All @@ -733,16 +726,8 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *
}

type queryCredentials struct {
ID uuid.UUID `db:"cred_id"`
IdentityID uuid.UUID `db:"identity_id"`
NID uuid.UUID `db:"nid"`
Type identity.CredentialsType `db:"cred_type"`
TypeID uuid.UUID `db:"cred_type_id"`
Identifier string `db:"cred_identifier"`
Config sqlxx.JSONRawMessage `db:"cred_config"`
Version int `db:"cred_version"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
Identifier string `db:"cred_identifier"`
identity.Credentials
}

func (queryCredentials) TableName() string {
Expand All @@ -756,35 +741,23 @@ type Where struct {

// QueryForCredentials queries for identity credentials with custom WHERE
// clauses, returning the results resolved by the owning identity's UUID.
func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](map[identity.CredentialsType]identity.Credentials), error) {
ici := "identity_credential_identifiers"
switch con.Dialect.Name() {
case "cockroach":
ici += "@identity_credential_identifiers_identity_credential_id_idx"
case "sqlite3":
ici += " INDEXED BY identity_credential_identifiers_identity_credential_id_idx"
case "mysql":
ici += " USE INDEX(identity_credential_identifiers_identity_credential_id_idx)"
default:
// good luck 🤷‍♂️
}
func QueryForCredentials(con *pop.Connection, where ...Where) (credentialsPerIdentity map[uuid.UUID](map[identity.CredentialsType]identity.Credentials), err error) {
// This query has been meticulously crafted to be as fast as possible.
// If you touch it, you will likely introduce a performance regression.
q := con.Select(
"identity_credentials.id cred_id",
"identity_credentials.identity_id identity_id",
"identity_credentials.nid nid",
"ict.name cred_type",
"ict.id cred_type_id",
"COALESCE(identity_credential_identifiers.identifier, '') cred_identifier",
"identity_credentials.config cred_config",
"identity_credentials.version cred_version",
"identity_credentials.created_at created_at",
"identity_credentials.updated_at updated_at",
).InnerJoin(
"identity_credential_types ict",
"(identity_credentials.identity_credential_type_id = ict.id)",
).LeftJoin(
ici,
"identity_credentials.id",
"identity_credentials.identity_credential_type_id",
"identity_credentials.identity_id",
"identity_credentials.nid",
"identity_credentials.config",
"identity_credentials.version",
"identity_credentials.created_at",
"identity_credentials.updated_at",
).LeftJoin(identifiersTableNameWithIndexHint(con),
"identity_credential_identifiers.identity_credential_id = identity_credentials.id AND identity_credential_identifiers.nid = identity_credentials.nid",
).Order(
"identity_credentials.id ASC",
)
for _, w := range where {
q = q.Where("("+w.Condition+")", w.Args...)
Expand All @@ -793,8 +766,16 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma
if err := q.All(&results); err != nil {
return nil, sqlcon.HandleError(err)
}
credentialsPerIdentity := map[uuid.UUID](map[identity.CredentialsType]identity.Credentials){}

// assemble
credentialsPerIdentity = map[uuid.UUID](map[identity.CredentialsType]identity.Credentials){}
for _, res := range results {

res.Type, err = FindIdentityCredentialsTypeByID(con, res.IdentityCredentialTypeID)
if err != nil {
return nil, err
}

credentials, ok := credentialsPerIdentity[res.IdentityID]
if !ok {
credentialsPerIdentity[res.IdentityID] = make(map[identity.CredentialsType]identity.Credentials)
Expand All @@ -807,20 +788,10 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma
if identifiers == nil {
identifiers = make([]string, 0)
}
c := identity.Credentials{
ID: res.ID,
IdentityID: res.IdentityID,
NID: res.NID,
Type: res.Type,
IdentityCredentialTypeID: res.TypeID,
Identifiers: identifiers,
Config: res.Config,
Version: res.Version,
CreatedAt: res.CreatedAt,
UpdatedAt: res.UpdatedAt,
}
credentials[res.Type] = c
res.Identifiers = identifiers
credentials[res.Type] = res.Credentials
}

// We need deterministic ordering for testing, but sorting in the
// database can be expensive under certain circumstances.
for _, creds := range credentialsPerIdentity {
Expand All @@ -831,6 +802,21 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma
return credentialsPerIdentity, nil
}

func identifiersTableNameWithIndexHint(con *pop.Connection) string {
ici := "identity_credential_identifiers"
switch con.Dialect.Name() {
case "cockroach":
ici += "@identity_credential_identifiers_nid_i_ici_idx"
case "sqlite3":
ici += " INDEXED BY identity_credential_identifiers_nid_i_ici_idx"
case "mysql":
ici += " USE INDEX(identity_credential_identifiers_nid_i_ici_idx)"
default:
// good luck 🤷‍♂️
}
return ici
}

func paginationAttributes(params *identity.ListIdentityParameters, paginator *keysetpagination.Paginator) []attribute.KeyValue {
attrs := []attribute.KeyValue{
attribute.StringSlice("expand", params.Expand.ToEager()),
Expand All @@ -856,7 +842,7 @@ func (p *IdentityPersister) getCredentialTypeIDs(ctx context.Context, credential
result := map[identity.CredentialsType]uuid.UUID{}

for _, ct := range credentialTypes {
typeID, err := p.FindIdentityCredentialsTypeByName(ctx, ct)
typeID, err := FindIdentityCredentialsTypeByName(p.GetConnection(ctx), ct)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -991,7 +977,7 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.
switch e {
case identity.ExpandFieldCredentials:
creds, err := QueryForCredentials(con,
Where{"identity_credentials.nid = ?", []any{nid}},
Where{"identity_credentials.nid = ?", []interface{}{nid}},
Where{"identity_credentials.identity_id IN (?)", identityIDs})
if err != nil {
return err
Expand Down Expand Up @@ -1315,14 +1301,19 @@ func (p *IdentityPersister) InjectTraitsSchemaURL(ctx context.Context, i *identi
return nil
}

func (p *IdentityPersister) FindIdentityCredentialsTypeByID(ctx context.Context, id uuid.UUID) (identity.CredentialsType, error) {
result, found := p.credentialTypesID.Load(id)
var (
credentialTypesID = x.NewSyncMap[uuid.UUID, identity.CredentialsType]()
credentialTypesName = x.NewSyncMap[identity.CredentialsType, uuid.UUID]()
)

func FindIdentityCredentialsTypeByID(con *pop.Connection, id uuid.UUID) (identity.CredentialsType, error) {
result, found := credentialTypesID.Load(id)
if !found {
if err := p.loadCredentialTypes(ctx); err != nil {
if err := loadCredentialTypes(con); err != nil {
return "", err
}

result, found = p.credentialTypesID.Load(id)
result, found = credentialTypesID.Load(id)
}

if !found {
Expand All @@ -1332,14 +1323,14 @@ func (p *IdentityPersister) FindIdentityCredentialsTypeByID(ctx context.Context,
return result, nil
}

func (p *IdentityPersister) FindIdentityCredentialsTypeByName(ctx context.Context, ct identity.CredentialsType) (uuid.UUID, error) {
result, found := p.credentialTypesName.Load(ct)
func FindIdentityCredentialsTypeByName(con *pop.Connection, ct identity.CredentialsType) (uuid.UUID, error) {
result, found := credentialTypesName.Load(ct)
if !found {
if err := p.loadCredentialTypes(ctx); err != nil {
if err := loadCredentialTypes(con); err != nil {
return uuid.Nil, err
}

result, found = p.credentialTypesName.Load(ct)
result, found = credentialTypesName.Load(ct)
}

if !found {
Expand All @@ -1349,18 +1340,23 @@ func (p *IdentityPersister) FindIdentityCredentialsTypeByName(ctx context.Contex
return result, nil
}

func (p *IdentityPersister) loadCredentialTypes(ctx context.Context) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.identity.loadCredentialTypes")
var mux sync.Mutex

func loadCredentialTypes(con *pop.Connection) (err error) {
ctx, span := trace.SpanFromContext(con.Context()).TracerProvider().Tracer("").Start(con.Context(), "persistence.sql.identity.loadCredentialTypes")
defer otelx.End(span, &err)
_ = ctx

mux.Lock()
defer mux.Unlock()
var tt []identity.CredentialsTypeTable
if err := p.GetConnection(ctx).All(&tt); err != nil {
if err := con.WithContext(ctx).All(&tt); err != nil {
return sqlcon.HandleError(err)
}

for _, t := range tt {
p.credentialTypesID.Store(t.ID, t.Name)
p.credentialTypesName.Store(t.Name, t.ID)
credentialTypesID.Store(t.ID, t.Name)
credentialTypesName.Store(t.Name, t.ID)
}

return nil
Expand Down
5 changes: 3 additions & 2 deletions persistence/sql/migratest/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ func testDatabase(t *testing.T, db string, c *pop.Connection) {
})
})

tm.DumpMigrations = false
require.NoError(t, tm.Down(ctx, -1))
tm.DumpMigrations = false // true for debug
err = tm.Down(ctx, -1) // for easy breakpointing
require.NoError(t, err)
}

0 comments on commit da27754

Please sign in to comment.