From a0e4509a35b1c7a36782f27ed6a109d2274839d2 Mon Sep 17 00:00:00 2001 From: Mihaly Gyongyosi Date: Fri, 13 Sep 2024 18:29:47 +0200 Subject: [PATCH 1/2] Set SessionIndex on LogoutRequest if it's available --- service_provider.go | 14 +++-- service_provider_test.go | 61 +++++++++++++++---- ...ducePostLogoutRequest_NoSessionIndex_form} | 0 ...ProducePostLogoutRequest_SessionIndex_form | 1 + ...goutRequest_NoSessionIndex_decodedRequest} | 0 ...tLogoutRequest_SessionIndex_decodedRequest | 1 + 6 files changed, 60 insertions(+), 17 deletions(-) rename testdata/{TestSPCanProducePostLogoutRequest_form => TestSPCanProducePostLogoutRequest_NoSessionIndex_form} (100%) create mode 100644 testdata/TestSPCanProducePostLogoutRequest_SessionIndex_form rename testdata/{TestSPCanProduceRedirectLogoutRequest_decodedRequest => TestSPCanProduceRedirectLogoutRequest_NoSessionIndex_decodedRequest} (100%) create mode 100644 testdata/TestSPCanProduceRedirectLogoutRequest_SessionIndex_decodedRequest diff --git a/service_provider.go b/service_provider.go index 51fe618b..d7d5d1e8 100644 --- a/service_provider.go +++ b/service_provider.go @@ -1192,7 +1192,7 @@ func (sp *ServiceProvider) SignLogoutRequest(req *LogoutRequest) error { } // MakeLogoutRequest produces a new LogoutRequest object for idpURL. -func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequest, error) { +func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID, sessionIndex string) (*LogoutRequest, error) { req := LogoutRequest{ ID: fmt.Sprintf("id-%x", randomBytes(20)), @@ -1210,6 +1210,10 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ SPNameQualifier: sp.Metadata().EntityID, }, } + if sessionIndex != "" { + req.SessionIndex = &SessionIndex{sessionIndex} + } + if len(sp.SignatureMethod) > 0 { if err := sp.SignLogoutRequest(&req); err != nil { return nil, err @@ -1221,8 +1225,8 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ // MakeRedirectLogoutRequest creates a SAML authentication request using // the HTTP-Redirect binding. It returns a URL that we will redirect the user to // in order to start the auth process. -func (sp *ServiceProvider) MakeRedirectLogoutRequest(nameID, relayState string) (*url.URL, error) { - req, err := sp.MakeLogoutRequest(sp.GetSLOBindingLocation(HTTPRedirectBinding), nameID) +func (sp *ServiceProvider) MakeRedirectLogoutRequest(nameID, relayState, sessionIndex string) (*url.URL, error) { + req, err := sp.MakeLogoutRequest(sp.GetSLOBindingLocation(HTTPRedirectBinding), nameID, sessionIndex) if err != nil { return nil, err } @@ -1261,8 +1265,8 @@ func (r *LogoutRequest) Redirect(relayState string) *url.URL { // MakePostLogoutRequest creates a SAML authentication request using // the HTTP-POST binding. It returns HTML text representing an HTML form that // can be sent presented to a browser to initiate the logout process. -func (sp *ServiceProvider) MakePostLogoutRequest(nameID, relayState string) ([]byte, error) { - req, err := sp.MakeLogoutRequest(sp.GetSLOBindingLocation(HTTPPostBinding), nameID) +func (sp *ServiceProvider) MakePostLogoutRequest(nameID, relayState, sessionIndex string) ([]byte, error) { + req, err := sp.MakeLogoutRequest(sp.GetSLOBindingLocation(HTTPPostBinding), nameID, sessionIndex) if err != nil { return nil, err } diff --git a/service_provider_test.go b/service_provider_test.go index cab8c0da..b787b3f5 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -375,6 +375,19 @@ func TestSPFailToProduceSignedRequestWithBogusSignatureMethod(t *testing.T) { } func TestSPCanProducePostLogoutRequest(t *testing.T) { + testCases := []struct { + name string + sessionIndex string + }{ + { + name: "TestSPCanProducePostLogoutRequest_NoSessionIndex", + }, + { + name: "TestSPCanProducePostLogoutRequest_SessionIndex", + sessionIndex: "session-123", + }, + } + test := NewServiceProviderTest(t) TimeNow = func() time.Time { rv, _ := time.Parse("Mon Jan 2 15:04:05 UTC 2006", "Mon Dec 1 01:31:21 UTC 2015") @@ -390,12 +403,30 @@ func TestSPCanProducePostLogoutRequest(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - form, err := s.MakePostLogoutRequest("ros@octolabs.io", "relayState") - assert.Check(t, err) - golden.Assert(t, string(form), t.Name()+"_form") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + RandReader = &testRandomReader{} + form, err := s.MakePostLogoutRequest("ros@octolabs.io", "relayState", tc.sessionIndex) + assert.Check(t, err) + golden.Assert(t, string(form), tc.name+"_form") + }) + } } func TestSPCanProduceRedirectLogoutRequest(t *testing.T) { + testCases := []struct { + name string + sessionIndex string + }{ + { + name: "TestSPCanProduceRedirectLogoutRequest_NoSessionIndex", + }, + { + name: "TestSPCanProduceRedirectLogoutRequest_SessionIndex", + sessionIndex: "session-123", + }, + } + test := NewServiceProviderTest(t) TimeNow = func() time.Time { rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 UTC 2006", "Mon Dec 1 01:31:21.123456789 UTC 2015") @@ -412,16 +443,22 @@ func TestSPCanProduceRedirectLogoutRequest(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - redirectURL, err := s.MakeRedirectLogoutRequest("ross@octolabs.io", "relayState") - assert.Check(t, err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + RandReader = &testRandomReader{} - decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) - assert.Check(t, err) - assert.Check(t, is.Equal("idp.testshib.org", - redirectURL.Host)) - assert.Check(t, is.Equal("/idp/profile/SAML2/Redirect/SLO", - redirectURL.Path)) - golden.Assert(t, string(decodedRequest), t.Name()+"_decodedRequest") + redirectURL, err := s.MakeRedirectLogoutRequest("ross@octolabs.io", "relayState", tc.sessionIndex) + assert.Check(t, err) + + decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) + assert.Check(t, err) + assert.Check(t, is.Equal("idp.testshib.org", + redirectURL.Host)) + assert.Check(t, is.Equal("/idp/profile/SAML2/Redirect/SLO", + redirectURL.Path)) + golden.Assert(t, string(decodedRequest), tc.name+"_decodedRequest") + }) + } } func TestSPCanProducePostLogoutResponse(t *testing.T) { diff --git a/testdata/TestSPCanProducePostLogoutRequest_form b/testdata/TestSPCanProducePostLogoutRequest_NoSessionIndex_form similarity index 100% rename from testdata/TestSPCanProducePostLogoutRequest_form rename to testdata/TestSPCanProducePostLogoutRequest_NoSessionIndex_form diff --git a/testdata/TestSPCanProducePostLogoutRequest_SessionIndex_form b/testdata/TestSPCanProducePostLogoutRequest_SessionIndex_form new file mode 100644 index 00000000..d1afb2aa --- /dev/null +++ b/testdata/TestSPCanProducePostLogoutRequest_SessionIndex_form @@ -0,0 +1 @@ +
\ No newline at end of file diff --git a/testdata/TestSPCanProduceRedirectLogoutRequest_decodedRequest b/testdata/TestSPCanProduceRedirectLogoutRequest_NoSessionIndex_decodedRequest similarity index 100% rename from testdata/TestSPCanProduceRedirectLogoutRequest_decodedRequest rename to testdata/TestSPCanProduceRedirectLogoutRequest_NoSessionIndex_decodedRequest diff --git a/testdata/TestSPCanProduceRedirectLogoutRequest_SessionIndex_decodedRequest b/testdata/TestSPCanProduceRedirectLogoutRequest_SessionIndex_decodedRequest new file mode 100644 index 00000000..0643b96f --- /dev/null +++ b/testdata/TestSPCanProduceRedirectLogoutRequest_SessionIndex_decodedRequest @@ -0,0 +1 @@ +https://15661444.ngrok.io/saml2/metadataross@octolabs.iosession-123 \ No newline at end of file From 9fa8a464d9f6bb4db54defa0c7614ceae36c2c37 Mon Sep 17 00:00:00 2001 From: Mihaly Gyongyosi Date: Mon, 16 Sep 2024 10:02:08 +0200 Subject: [PATCH 2/2] fix --- example/trivial/trivial.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/trivial/trivial.go b/example/trivial/trivial.go index 45f46080..af12b893 100644 --- a/example/trivial/trivial.go +++ b/example/trivial/trivial.go @@ -23,7 +23,7 @@ func hello(w http.ResponseWriter, r *http.Request) { func logout(w http.ResponseWriter, r *http.Request) { nameID := samlsp.AttributeFromContext(r.Context(), "urn:oasis:names:tc:SAML:attribute:subject-id") - url, err := samlMiddleware.ServiceProvider.MakeRedirectLogoutRequest(nameID, "") + url, err := samlMiddleware.ServiceProvider.MakeRedirectLogoutRequest(nameID, "", "") if err != nil { panic(err) // TODO handle error }