diff --git a/saml/test/provider.go b/saml/test/provider.go index b4b9d7f..025577d 100644 --- a/saml/test/provider.go +++ b/saml/test/provider.go @@ -19,6 +19,9 @@ import ( "testing" "time" + "github.com/beevik/etree" + "github.com/russellhaering/gosaml2/types" + dsig "github.com/russellhaering/goxmldsig" "github.com/stretchr/testify/require" "github.com/hashicorp/cap/saml/models/core" @@ -46,8 +49,8 @@ const meta = ` ` // From https://www.samltool.com/generic_sso_res.php -const responseSigned = ` - +const ResponseSigned = ` + http://idp.example.com/metadata.php @@ -118,8 +121,9 @@ func (s *SAMLResponsePostData) PostRequest(t *testing.T) *http.Request { // TestProvider is an identity provider that can be used for testing // SAML federeation and authentication flows. type TestProvider struct { - t *testing.T - server *httptest.Server + t *testing.T + server *httptest.Server + keystore dsig.X509KeyStore metadata *metadata.EntityDescriptorIDPSSO recorder *httptest.ResponseRecorder @@ -199,9 +203,18 @@ func StartTestProvider(t *testing.T) *TestProvider { err := xml.Unmarshal([]byte(meta), &m) r.NoError(err) + keystore := dsig.RandomKeyStoreForTest() + _, cert, err := keystore.GetKeyPair() + r.NoError(err) + + b64Cert := base64.StdEncoding.EncodeToString(cert) + + m.IDPSSODescriptor[0].RoleDescriptor.KeyDescriptor[0].KeyInfo.X509Data.X509Certificates[0].Data = b64Cert + provider := &TestProvider{ t: t, metadata: &m, + keystore: keystore, } provider.defaults() @@ -272,14 +285,13 @@ func (p *TestProvider) loginHandlerPost(w http.ResponseWriter, req *http.Request relayState := req.FormValue("RelayState") r.Equal(p.expectedRelayState, relayState, "relay state doesn't match") - http.Error(w, "not implemented", http.StatusNotImplemented) samlReq := p.parseRequestPost(rawReq) p.validateRequest(samlReq) samlResponseData := &SAMLResponsePostData{ - SAMLResponse: responseSigned, + SAMLResponse: ResponseSigned, RelayState: relayState, Destination: samlReq.AssertionConsumerServiceURL, } @@ -312,7 +324,7 @@ func (p *TestProvider) loginHandlerRedirect(w http.ResponseWriter, req *http.Req p.validateRequest(samlReq) samlResponseData := &SAMLResponsePostData{ - SAMLResponse: responseSigned, + SAMLResponse: ResponseSigned, RelayState: relayState, Destination: samlReq.AssertionConsumerServiceURL, } @@ -417,3 +429,132 @@ func (p *TestProvider) parseRequestPost(request string) *core.AuthnRequest { return &req } + +type responseOptions struct { + sign bool + expired bool +} + +type ResponseOption func(*responseOptions) + +func getResponseOptions(opts ...ResponseOption) *responseOptions { + defaults := defaultResponseOptions() + for _, o := range opts { + o(defaults) + } + + return defaults +} + +func defaultResponseOptions() *responseOptions { + return &responseOptions{} +} + +func WithResponseSigned() ResponseOption { + return func(o *responseOptions) { + o.sign = true + } +} + +func WithResponseExpired() ResponseOption { + return func(o *responseOptions) { + o.expired = true + } +} + +func (p *TestProvider) SamlResponse(t *testing.T, opts ...ResponseOption) string { + r := require.New(t) + + opt := getResponseOptions(opts...) + + notOnOrAfter := "2200-01-18T06:21:48Z" + + if opt.expired { + notOnOrAfter = "2001-01-18T06:21:48Z" + } + + response := &core.Response{ + Response: types.Response{ + Destination: "http://hashicorp-cap.test/saml/acs", + ID: "test-resp-id", + InResponseTo: "test-request-id", + IssueInstant: time.Now(), + Version: "2.0", + Issuer: &types.Issuer{ + Value: "http://test.idp", + }, + Status: &types.Status{ + StatusCode: &types.StatusCode{ + Value: string(core.StatusCodeSuccess), + }, + }, + Assertions: []types.Assertion{ + { + ID: "assertion-id", + Issuer: &types.Issuer{ + Value: "http://test.idp", + }, + Subject: &types.Subject{ + NameID: &types.NameID{ + Value: "name-id", + }, + SubjectConfirmation: &types.SubjectConfirmation{ + Method: "urn:oasis:names:tc:SAML:2.0:cm:bearer", + SubjectConfirmationData: &types.SubjectConfirmationData{ + InResponseTo: "test-request-id", + Recipient: "http://hashicorp-cap.test/saml/acs", + NotOnOrAfter: notOnOrAfter, + }, + }, + }, + Conditions: &types.Conditions{ + NotBefore: "2001-01-18T06:21:48Z", + NotOnOrAfter: notOnOrAfter, + AudienceRestrictions: []types.AudienceRestriction{ + { + Audiences: []types.Audience{ + {Value: "http://hashicorp-cap.test"}, + }, + }, + }, + }, + AttributeStatement: &types.AttributeStatement{ + Attributes: []types.Attribute{ + { + Name: "mail", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:basic", + Values: []types.AttributeValue{ + { + Type: "xs:string", + Value: "user@hashicorp-cap.test", + }, + }, + }, + }, + }, + }, + }, + }, + } + + resp, err := xml.Marshal(response) + r.NoError(err) + + doc := etree.NewDocument() + err = doc.ReadFromBytes(resp) + r.NoError(err) + + if opt.sign { + signCtx := dsig.NewDefaultSigningContext(p.keystore) + + signed, err := signCtx.SignEnveloped(doc.Root()) + r.NoError(err) + + doc.SetRoot(signed) + } + + result, err := doc.WriteToString() + r.NoError(err) + + return result +}