Skip to content

Commit

Permalink
feat(saml): update attributes mapping + remove slo
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultHerard <[email protected]>

Co-authored-by: sebferrer <[email protected]>
  • Loading branch information
ThibHrrd and sebferrer committed Feb 16, 2023
1 parent 5636dc0 commit b730cfb
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 129 deletions.
15 changes: 2 additions & 13 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -446,14 +446,6 @@
"https://foo.bar.com/path/to/certificate"
]
},
"idp_logout_url": {
"title": "IDP Logout URL",
"description": "The URL of the Single Log Out (SLO) API of the IDP",
"type": "string",
"examples": [
"https://path/to/logout"
]
},
"idp_sso_url": {
"title": "IDP SSO URL",
"description": "The URL of the SSO Handler at the IDP",
Expand Down Expand Up @@ -482,19 +474,16 @@
},
"then": {
"required": [
"idp_logout_url",
"idp_certificate_path",
"idp_entity_id"
"idp_entity_id",
"idp_sso_url"
]
},
"else":{
"properties": {
"idp_certificate_path": {
"const": {}
},
"idp_logout_url": {
"const": {}
},
"idp_entity_id":{
"const":{}
},
Expand Down
7 changes: 0 additions & 7 deletions selfservice/strategy/saml/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ func TestInitSAMLWithoutProvider(t *testing.T) {
idpInformation["idp_sso_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"
idpInformation["idp_entity_id"] = "https://samltest.id/saml/idp"
idpInformation["idp_certificate_path"] = "file://testdata/samlkratos.crt"
idpInformation["idp_logout_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"

// Initiates without service provider
ViperSetProviderConfig(
Expand Down Expand Up @@ -75,7 +74,6 @@ func TestInitSAMLWithoutPoviderID(t *testing.T) {
idpInformation["idp_sso_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"
idpInformation["idp_entity_id"] = "https://samltest.id/saml/idp"
idpInformation["idp_certificate_path"] = "file://testdata/samlkratos.crt"
idpInformation["idp_logout_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"

// Initiates the service provider
ViperSetProviderConfig(
Expand Down Expand Up @@ -125,7 +123,6 @@ func TestInitSAMLWithoutPoviderLabel(t *testing.T) {
idpInformation["idp_sso_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"
idpInformation["idp_entity_id"] = "https://samltest.id/saml/idp"
idpInformation["idp_certificate_path"] = "file://testdata/samlkratos.crt"
idpInformation["idp_logout_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"

// Initiates the service provider
ViperSetProviderConfig(
Expand Down Expand Up @@ -174,7 +171,6 @@ func TestAttributesMapWithoutID(t *testing.T) {
idpInformation["idp_sso_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"
idpInformation["idp_entity_id"] = "https://samltest.id/saml/idp"
idpInformation["idp_certificate_path"] = "file://testdata/samlkratos.crt"
idpInformation["idp_logout_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"

// Initiates the service provider
ViperSetProviderConfig(
Expand Down Expand Up @@ -226,7 +222,6 @@ func TestAttributesMapWithAnExtraField(t *testing.T) {
idpInformation["idp_sso_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"
idpInformation["idp_entity_id"] = "https://samltest.id/saml/idp"
idpInformation["idp_certificate_path"] = "file://testdata/idp_cert.pem"
idpInformation["idp_logout_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"

// Initiates the service provider
ViperSetProviderConfig(
Expand Down Expand Up @@ -319,7 +314,6 @@ func TestInitSAMLWithMissingIDPInformationField(t *testing.T) {
idpInformation := make(map[string]string)
idpInformation["idp_sso_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"
idpInformation["idp_entity_id"] = "https://samltest.id/saml/idp"
idpInformation["idp_logout_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"

// Initiates the service provider
ViperSetProviderConfig(
Expand Down Expand Up @@ -369,7 +363,6 @@ func TestInitSAMLWithExtraIDPInformationField(t *testing.T) {
idpInformation["idp_sso_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"
idpInformation["idp_entity_id"] = "https://samltest.id/saml/idp"
idpInformation["idp_certificate_path"] = "file://testdata/samlkratos.crt"
idpInformation["idp_logout_url"] = "https://samltest.id/idp/profile/SAML2/Redirect/SSO"
idpInformation["evil"] = "evil"

// Initiates the service provider
Expand Down
26 changes: 6 additions & 20 deletions selfservice/strategy/saml/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi
}

// Key pair to encrypt and sign SAML requests
keyPair, err := tls.LoadX509KeyPair(strings.Replace(providerConfig.PublicCertPath, "file://", "", 1), strings.Replace(providerConfig.PrivateKeyPath, "file://", "", 1)) // TODO : Fetcher
keyPair, err := tls.LoadX509KeyPair(strings.Replace(providerConfig.PublicCertPath, "file://", "", 1), strings.Replace(providerConfig.PrivateKeyPath, "file://", "", 1))
if err != nil {
return herodot.ErrNotFound.WithTrace(err) // TODO : Replace with File not found error
return herodot.ErrInternalServerError.WithReason("An error occurred while retrieving the key pair used by SAML")
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
return herodot.ErrNotFound.WithTrace(err)
return herodot.ErrInternalServerError.WithReason("An error occurred while using the certificate associated with SAML")
}

var idpMetadata *samlidp.EntityDescriptor
Expand Down Expand Up @@ -187,12 +187,6 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi
return herodot.ErrNotFound.WithTrace(err)
}

// The IDP Logout URL
IDPlogoutURL, err := url.Parse(providerConfig.IDPInformation["idp_logout_url"])
if err != nil {
return herodot.ErrNotFound.WithTrace(err)
}

// The certificate of the IDP
certificateBuffer, err := fetcher.NewFetcher().Fetch(providerConfig.IDPInformation["idp_certificate_path"])
if err != nil {
Expand All @@ -212,12 +206,9 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi

// Because the metadata file is not provided, we need to simulate an IDP to create artificial metadata from the data entered in the conf file
tempIDP := samlidp.IdentityProvider{
Key: nil,
Certificate: IDPCertificate,
Logger: nil,
MetadataURL: *entityIDURL,
SSOURL: *IDPSSOURL,
LogoutURL: *IDPlogoutURL,
}

// Now we assign our reconstructed metadata to our SP
Expand Down Expand Up @@ -282,7 +273,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi
}

// Crewjam library use default route for ACS and metadata but we want to overwrite them
metadata, err := url.Parse(publicUrlString + RouteMetadata)
metadata, err := url.Parse(publicUrlString + strings.Replace(RouteMetadata, ":provider", providerConfig.ID, 1))
if err != nil {
return herodot.ErrNotFound.WithTrace(err)
}
Expand All @@ -302,7 +293,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi
// Return the singleton MiddleWare
func GetMiddleware(pid string) (*samlsp.Middleware, error) {
if samlMiddlewares[pid] == nil {
return nil, errors.Errorf("An error occurred while retrieving the middeware, it is null") // TODO : Improve error message
return nil, errors.Errorf("An error occurred during the connection with SAML.")
}
return samlMiddlewares[pid], nil
}
Expand Down Expand Up @@ -342,17 +333,12 @@ func CreateSAMLProviderConfig(config config.Config, ctx context.Context, pid str
return nil, ErrInvalidSAMLConfiguration.WithReasonf("Please include your Identity Provider information in the configuration file.").WithTrace(err)
}

/**
* SAMLTODO errors
*/
// _, sso_exists := providerConfig.IDPInformation["idp_sso_url"]
_, sso_exists := providerConfig.IDPInformation["idp_sso_url"]
_, entity_id_exists := providerConfig.IDPInformation["idp_entity_id"]
_, certificate_exists := providerConfig.IDPInformation["idp_certificate_path"]
_, logout_url_exists := providerConfig.IDPInformation["idp_logout_url"]
_, metadata_exists := providerConfig.IDPInformation["idp_metadata_url"]

if (!metadata_exists && (!sso_exists || !entity_id_exists || !certificate_exists || !logout_url_exists)) || len(providerConfig.IDPInformation) > 4 {
if (!metadata_exists && (!sso_exists || !entity_id_exists || !certificate_exists)) || len(providerConfig.IDPInformation) > 3 {
return nil, ErrInvalidSAMLConfiguration.WithReason("Please check your IDP information in the configuration file").WithTrace(err)
}

Expand Down
6 changes: 3 additions & 3 deletions selfservice/strategy/saml/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestInitMiddleWareWithMetadata(t *testing.T) {
require.NoError(t, err)
assert.Check(t, middleWare != nil)
assert.Check(t, middleWare.ServiceProvider.IDPMetadata != nil)
assert.Check(t, middleWare.ServiceProvider.MetadataURL.Path == "/self-service/methods/saml/metadata/:provider")
assert.Check(t, middleWare.ServiceProvider.MetadataURL.Path == "/self-service/methods/saml/metadata/samlProvider")
assert.Check(t, middleWare.ServiceProvider.IDPMetadata.EntityID == "https://idp.testshib.org/idp/shibboleth")
}

Expand All @@ -44,7 +44,7 @@ func TestInitMiddleWareWithoutMetadata(t *testing.T) {
require.NoError(t, err)
assert.Check(t, middleWare != nil)
assert.Check(t, middleWare.ServiceProvider.IDPMetadata != nil)
assert.Check(t, middleWare.ServiceProvider.MetadataURL.Path == "/self-service/methods/saml/metadata/:provider")
assert.Check(t, middleWare.ServiceProvider.MetadataURL.Path == "/self-service/methods/saml/metadata/samlProvider")
assert.Check(t, middleWare.ServiceProvider.IDPMetadata.EntityID == "https://samltest.id/saml/idp")
}

Expand All @@ -63,7 +63,7 @@ func TestGetMiddleware(t *testing.T) {
require.NoError(t, err)
assert.Check(t, middleWare != nil)
assert.Check(t, middleWare.ServiceProvider.IDPMetadata != nil)
assert.Check(t, middleWare.ServiceProvider.MetadataURL.Path == "/self-service/methods/saml/metadata/:provider")
assert.Check(t, middleWare.ServiceProvider.MetadataURL.Path == "/self-service/methods/saml/metadata/samlProvider")
assert.Check(t, middleWare.ServiceProvider.IDPMetadata.EntityID == "https://idp.testshib.org/idp/shibboleth")
}

Expand Down
1 change: 1 addition & 0 deletions selfservice/strategy/saml/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type registrationStrategyDependencies interface {
x.WriterProvider
x.CSRFTokenGeneratorProvider
x.CSRFProvider
x.HTTPClientProvider

config.Provider

Expand Down
1 change: 0 additions & 1 deletion selfservice/strategy/saml/strategy_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ func InitTestMiddlewareWithoutMetadata(t *testing.T, idpSsoUrl string, idpEntity
idpInformation["idp_sso_url"] = idpSsoUrl
idpInformation["idp_entity_id"] = idpEntityId
idpInformation["idp_certificate_path"] = idpCertifiatePath
idpInformation["idp_logout_url"] = idpLogoutUrl

return InitTestMiddleware(t, idpInformation)
}
Expand Down
105 changes: 40 additions & 65 deletions selfservice/strategy/saml/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@ package saml

import (
"bytes"
"context"
"encoding/json"
"net/http"

"github.com/google/go-jsonnet"
"github.com/pkg/errors"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/x/decoderx"

"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/text"

"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/ory/kratos/x"
Expand All @@ -31,87 +27,66 @@ func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) {
s.setRoutes(r)
}

func (s *Strategy) GetRegistrationIdentity(r *http.Request, ctx context.Context, provider Provider, claims *Claims, logsEnabled bool) (*identity.Identity, error) {
// Fetch fetches the file contents from the mapper file.
jn, err := s.f.Fetch(provider.Config().Mapper)
if err != nil {
return nil, err
}

func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider) (*identity.Identity, error) {
var jsonClaims bytes.Buffer
if err := json.NewEncoder(&jsonClaims).Encode(claims); err != nil {
return nil, err
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
}

// Identity Creation
i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
i := identity.NewIdentity(s.d.Config().DefaultIdentityTraitsSchemaID(r.Context()))
if err := s.setTraits(w, r, a, claims, provider, jsonClaims, i); err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, i.Traits, err)
}

vm := jsonnet.MakeVM()
vm.ExtCode("claims", jsonClaims.String())
evaluated, err := vm.EvaluateAnonymousSnippet(provider.Config().Mapper, jn.String())
s.d.Logger().
WithRequest(r).
WithField("saml_provider", provider.Config().ID).
WithSensitiveField("saml_claims", claims).
Debug("SAML Connect completed.")
return i, nil
}

func (s *Strategy) setTraits(w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, jsonClaims bytes.Buffer, i *identity.Identity) error {

traitsMap := make(map[string]interface{})
json.Unmarshal(jsonClaims.Bytes(), &traitsMap)
delete(traitsMap, "iss")
delete(traitsMap, "email_verified")
delete(traitsMap, "sub")
traits, err := json.Marshal(traitsMap)
if err != nil {
return nil, err
} else if traits := gjson.Get(evaluated, "identity.traits"); !traits.IsObject() {
i.Traits = []byte{'{', '}'}
if logsEnabled {
s.d.Logger().
WithRequest(r).
WithField("Provider", provider.Config().ID).
WithSensitiveField("saml_claims", claims).
WithField("mapper_jsonnet_output", evaluated).
WithField("mapper_jsonnet_url", provider.Config().Mapper).
Error("SAML Jsonnet mapper did not return an object for key identity.traits. Please check your Jsonnet code!")
}
} else {
i.Traits = []byte(traits.Raw)
return s.handleError(w, r, a, provider.Config().ID, i.Traits, err)
}
i.Traits = identity.Traits(traits)

s.d.Logger().
WithRequest(r).
WithField("oidc_provider", provider.Config().ID).
WithSensitiveField("identity_traits", i.Traits).
WithField("mapper_jsonnet_url", provider.Config().Mapper).
Debug("Merged form values and OpenID Connect Jsonnet output.")
return nil
}

if logsEnabled {
s.d.Logger().
WithRequest(r).
WithField("saml_provider", provider.Config().ID).
WithSensitiveField("saml_claims", claims).
WithSensitiveField("mapper_jsonnet_output", evaluated).
WithField("mapper_jsonnet_url", provider.Config().Mapper).
Debug("SAML Jsonnet mapper completed.")

s.d.Logger().
WithRequest(r).
WithField("saml_provider", provider.Config().ID).
WithSensitiveField("identity_traits", i.Traits).
WithSensitiveField("mapper_jsonnet_output", evaluated).
WithField("mapper_jsonnet_url", provider.Config().Mapper).
Debug("Merged form values and SAML Jsonnet output.")
func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, a *registration.Flow, provider Provider, claims *Claims) error {
i, err := s.createIdentity(w, r, a, claims, provider)
if err != nil {
return s.handleError(w, r, a, provider.Config().ID, nil, err)
}

// Verify the identity
if err := s.d.IdentityValidator().Validate(ctx, i); err != nil {
return i, err
if err := s.d.IdentityValidator().Validate(r.Context(), i); err != nil {
return s.handleError(w, r, a, provider.Config().ID, nil, err)
}

// Create new uniq credentials identifier for user is database
creds, err := identity.NewCredentialsSAML(claims.Subject, provider.Config().ID)
if err != nil {
return i, err
return s.handleError(w, r, a, provider.Config().ID, nil, err)
}

// Set the identifiers to the identity
i.SetCredentials(s.ID(), *creds)

return i, nil
}

func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, a *registration.Flow, provider Provider, claims *Claims) error {

i, err := s.GetRegistrationIdentity(r, r.Context(), provider, claims, true)
if err != nil {
if i == nil {
return s.handleError(w, r, a, provider.Config().ID, nil, err)
} else {
return s.handleError(w, r, a, provider.Config().ID, i.Traits, err)
}
}

if err := s.d.RegistrationExecutor().PostRegistrationHook(w, r, identity.CredentialsTypeSAML, a, i); err != nil {
return s.handleError(w, r, a, provider.Config().ID, i.Traits, err)
}
Expand Down
20 changes: 0 additions & 20 deletions selfservice/strategy/saml/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,26 +171,6 @@ func TestCountActiveCredentials(t *testing.T) {
gotest.Check(t, count == 1)
}

func TestGetRegistrationIdentity(t *testing.T) {
if testing.Short() {
t.Skip()
}

saml.DestroyMiddlewareIfExists("samlProvider")

middleware, strategy, _, _ := InitTestMiddlewareWithMetadata(t,
"file://testdata/SP_IDPMetadata.xml")

provider, _ := strategy.Provider(context.Background(), "samlProvider")
assertion, _ := GetAndDecryptAssertion(t, "./testdata/SP_SamlResponse.xml", middleware.ServiceProvider.Key)
attributes, _ := strategy.GetAttributesFromAssertion(assertion)
claims, _ := provider.Claims(context.Background(), strategy.D().Config(), attributes, "samlProvider")

i, err := strategy.GetRegistrationIdentity(nil, context.Background(), provider, claims, false)
require.NoError(t, err)
gotest.Check(t, i != nil)
}

func TestCountActiveFirstFactorCredentials(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
strategy := saml.NewStrategy(reg)
Expand Down

0 comments on commit b730cfb

Please sign in to comment.