Skip to content

Commit

Permalink
feat(saml): fix unit tests + lints
Browse files Browse the repository at this point in the history
  • Loading branch information
alexGNX committed May 10, 2022
1 parent 6a10e28 commit 5a13d4b
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 13 deletions.
3 changes: 3 additions & 0 deletions cmd/remote/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ var statusCmd = &cobra.Command{
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
c, err := cliclient.NewClient(cmd)
if err != nil {
return err
}
state := &statusState{}
defer cmdx.PrintRow(cmd, state)

Expand Down
3 changes: 3 additions & 0 deletions cmd/remote/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ var versionCmd = &cobra.Command{
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
c, err := cliclient.NewClient(cmd)
if err != nil {
return err
}

resp, _, err := c.MetadataApi.GetVersion(cmd.Context()).Execute()
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions selfservice/flow/saml/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ func (h *Handler) instantiateMiddleware(config config.Config) error {

// Crewjam library use default route for ACS and metadat but we want to overwrite them
metadata, err := url.Parse(publicUrlString + RouteSamlMetadata)
if err != nil {
return err
}
samlMiddleWare.ServiceProvider.MetadataURL = *metadata

// The EntityID in the AuthnRequest is the Metadata URL
Expand Down
19 changes: 15 additions & 4 deletions selfservice/flow/saml/helpertest/helpertest.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,32 @@ func InitMiddlewareWithoutMetadata(t *testing.T, idpSsoUrl string, idpEntityId s
return InitMiddleware(t, idpInformation)
}

func GetAndDecryptAssertion(t *testing.T, samlResponseFile string, key *rsa.PrivateKey) (*crewjamsaml.Assertion, error) {
func GetAndDecryptAssertion(samlResponseFile string, key *rsa.PrivateKey) (*crewjamsaml.Assertion, error) {
// Load saml response test file
samlResponse, err := ioutil.ReadFile(samlResponseFile)
if err != nil {
return nil, err
}

// Decrypt saml response assertion
doc := etree.NewDocument()
err = doc.ReadFromBytes(samlResponse)
require.NoError(t, err)
if err != nil {
return nil, err
}

responseEl := doc.Root()
el := responseEl.FindElement("//EncryptedAssertion/EncryptedData")
plaintextAssertion, err := xmlenc.Decrypt(key, el)
if err != nil {
return nil, err
}

assertion := &crewjamsaml.Assertion{}
err = xml.Unmarshal(plaintextAssertion, assertion)
require.NoError(t, err)
if err != nil {
return nil, err
}

return assertion, err
return assertion, nil
}
4 changes: 1 addition & 3 deletions selfservice/strategy/saml/strategy/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login
return nil, nil
}

// Method not used but necessary to implement the interface
func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, ss *session.Session) (i *identity.Identity, err error) {
return nil, nil
return nil, flow.ErrStrategyNotResponsible
}

// Method not used but necessary to implement the interface
func (s *Strategy) PopulateLoginMethod(r *http.Request, requestedAAL identity.AuthenticatorAssuranceLevel, l *login.Flow) error {
if l.Type != flow.TypeBrowser {
return nil
Expand Down
6 changes: 4 additions & 2 deletions selfservice/strategy/saml/strategy/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"

"github.com/google/go-jsonnet"
"github.com/pkg/errors"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
Expand All @@ -23,6 +24,7 @@ import (

// Implement the interface
var _ registration.Strategy = new(Strategy)
var ErrStrategyNotResponsible = errors.New("strategy is not responsible for this request")

//Call at the creation of Kratos, when Kratos implement all authentication routes
func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) {
Expand Down Expand Up @@ -87,7 +89,7 @@ func (s *Strategy) GetRegistrationIdentity(r *http.Request, ctx context.Context,
return i, err
}

// Create new uniq credentials identifier for user is database
// Create new unique credentials identifier for user is database
creds, err := NewCredentialsForSAML(claims.Subject, provider.Config().ID)
if err != nil {
return i, err
Expand Down Expand Up @@ -128,5 +130,5 @@ func (s *Strategy) PopulateRegistrationMethod(r *http.Request, f *registration.F

// Method not used but necessary to implement the interface
func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registration.Flow, i *identity.Identity) (err error) {
return nil
return flow.ErrStrategyNotResponsible
}
8 changes: 5 additions & 3 deletions selfservice/strategy/saml/strategy/test/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestGetAndDecryptAssertion(t *testing.T) {
middleware, _, _, _ := helpertest.InitMiddlewareWithMetadata(t,
"file://testdata/idp_saml_metadata.xml")

assertion, err := helpertest.GetAndDecryptAssertion(t, "./testdata/SP_SamlResponse.xml", middleware.ServiceProvider.Key)
assertion, err := helpertest.GetAndDecryptAssertion("./testdata/SP_SamlResponse.xml", middleware.ServiceProvider.Key)

require.NoError(t, err)
assert.Check(t, assertion != nil)
Expand All @@ -44,7 +44,8 @@ func TestGetAttributesFromAssertion(t *testing.T) {
middleware, strategy, _, _ := helpertest.InitMiddlewareWithMetadata(t,
"file://testdata/idp_saml_metadata.xml")

assertion, _ := helpertest.GetAndDecryptAssertion(t, "./testdata/SP_SamlResponse.xml", middleware.ServiceProvider.Key)
assertion, err := helpertest.GetAndDecryptAssertion("./testdata/SP_SamlResponse.xml", middleware.ServiceProvider.Key)
require.NoError(t, err)

mapAttributes, err := strategy.GetAttributesFromAssertion(assertion)

Expand Down Expand Up @@ -181,7 +182,8 @@ func TestGetRegistrationIdentity(t *testing.T) {
"file://testdata/idp_saml_metadata.xml")

provider, _ := strategy.Provider(context.Background())
assertion, _ := helpertest.GetAndDecryptAssertion(t, "./testdata/SP_SamlResponse.xml", middleware.ServiceProvider.Key)
assertion, err := helpertest.GetAndDecryptAssertion("./testdata/SP_SamlResponse.xml", middleware.ServiceProvider.Key)
require.NoError(t, err)
attributes, _ := strategy.GetAttributesFromAssertion(assertion)
claims, _ := provider.Claims(context.Background(), strategy.D().Config(context.Background()), attributes)

Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/saml/strategy/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type CredentialsConfig struct {
Providers []ProviderCredentialsConfig `json:"providers"`
}

//Create an uniq identifier for user in database. Its look like "id + the id of the saml provider"
//Create a unique identifier for user in database. Its look like "id + the id of the saml provider"
func NewCredentialsForSAML(subject string, provider string) (*identity.Credentials, error) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(CredentialsConfig{
Expand Down

0 comments on commit 5a13d4b

Please sign in to comment.