Skip to content

Commit

Permalink
refactor: move credential configs for oidc and password
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Feb 26, 2022
1 parent 6da2370 commit 50ac851
Show file tree
Hide file tree
Showing 19 changed files with 181 additions and 108 deletions.
6 changes: 2 additions & 4 deletions cmd/identities/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"testing"

"github.com/ory/kratos/cmd/identities"
"github.com/ory/kratos/selfservice/strategy/oidc"

"github.com/ory/x/assertx"

"github.com/ory/kratos/x"
Expand Down Expand Up @@ -55,7 +53,7 @@ func TestGetCmd(t *testing.T) {

t.Run("case=gets a single identity with oidc credentials", func(t *testing.T) {
applyCredentials := func(identifier, accessToken, refreshToken, idToken string, encrypt bool) identity.Credentials {
toJson := func(c oidc.CredentialsConfig) []byte {
toJson := func(c identity.CredentialsOIDC) []byte {
out, err := json.Marshal(&c)
require.NoError(t, err)
return out
Expand All @@ -69,7 +67,7 @@ func TestGetCmd(t *testing.T) {
return identity.Credentials{
Type: identity.CredentialsTypeOIDC,
Identifiers: []string{"bar:" + identifier},
Config: toJson(oidc.CredentialsConfig{Providers: []oidc.ProviderCredentialsConfig{
Config: toJson(identity.CredentialsOIDC{Providers: []identity.CredentialsOIDCProvider{
{
Subject: "foo",
Provider: "bar",
Expand Down
57 changes: 57 additions & 0 deletions identity/credentials_oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package identity

import (
"bytes"
"encoding/json"
"fmt"

"github.com/pkg/errors"

"github.com/ory/kratos/x"
)

// CredentialsOIDC is contains the configuration for credentials of the type oidc.
//
// swagger:model identityCredentialsOidc
type CredentialsOIDC struct {
Providers []CredentialsOIDCProvider `json:"providers"`
}

// CredentialsOIDCProvider is contains a specific OpenID COnnect credential for a particular connection (e.g. Google).
//
// swagger:model identityCredentialsOidcProvider
type CredentialsOIDCProvider struct {
Subject string `json:"subject"`
Provider string `json:"provider"`
InitialIDToken string `json:"initial_id_token"`
InitialAccessToken string `json:"initial_access_token"`
InitialRefreshToken string `json:"initial_refresh_token"`
}

// NewCredentialsOIDC creates a new OIDC credential.
func NewCredentialsOIDC(idToken, accessToken, refreshToken, provider, subject string) (*Credentials, error) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(CredentialsOIDC{
Providers: []CredentialsOIDCProvider{
{
Subject: subject,
Provider: provider,
InitialIDToken: idToken,
InitialAccessToken: accessToken,
InitialRefreshToken: refreshToken,
}},
}); err != nil {
return nil, errors.WithStack(x.PseudoPanic.
WithDebugf("Unable to encode password options to JSON: %s", err))
}

return &Credentials{
Type: CredentialsTypeOIDC,
Identifiers: []string{OIDCUniqueID(provider, subject)},
Config: b.Bytes(),
}, nil
}

func OIDCUniqueID(provider, subject string) string {
return fmt.Sprintf("%s:%s", provider, subject)
}
9 changes: 9 additions & 0 deletions identity/credentials_password.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package identity

// CredentialsPassword is contains the configuration for credentials of the type password.
//
// swagger:model identityCredentialsPassword
type CredentialsPassword struct {
// HashedPassword is a hash-representation of the password.
HashedPassword string `json:"hashed_password"`
}
60 changes: 59 additions & 1 deletion identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,58 @@ type AdminCreateIdentityBody struct {
// required: true
Traits json.RawMessage `json:"traits"`

// Credentials represents all credentials that can be used for authenticating this identity.
//
// Use this structure to import credentials for a user.
Credentials *AdminIdentityImportCredentials `json:"credentials"`

// VerifiableAddresses contains all the addresses that can be verified by the user.
//
// Use this structure to import verified addresses for an identity. Please keep in mind
// that the address needs to be represented in the Identity Schema or this field will be overwritten
// on the next identity update.
VerifiableAddresses []VerifiableAddress `json:"verifiable_addresses"`

// RecoveryAddresses contains all the addresses that can be used to recover an identity.
//
// Use this structure to import recovery addresses for an identity. Please keep in mind
// that the address needs to be represented in the Identity Schema or this field will be overwritten
// on the next identity update.
RecoveryAddresses []RecoveryAddress `json:"recovery_addresses"`

// State is the identity's state.
//
// required: false
State State `json:"state"`
}

// swagger:model adminIdentityImportCredentials
type AdminIdentityImportCredentials struct {
// Password if set will import a password credential.
Password *AdminIdentityImportCredentialsPassword `json:"password,omitempty"`

// OIDC if set will import an OIDC credential.
OIDC *AdminIdentityImportCredentialsOIDC `json:"oidc,omitempty"`
}

// swagger:model AdminCreateIdentityImportCredentialsPassword
type AdminIdentityImportCredentialsPassword struct {
// The hashed password in [PHC format]( https://www.ory.sh/docs/kratos/concepts/credentials/username-email-password#hashed-password-format)
HashedPassword string `json:"hashed_password"`

// The password in plain text if no hash is available.
Password string `json:"password"`
}

// swagger:model AdminCreateIdentityImportCredentialsOIDC
type AdminIdentityImportCredentialsOIDC struct {
// The subject (`sub`) of the OpenID Connect connection. Usually the `sub` field of the ID Token.
Subject string `json:"subject"`

// The OpenID Connect provider to link the subject to. Usually something like `google` or `github`.
Provider string `json:"provider"`
}

// swagger:route POST /identities v0alpha2 adminCreateIdentity
//
// Create an Identity
Expand Down Expand Up @@ -249,7 +295,19 @@ func (h *Handler) create(w http.ResponseWriter, r *http.Request, _ httprouter.Pa
}
state = cr.State
}
i := &Identity{SchemaID: cr.SchemaID, Traits: []byte(cr.Traits), State: state, StateChangedAt: &stateChangedAt}

i := &Identity{
SchemaID: cr.SchemaID,
Traits: []byte(cr.Traits),
State: state,
StateChangedAt: &stateChangedAt,
//Credentials: cr.Credentials,
VerifiableAddresses: cr.VerifiableAddresses,
RecoveryAddresses: cr.RecoveryAddresses,
}
//i.Traits = identity.Traits(p.Traits)
//i.SetCredentials(s.ID(), identity.Credentials{Type: s.ID(), Identifiers: []string{}, Config: co})

if err := h.r.IdentityManager().Create(r.Context(), i); err != nil {
h.r.Writer().WriteError(w, r, err)
return
Expand Down
6 changes: 2 additions & 4 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
"testing"
"time"

"github.com/ory/kratos/selfservice/strategy/oidc"

"github.com/bxcodec/faker/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -174,7 +172,7 @@ func TestHandler(t *testing.T) {
}

iId := x.NewUUID()
toJson := func(c oidc.CredentialsConfig) []byte {
toJson := func(c identity.CredentialsOIDC) []byte {
out, err := json.Marshal(&c)
require.NoError(t, err)
return out
Expand All @@ -186,7 +184,7 @@ func TestHandler(t *testing.T) {
identity.CredentialsTypeOIDC: {
Type: identity.CredentialsTypeOIDC,
Identifiers: []string{"bar:" + identifier},
Config: toJson(oidc.CredentialsConfig{Providers: []oidc.ProviderCredentialsConfig{
Config: toJson(identity.CredentialsOIDC{Providers: []identity.CredentialsOIDCProvider{
{
Subject: "foo",
Provider: "bar",
Expand Down
17 changes: 17 additions & 0 deletions identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,23 @@ func (i *Identity) SetCredentials(t CredentialsType, c Credentials) {
i.Credentials[t] = c
}

func (i *Identity) SetCredentialsWithConfig(t CredentialsType, c Credentials, conf interface{}) (err error) {
i.lock().Lock()
defer i.lock().Unlock()
if i.Credentials == nil {
i.Credentials = make(map[CredentialsType]Credentials)
}

c.Config, err = json.Marshal(conf)
if err != nil {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to encode %s credentials to JSON: %s", t, err))
}

c.Type = t
i.Credentials[t] = c
return nil
}

func (i *Identity) DeleteCredentialsType(t CredentialsType) {
i.lock().Lock()
defer i.lock().Unlock()
Expand Down
7 changes: 1 addition & 6 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"path/filepath"
"strings"
Expand Down Expand Up @@ -116,7 +115,7 @@ type authCodeContainer struct {
func (s *Strategy) CountActiveCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
for _, c := range cc {
if c.Type == s.ID() && gjson.ValidBytes(c.Config) {
var conf CredentialsConfig
var conf identity.CredentialsOIDC
if err = json.Unmarshal(c.Config, &conf); err != nil {
return 0, errors.WithStack(err)
}
Expand Down Expand Up @@ -366,10 +365,6 @@ func (s *Strategy) handleCallback(w http.ResponseWriter, r *http.Request, ps htt
}
}

func uid(provider, subject string) string {
return fmt.Sprintf("%s:%s", provider, subject)
}

func (s *Strategy) populateMethod(r *http.Request, c *container.Container, message func(provider string) *text.Message) error {
conf, err := s.Config(r.Context())
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type SubmitSelfServiceLoginFlowWithOidcMethodBody struct {
}

func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login.Flow, token *oauth2.Token, claims *Claims, provider Provider, container *authCodeContainer) (*registration.Flow, error) {
i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, uid(provider.Config().ID, claims.Subject))
i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject))
if err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// If no account was found we're "manually" creating a new registration flow and redirecting the browser
Expand Down Expand Up @@ -104,7 +104,7 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
}

var o CredentialsConfig
var o identity.CredentialsOIDC
if err := json.NewDecoder(bytes.NewBuffer(c.Config)).Decode(&o); err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The password credentials could not be decoded properly").WithDebug(err.Error())))
}
Expand Down
4 changes: 2 additions & 2 deletions selfservice/strategy/oidc/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
}

func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, a *registration.Flow, token *oauth2.Token, claims *Claims, provider Provider, container *authCodeContainer) (*login.Flow, error) {
if _, _, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, uid(provider.Config().ID, claims.Subject)); err == nil {
if _, _, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject)); err == nil {
// If the identity already exists, we should perform the login flow instead.

// That will execute the "pre registration" hook which allows to e.g. disallow this flow. The registration
Expand Down Expand Up @@ -254,7 +254,7 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, a
return nil, s.handleError(w, r, a, provider.Config().ID, i.Traits, err)
}

creds, err := NewCredentials(it, cat, crt, provider.Config().ID, claims.Subject)
creds, err := identity.NewCredentialsOIDC(it, cat, crt, provider.Config().ID, claims.Subject)
if err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, i.Traits, err)
}
Expand Down
20 changes: 10 additions & 10 deletions selfservice/strategy/oidc/strategy_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (s *Strategy) linkedProviders(ctx context.Context, r *http.Request, conf *C
return nil, nil
}

var available CredentialsConfig
var available identity.CredentialsOIDC
if err := json.Unmarshal(creds.Config, &available); err != nil {
return nil, errors.WithStack(err)
}
Expand Down Expand Up @@ -115,7 +115,7 @@ func (s *Strategy) linkedProviders(ctx context.Context, r *http.Request, conf *C
}

func (s *Strategy) linkableProviders(ctx context.Context, r *http.Request, conf *ConfigurationCollection, confidential *identity.Identity) ([]Provider, error) {
var available CredentialsConfig
var available identity.CredentialsOIDC
creds, ok := confidential.GetCredentials(s.ID())
if ok {
if err := json.Unmarshal(creds.Config, &available); err != nil {
Expand Down Expand Up @@ -394,18 +394,18 @@ func (s *Strategy) linkProvider(w http.ResponseWriter, r *http.Request, ctxUpdat
return s.handleSettingsError(w, r, ctxUpdate, p, err)
}

var conf CredentialsConfig
var conf identity.CredentialsOIDC
creds, err := i.ParseCredentials(s.ID(), &conf)
if errors.Is(err, herodot.ErrNotFound) {
var err error
if creds, err = NewCredentials(it, cat, crt, provider.Config().ID, claims.Subject); err != nil {
if creds, err = identity.NewCredentialsOIDC(it, cat, crt, provider.Config().ID, claims.Subject); err != nil {
return s.handleSettingsError(w, r, ctxUpdate, p, err)
}
} else if err != nil {
return s.handleSettingsError(w, r, ctxUpdate, p, err)
} else {
creds.Identifiers = append(creds.Identifiers, uid(provider.Config().ID, claims.Subject))
conf.Providers = append(conf.Providers, ProviderCredentialsConfig{
creds.Identifiers = append(creds.Identifiers, identity.OIDCUniqueID(provider.Config().ID, claims.Subject))
conf.Providers = append(conf.Providers, identity.CredentialsOIDCProvider{
Subject: claims.Subject, Provider: provider.Config().ID,
InitialAccessToken: cat,
InitialRefreshToken: crt,
Expand Down Expand Up @@ -448,20 +448,20 @@ func (s *Strategy) unlinkProvider(w http.ResponseWriter, r *http.Request, ctxUpd
return s.handleSettingsError(w, r, ctxUpdate, p, err)
}

var cc CredentialsConfig
var cc identity.CredentialsOIDC
creds, err := i.ParseCredentials(s.ID(), &cc)
if err != nil {
return s.handleSettingsError(w, r, ctxUpdate, p, errors.WithStack(UnknownConnectionValidationError))
}

var found bool
var updatedProviders []ProviderCredentialsConfig
var updatedProviders []identity.CredentialsOIDCProvider
var updatedIdentifiers []string
for _, available := range availableProviders {
if p.Unlink == available.Config().ID {
for _, link := range cc.Providers {
if link.Provider != p.Unlink {
updatedIdentifiers = append(updatedIdentifiers, uid(link.Provider, link.Subject))
updatedIdentifiers = append(updatedIdentifiers, identity.OIDCUniqueID(link.Provider, link.Subject))
updatedProviders = append(updatedProviders, link)
} else {
found = true
Expand All @@ -475,7 +475,7 @@ func (s *Strategy) unlinkProvider(w http.ResponseWriter, r *http.Request, ctxUpd
}

creds.Identifiers = updatedIdentifiers
creds.Config, err = json.Marshal(&CredentialsConfig{updatedProviders})
creds.Config, err = json.Marshal(&identity.CredentialsOIDC{updatedProviders})
if err != nil {
return s.handleSettingsError(w, r, ctxUpdate, p, errors.WithStack(err))

Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/oidc/strategy_settings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func TestSettingsStrategy(t *testing.T) {
actual, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), iid)
require.NoError(t, err)

var cc oidc.CredentialsConfig
var cc identity.CredentialsOIDC
creds, err := actual.ParseCredentials(identity.CredentialsTypeOIDC, &cc)
require.NoError(t, err)

Expand Down
Loading

0 comments on commit 50ac851

Please sign in to comment.