diff --git a/go.mod b/go.mod index 5500d63a..745c5c2c 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/kr/pretty v0.3.1 github.com/mattermost/xml-roundtrip-validator v0.1.0 github.com/pkg/errors v0.9.1 // indirect - github.com/russellhaering/goxmldsig v1.2.0 + github.com/russellhaering/goxmldsig v1.3.0 github.com/stretchr/testify v1.8.1 github.com/zenazn/goji v1.0.1 golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed diff --git a/go.sum b/go.sum index a2bb4b19..7ab71ea2 100644 --- a/go.sum +++ b/go.sum @@ -35,8 +35,8 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/russellhaering/goxmldsig v1.2.0 h1:Y6GTTc9Un5hCxSzVz4UIWQ/zuVwDvzJk80guqzwx6Vg= -github.com/russellhaering/goxmldsig v1.2.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= +github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM= +github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/identity_provider.go b/identity_provider.go index 47052916..b2da5631 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -97,6 +97,7 @@ type AssertionMaker interface { // and password). type IdentityProvider struct { Key crypto.PrivateKey + Signer crypto.Signer Logger logger.Interface Certificate *x509.Certificate Intermediates []*x509.Certificate @@ -831,24 +832,8 @@ const canonicalizerPrefixList = "" // MakeAssertionEl sets `AssertionEl` to a signed, possibly encrypted, version of `Assertion`. func (req *IdpAuthnRequest) MakeAssertionEl() error { - keyPair := tls.Certificate{ - Certificate: [][]byte{req.IDP.Certificate.Raw}, - PrivateKey: req.IDP.Key, - Leaf: req.IDP.Certificate, - } - for _, cert := range req.IDP.Intermediates { - keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - } - keyStore := dsig.TLSCertKeyStore(keyPair) - - signatureMethod := req.IDP.SignatureMethod - if signatureMethod == "" { - signatureMethod = dsig.RSASHA1SignatureMethod - } - - signingContext := dsig.NewDefaultSigningContext(keyStore) - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + signingContext, err := req.signingContext() + if err != nil { return err } @@ -1049,24 +1034,8 @@ func (req *IdpAuthnRequest) MakeResponse() error { // Sign the response element (we've already signed the Assertion element) { - keyPair := tls.Certificate{ - Certificate: [][]byte{req.IDP.Certificate.Raw}, - PrivateKey: req.IDP.Key, - Leaf: req.IDP.Certificate, - } - for _, cert := range req.IDP.Intermediates { - keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - } - keyStore := dsig.TLSCertKeyStore(keyPair) - - signatureMethod := req.IDP.SignatureMethod - if signatureMethod == "" { - signatureMethod = dsig.RSASHA1SignatureMethod - } - - signingContext := dsig.NewDefaultSigningContext(keyStore) - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + signingContext, err := req.signingContext() + if err != nil { return err } @@ -1084,3 +1053,42 @@ func (req *IdpAuthnRequest) MakeResponse() error { req.ResponseEl = responseEl return nil } + +// signingContext will create a signing context for the request. +func (req *IdpAuthnRequest) signingContext() (signingContext *dsig.SigningContext, err error) { + // Create a cert chain based off of the IDP cert and its intermediates. + certificates := [][]byte{req.IDP.Certificate.Raw} + for _, cert := range req.IDP.Intermediates { + certificates = append(certificates, cert.Raw) + } + + // If signer is set, use it instead of the private key. + if req.IDP.Signer != nil { + signingContext, err = dsig.NewSigningContext(req.IDP.Signer, certificates) + if err != nil { + return + } + } else { + keyPair := tls.Certificate{ + Certificate: certificates, + PrivateKey: req.IDP.Key, + Leaf: req.IDP.Certificate, + } + keyStore := dsig.TLSCertKeyStore(keyPair) + + signingContext = dsig.NewDefaultSigningContext(keyStore) + } + + // Default to using SHA1 if the signature method isn't set. + signatureMethod := req.IDP.SignatureMethod + if signatureMethod == "" { + signatureMethod = dsig.RSASHA1SignatureMethod + } + + signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) + if err = signingContext.SetSignatureMethod(signatureMethod); err != nil { + return + } + + return +} diff --git a/identity_provider_go116_test.go b/identity_provider_go116_test.go index ead0a780..6d4a0a53 100644 --- a/identity_provider_go116_test.go +++ b/identity_provider_go116_test.go @@ -18,7 +18,7 @@ import ( ) func TestIDPHTTPCanHandleSSORequest(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) w := httptest.NewRecorder() const validRequest = `lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D` diff --git a/identity_provider_go117_test.go b/identity_provider_go117_test.go index 536587d6..c9060d94 100644 --- a/identity_provider_go117_test.go +++ b/identity_provider_go117_test.go @@ -18,7 +18,7 @@ import ( ) func TestIDPHTTPCanHandleSSORequest(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) w := httptest.NewRecorder() const validRequest = `lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D` diff --git a/identity_provider_test.go b/identity_provider_test.go index 0c602b53..ee5bb461 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -23,6 +23,7 @@ import ( "github.com/beevik/etree" "github.com/golang-jwt/jwt/v4" + dsig "github.com/russellhaering/goxmldsig" "github.com/crewjam/saml/logger" "github.com/crewjam/saml/testsaml" @@ -35,6 +36,7 @@ type IdentityProviderTest struct { SP ServiceProvider Key crypto.PrivateKey + Signer crypto.Signer Certificate *x509.Certificate SessionProvider SessionProvider IDP IdentityProvider @@ -48,7 +50,7 @@ func mustParseURL(s string) url.URL { return *rv } -func mustParsePrivateKey(pemStr []byte) crypto.PrivateKey { +func mustParsePrivateKey(pemStr []byte) crypto.Signer { b, _ := pem.Decode(pemStr) if b == nil { panic("cannot parse PEM") @@ -72,7 +74,28 @@ func mustParseCertificate(pemStr []byte) *x509.Certificate { return cert } -func NewIdentifyProviderTest(t *testing.T) *IdentityProviderTest { +// idpTestOpts are options that can be applied to the identity provider. +type idpTestOpts struct { + apply func(*testing.T, *IdentityProviderTest) +} + +// applyKey will set the private key for the identity provider. +var applyKey = idpTestOpts{ + apply: func(t *testing.T, test *IdentityProviderTest) { + test.Key = mustParsePrivateKey(golden.Get(t, "idp_key.pem")) + (&test.IDP).Key = test.Key + }, +} + +// applySigner will set the signer for the identity provider. +var applySigner = idpTestOpts{ + apply: func(t *testing.T, test *IdentityProviderTest) { + test.Signer = mustParsePrivateKey(golden.Get(t, "idp_key.pem")) + (&test.IDP).Signer = test.Signer + }, +} + +func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProviderTest { test := IdentityProviderTest{} TimeNow = func() time.Time { rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") @@ -92,11 +115,9 @@ func NewIdentifyProviderTest(t *testing.T) *IdentityProviderTest { IDPMetadata: &EntityDescriptor{}, } - test.Key = mustParsePrivateKey(golden.Get(t, "idp_key.pem")) test.Certificate = mustParseCertificate(golden.Get(t, "idp_cert.pem")) test.IDP = IdentityProvider{ - Key: test.Key, Certificate: test.Certificate, Logger: logger.DefaultLogger, MetadataURL: mustParseURL("https://idp.example.com/saml/metadata"), @@ -116,6 +137,11 @@ func NewIdentifyProviderTest(t *testing.T) *IdentityProviderTest { }, } + // apply the test options + for _, opt := range opts { + opt.apply(t, &test) + } + // bind the service provider and the IDP test.SP.IDPMetadata = test.IDP.Metadata() return &test @@ -138,7 +164,7 @@ func (mspp *mockServiceProviderProvider) GetServiceProvider(r *http.Request, ser } func TestIDPCanProduceMetadata(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) expected := &EntityDescriptor{ ValidUntil: TimeNow().Add(DefaultValidDuration), CacheDuration: DefaultValidDuration, @@ -199,7 +225,7 @@ func TestIDPCanProduceMetadata(t *testing.T) { } func TestIDPHTTPCanHandleMetadataRequest(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) w := httptest.NewRecorder() r, _ := http.NewRequest("GET", "https://idp.example.com/saml/metadata", nil) test.IDP.Handler().ServeHTTP(w, r) @@ -210,7 +236,7 @@ func TestIDPHTTPCanHandleMetadataRequest(t *testing.T) { } func TestIDPCanHandleRequestWithNewSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s", @@ -236,7 +262,7 @@ func TestIDPCanHandleRequestWithNewSession(t *testing.T) { } func TestIDPCanHandleRequestWithExistingSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ @@ -261,7 +287,7 @@ func TestIDPCanHandleRequestWithExistingSession(t *testing.T) { } func TestIDPCanHandlePostRequestWithExistingSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ @@ -290,7 +316,7 @@ func TestIDPCanHandlePostRequestWithExistingSession(t *testing.T) { } func TestIDPRejectsInvalidRequest(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { panic("not reached") @@ -311,7 +337,7 @@ func TestIDPRejectsInvalidRequest(t *testing.T) { } func TestIDPCanParse(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) r, _ := http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&SAMLRequest=lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D", nil) req, err := NewIdpAuthnRequest(&test.IDP, r) assert.Check(t, err) @@ -335,7 +361,7 @@ func TestIDPCanParse(t *testing.T) { } func TestIDPCanValidate(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -459,7 +485,7 @@ func TestIDPCanValidate(t *testing.T) { } func TestIDPMakeAssertion(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -643,7 +669,7 @@ func TestIDPMakeAssertion(t *testing.T) { } func TestIDPMarshalAssertion(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -691,8 +717,19 @@ func TestIDPMarshalAssertion(t *testing.T) { golden.Assert(t, string(assertionBuffer), t.Name()+"_encrypted_assertion") } -func TestIDPMakeResponse(t *testing.T) { - test := NewIdentifyProviderTest(t) +func TestIDPMakeResponsePrivateKey(t *testing.T) { + test := NewIdentityProviderTest(t, applyKey) + + testMakeResponse(t, test) +} + +func TestIDPMakeResponseSigner(t *testing.T) { + test := NewIdentityProviderTest(t, applySigner) + + testMakeResponse(t, test) +} + +func testMakeResponse(t *testing.T, test *IdentityProviderTest) { req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -713,6 +750,16 @@ func TestIDPMakeResponse(t *testing.T) { err = req.MakeResponse() assert.Check(t, err) + certificateStore := &dsig.MemoryX509CertificateStore{ + Roots: []*x509.Certificate{ + req.IDP.Certificate, + }, + } + validationCtx := dsig.NewDefaultValidationContext(certificateStore) + validationCtx.Clock = dsig.NewFakeClockAt(req.IDP.Certificate.NotBefore) + _, err = validationCtx.Validate(req.ResponseEl) + assert.Check(t, err) + response := Response{} err = unmarshalEtreeHack(req.ResponseEl, &response) assert.Check(t, err) @@ -722,11 +769,11 @@ func TestIDPMakeResponse(t *testing.T) { doc.Indent(2) responseStr, err := doc.WriteToString() assert.Check(t, err) - golden.Assert(t, responseStr, t.Name()+"_response.xml") + golden.Assert(t, responseStr, "TestIDPMakeResponse_response.xml") } func TestIDPWriteResponse(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) req := IdpAuthnRequest{ Now: TimeNow(), IDP: &test.IDP, @@ -746,7 +793,7 @@ func TestIDPWriteResponse(t *testing.T) { } func TestIDPIDPInitiatedNewSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { fmt.Fprintf(w, "RelayState: %s", req.RelayState) @@ -762,7 +809,7 @@ func TestIDPIDPInitiatedNewSession(t *testing.T) { } func TestIDPIDPInitiatedExistingSession(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ @@ -780,7 +827,7 @@ func TestIDPIDPInitiatedExistingSession(t *testing.T) { } func TestIDPIDPInitiatedBadServiceProvider(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ @@ -797,7 +844,7 @@ func TestIDPIDPInitiatedBadServiceProvider(t *testing.T) { } func TestIDPCanHandleUnencryptedResponse(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ID: "f00df00df00d", UserName: "alice"} @@ -845,7 +892,7 @@ func TestIDPCanHandleUnencryptedResponse(t *testing.T) { } func TestIDPRequestedAttributes(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) metadata := EntityDescriptor{} err := xml.Unmarshal(golden.Get(t, "TestIDPRequestedAttributes_idp_metadata.xml"), &metadata) assert.Check(t, err) @@ -975,7 +1022,7 @@ func TestIDPRequestedAttributes(t *testing.T) { } func TestIDPNoDestination(t *testing.T) { - test := NewIdentifyProviderTest(t) + test := NewIdentityProviderTest(t, applyKey) test.IDP.SessionProvider = &mockSessionProvider{ GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session { return &Session{ID: "f00df00df00d", UserName: "alice"} diff --git a/samlidp/samlidp.go b/samlidp/samlidp.go index 0ed28b2f..a9f023eb 100644 --- a/samlidp/samlidp.go +++ b/samlidp/samlidp.go @@ -20,6 +20,7 @@ import ( type Options struct { URL url.URL Key crypto.PrivateKey + Signer crypto.Signer Logger logger.Interface Certificate *x509.Certificate Store Store @@ -59,6 +60,7 @@ func New(opts Options) (*Server, error) { serviceProviders: map[string]*saml.EntityDescriptor{}, IDP: saml.IdentityProvider{ Key: opts.Key, + Signer: opts.Signer, Logger: logr, Certificate: opts.Certificate, MetadataURL: metadataURL,