diff --git a/.golangci.yml b/.golangci.yml index 4fbc0405..f93ef23b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -7,41 +7,35 @@ linters: enable: - - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification [fast: true, auto-fix: true] - - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports [fast: true, auto-fix: true] - - gosec # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases [fast: true, auto-fix: false] - - misspell # Finds commonly misspelled English words in comments [fast: true, auto-fix: true] - - deadcode # Finds unused code [fast: true, auto-fix: false] - - revive # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes [fast: true, auto-fix: false] - - unconvert # Remove unnecessary type conversions [fast: true, auto-fix: false] - - disable: - # TODO(ross): fix errors reported by these checkers and enable them - bodyclose # checks whether HTTP response body is closed successfully [fast: false, auto-fix: false] - depguard # Go linter that checks if package imports are in a list of acceptable packages [fast: true, auto-fix: false] - - dupl # Tool for code clone detection [fast: true, auto-fix: false] - errcheck # Inspects source code for security problems [fast: true, auto-fix: false] - - gochecknoglobals # Checks that no globals are present in Go code [fast: true, auto-fix: false] - - gochecknoinits # Checks that no init functions are present in Go code [fast: true, auto-fix: false] - - goconst # Finds repeated strings that could be replaced by a constant [fast: true, auto-fix: false] - gocritic # The most opinionated Go source code linter [fast: true, auto-fix: false] - gocyclo # Computes and checks the cyclomatic complexity of functions [fast: true, auto-fix: false] + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification [fast: true, auto-fix: true] + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports [fast: true, auto-fix: true] + - gosec # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases [fast: true, auto-fix: false] - gosimple # Linter for Go source code that specializes in simplifying a code [fast: false, auto-fix: false] - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string [fast: false, auto-fix: false] - ineffassign # Detects when assignments to existing variables are not used [fast: true, auto-fix: false] - - interfacer # Linter that suggests narrower interface types [fast: false, auto-fix: false] - - lll # Reports long lines [fast: true, auto-fix: false] - - maligned # Tool to detect Go structs that would take less memory if their fields were sorted [fast: true, auto-fix: false] + - misspell # Finds commonly misspelled English words in comments [fast: true, auto-fix: true] - nakedret # Finds naked returns in functions greater than a specified function length [fast: true, auto-fix: false] - prealloc # Finds slice declarations that could potentially be preallocated [fast: true, auto-fix: false] - - scopelint # Scopelint checks for unpinned variables in go programs [fast: true, auto-fix: false] + - revive # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes [fast: true, auto-fix: false] - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks [fast: false, auto-fix: false] - - structcheck # Finds unused struct fields [fast: true, auto-fix: false] - stylecheck # Stylecheck is a replacement for golint [fast: false, auto-fix: false] - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code [fast: true, auto-fix: false] + - unconvert # Remove unnecessary type conversions [fast: true, auto-fix: false] - unparam # Reports unused function parameters [fast: false, auto-fix: false] - unused # Checks Go code for unused constants, variables, functions and types [fast: false, auto-fix: false] - - varcheck # Finds unused global variables and constants [fast: true, auto-fix: false] + + disable: + # TODO(ross): fix errors reported by these checkers and enable them + - dupl # Tool for code clone detection [fast: true, auto-fix: false] + - gochecknoglobals # Checks that no globals are present in Go code [fast: true, auto-fix: false] + - gochecknoinits # Checks that no init functions are present in Go code [fast: true, auto-fix: false] + - goconst # Finds repeated strings that could be replaced by a constant [fast: true, auto-fix: false] + - lll # Reports long lines [fast: true, auto-fix: false] linters-settings: goimports: local-prefixes: github.com/crewjam/saml diff --git a/example/idp/idp.go b/example/idp/idp.go index 6069d379..4e47a56a 100644 --- a/example/idp/idp.go +++ b/example/idp/idp.go @@ -1,3 +1,4 @@ +// Package main contains an example identity provider implementation. package main import ( diff --git a/example/service.go b/example/service.go index c153b65f..5b6ddb27 100644 --- a/example/service.go +++ b/example/service.go @@ -32,7 +32,7 @@ type Link struct { } // CreateLink handles requests to create links -func CreateLink(c web.C, w http.ResponseWriter, r *http.Request) { +func CreateLink(_ web.C, w http.ResponseWriter, r *http.Request) { account := r.Header.Get("X-Remote-User") l := Link{ ShortLink: uniuri.New(), @@ -42,22 +42,20 @@ func CreateLink(c web.C, w http.ResponseWriter, r *http.Request) { links[l.ShortLink] = l fmt.Fprintf(w, "%s\n", l.ShortLink) - return } // ServeLink handles requests to redirect to a link -func ServeLink(c web.C, w http.ResponseWriter, r *http.Request) { +func ServeLink(_ web.C, w http.ResponseWriter, r *http.Request) { l, ok := links[strings.TrimPrefix(r.URL.Path, "/")] if !ok { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } http.Redirect(w, r, l.Target, http.StatusFound) - return } // ListLinks returns a list of the current user's links -func ListLinks(c web.C, w http.ResponseWriter, r *http.Request) { +func ListLinks(_ web.C, w http.ResponseWriter, r *http.Request) { account := r.Header.Get("X-Remote-User") for _, l := range links { if l.Owner == account { @@ -145,14 +143,24 @@ func main() { spURL := *idpMetadataURL spURL.Path = "/services/sp" - http.Post(spURL.String(), "text/xml", bytes.NewReader(spMetadataBuf)) + resp, err := http.Post(spURL.String(), "text/xml", bytes.NewReader(spMetadataBuf)) + + if err != nil { + panic(err) + } + + if err := resp.Body.Close(); err != nil { + panic(err) + } goji.Handle("/saml/*", samlSP) authMux := web.New() authMux.Use(samlSP.RequireAccount) authMux.Get("/whoami", func(w http.ResponseWriter, r *http.Request) { - pretty.Fprintf(w, "%# v", r) + if _, err := pretty.Fprintf(w, "%# v", r); err != nil { + panic(err) + } }) authMux.Post("/", CreateLink) authMux.Get("/", ListLinks) diff --git a/example/trivial/trivial.go b/example/trivial/trivial.go index 81ae051b..e7661833 100644 --- a/example/trivial/trivial.go +++ b/example/trivial/trivial.go @@ -1,3 +1,4 @@ +// Package main contains an example service provider implementation. package main import ( @@ -70,6 +71,7 @@ func main() { }) app := http.HandlerFunc(hello) slo := http.HandlerFunc(logout) + http.Handle("/hello", samlMiddleware.RequireAccount(app)) http.Handle("/saml/", samlMiddleware) http.Handle("/logout", slo) diff --git a/identity_provider.go b/identity_provider.go index 1427fafe..e822cb24 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -198,10 +198,13 @@ func (idp *IdentityProvider) Handler() http.Handler { } // ServeMetadata is an http.HandlerFunc that serves the IDP metadata -func (idp *IdentityProvider) ServeMetadata(w http.ResponseWriter, r *http.Request) { +func (idp *IdentityProvider) ServeMetadata(w http.ResponseWriter, _ *http.Request) { buf, _ := xml.MarshalIndent(idp.Metadata(), "", " ") w.Header().Set("Content-Type", "application/samlmetadata+xml") - w.Write(buf) + if _, err := w.Write(buf); err != nil { + idp.Logger.Printf("ERROR: %s", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } } // ServeSSO handles SAML auth requests. @@ -718,9 +721,7 @@ func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Sessio }) } - for _, ca := range session.CustomAttributes { - attributes = append(attributes, ca) - } + attributes = append(attributes, session.CustomAttributes...) if len(session.Groups) != 0 { groupMemberAttributeValues := []AttributeValue{} diff --git a/identity_provider_go117_test.go b/identity_provider_go117_test.go index 536587d6..0ce6a1a7 100644 --- a/identity_provider_go117_test.go +++ b/identity_provider_go117_test.go @@ -43,8 +43,10 @@ func TestIDPHTTPCanHandleSSORequest(t *testing.T) { d := bytes.Replace(c, []byte("]]"), 1) f := bytes.Buffer{} e, _ := flate.NewWriter(&f, flate.DefaultCompression) - e.Write(d) - e.Close() + _, err := e.Write(d) + assert.Check(t, err) + err = e.Close() + assert.Check(t, err) g := base64.StdEncoding.EncodeToString(f.Bytes()) invalidRequest := url.QueryEscape(g) diff --git a/identity_provider_test.go b/identity_provider_test.go index d79cf8a5..372cd0bb 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -763,7 +763,7 @@ func TestIDPIDPInitiatedNewSession(t *testing.T) { r, _ := http.NewRequest("GET", "https://idp.example.com/services/sp/whoami", nil) test.IDP.ServeIDPInitiated(w, r, test.SP.MetadataURL.String(), "ThisIsTheRelayState") assert.Check(t, is.Equal(200, w.Code)) - assert.Check(t, is.Equal("RelayState: ThisIsTheRelayState", string(w.Body.Bytes()))) + assert.Check(t, is.Equal("RelayState: ThisIsTheRelayState", w.Body.String())) } func TestIDPIDPInitiatedExistingSession(t *testing.T) { @@ -1029,18 +1029,18 @@ func TestIDPRejectDecompressionBomb(t *testing.T) { }, } - //w := httptest.NewRecorder() - data := bytes.Repeat([]byte("a"), 768*1024*1024) var compressed bytes.Buffer w, _ := flate.NewWriter(&compressed, flate.BestCompression) - w.Write(data) - w.Close() + _, err := w.Write(data) + assert.Check(t, err) + err = w.Close() + assert.Check(t, err) encoded := base64.StdEncoding.EncodeToString(compressed.Bytes()) r, _ := http.NewRequest("GET", "/dontcare?"+url.Values{ "SAMLRequest": {encoded}, }.Encode(), nil) - _, err := NewIdpAuthnRequest(&test.IDP, r) + _, err = NewIdpAuthnRequest(&test.IDP, r) assert.Error(t, err, "cannot decompress request: flate: uncompress limit exceeded (10485760 bytes)") } diff --git a/logger/logger.go b/logger/logger.go index c211aba6..03bb0bde 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,3 +1,4 @@ +// Package logger provides a logging interface. package logger import ( diff --git a/metadata.go b/metadata.go index 74eeb763..f2a25ee5 100644 --- a/metadata.go +++ b/metadata.go @@ -65,7 +65,7 @@ type EntityDescriptor struct { } // MarshalXML implements xml.Marshaler -func (m EntityDescriptor) MarshalXML(e *xml.Encoder, start xml.StartElement) error { +func (m EntityDescriptor) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { type Alias EntityDescriptor aux := &struct { ValidUntil RelaxedTime `xml:"validUntil,attr,omitempty"` diff --git a/samlidp/samlidp.go b/samlidp/samlidp.go index 0ed28b2f..2141ca89 100644 --- a/samlidp/samlidp.go +++ b/samlidp/samlidp.go @@ -47,9 +47,9 @@ type Server struct { // New returns a new Server func New(opts Options) (*Server, error) { metadataURL := opts.URL - metadataURL.Path = metadataURL.Path + "/metadata" + metadataURL.Path += "/metadata" ssoURL := opts.URL - ssoURL.Path = ssoURL.Path + "/sso" + ssoURL.Path += "/sso" logr := opts.Logger if logr == nil { logr = logger.DefaultLogger diff --git a/samlidp/samlidp_test.go b/samlidp/samlidp_test.go index 078239f5..e5b2dafb 100644 --- a/samlidp/samlidp_test.go +++ b/samlidp/samlidp_test.go @@ -124,8 +124,8 @@ func TestHTTPCanHandleMetadataRequest(t *testing.T) { test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, - strings.HasPrefix(string(w.Body.Bytes()), "

"), - string(w.Body.Bytes())) + strings.HasPrefix(w.Body.String(), "

"), + w.Body.String()) golden.Assert(t, w.Body.String(), "http_sso_response.html") } diff --git a/samlidp/service.go b/samlidp/service.go index 5c2cc659..0b62cd3b 100644 --- a/samlidp/service.go +++ b/samlidp/service.go @@ -25,7 +25,7 @@ type Service struct { // service provider ID, which is typically the service provider's // metadata URL. If an appropriate service provider cannot be found then // the returned error must be os.ErrNotExist. -func (s *Server) GetServiceProvider(r *http.Request, serviceProviderID string) (*saml.EntityDescriptor, error) { +func (s *Server) GetServiceProvider(_ *http.Request, serviceProviderID string) (*saml.EntityDescriptor, error) { s.idpConfigMu.RLock() defer s.idpConfigMu.RUnlock() rv, ok := s.serviceProviders[serviceProviderID] @@ -37,7 +37,7 @@ func (s *Server) GetServiceProvider(r *http.Request, serviceProviderID string) ( // HandleListServices handles the `GET /services/` request and responds with a JSON formatted list // of service names. -func (s *Server) HandleListServices(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleListServices(_ web.C, w http.ResponseWriter, _ *http.Request) { services, err := s.Store.List("/services/") if err != nil { s.logger.Printf("ERROR: %s", err) @@ -45,14 +45,18 @@ func (s *Server) HandleListServices(c web.C, w http.ResponseWriter, r *http.Requ return } - json.NewEncoder(w).Encode(struct { + err = json.NewEncoder(w).Encode(struct { Services []string `json:"services"` }{Services: services}) + if err != nil { + s.logger.Printf("ERROR: %s", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } } // HandleGetService handles the `GET /services/:id` request and responds with the service // metadata in XML format. -func (s *Server) HandleGetService(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleGetService(c web.C, w http.ResponseWriter, _ *http.Request) { service := Service{} err := s.Store.Get(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service) if err != nil { @@ -60,7 +64,11 @@ func (s *Server) HandleGetService(c web.C, w http.ResponseWriter, r *http.Reques http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - xml.NewEncoder(w).Encode(service.Metadata) + err = xml.NewEncoder(w).Encode(service.Metadata) + if err != nil { + s.logger.Printf("ERROR: %s", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } } // HandlePutService handles the `PUT /shortcuts/:id` request. It accepts the XML-formatted @@ -92,7 +100,7 @@ func (s *Server) HandlePutService(c web.C, w http.ResponseWriter, r *http.Reques } // HandleDeleteService handles the `DELETE /services/:id` request. -func (s *Server) HandleDeleteService(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleDeleteService(c web.C, w http.ResponseWriter, _ *http.Request) { service := Service{} err := s.Store.Get(fmt.Sprintf("/services/%s", c.URLParams["id"]), &service) if err != nil { diff --git a/samlidp/service_test.go b/samlidp/service_test.go index 57e7c4d4..7ee4df83 100644 --- a/samlidp/service_test.go +++ b/samlidp/service_test.go @@ -18,7 +18,7 @@ func TestServicesCrud(t *testing.T) { r, _ := http.NewRequest("GET", "https://idp.example.com/services/", nil) test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) - assert.Check(t, is.Equal("{\"services\":[]}\n", string(w.Body.Bytes()))) + assert.Check(t, is.Equal("{\"services\":[]}\n", w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("PUT", "https://idp.example.com/services/sp", @@ -36,7 +36,7 @@ func TestServicesCrud(t *testing.T) { r, _ = http.NewRequest("GET", "https://idp.example.com/services/", nil) test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) - assert.Check(t, is.Equal("{\"services\":[\"sp\"]}\n", string(w.Body.Bytes()))) + assert.Check(t, is.Equal("{\"services\":[\"sp\"]}\n", w.Body.String())) assert.Check(t, is.Len(test.Server.serviceProviders, 2)) @@ -49,6 +49,6 @@ func TestServicesCrud(t *testing.T) { r, _ = http.NewRequest("GET", "https://idp.example.com/services/", nil) test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) - assert.Check(t, is.Equal("{\"services\":[]}\n", string(w.Body.Bytes()))) + assert.Check(t, is.Equal("{\"services\":[]}\n", w.Body.String())) assert.Check(t, is.Len(test.Server.serviceProviders, 1)) } diff --git a/samlidp/session.go b/samlidp/session.go index ba3bd65b..8ffae2ba 100644 --- a/samlidp/session.go +++ b/samlidp/session.go @@ -48,12 +48,13 @@ func (s *Server) GetSession(w http.ResponseWriter, r *http.Request, req *saml.Id } session := &saml.Session{ - ID: base64.StdEncoding.EncodeToString(randomBytes(32)), - NameID: user.Email, - CreateTime: saml.TimeNow(), - ExpireTime: saml.TimeNow().Add(sessionMaxAge), - Index: hex.EncodeToString(randomBytes(32)), - UserName: user.Name, + ID: base64.StdEncoding.EncodeToString(randomBytes(32)), + NameID: user.Email, + CreateTime: saml.TimeNow(), + ExpireTime: saml.TimeNow().Add(sessionMaxAge), + Index: hex.EncodeToString(randomBytes(32)), + UserName: user.Name, + // nolint:gocritic // Groups should be a slice here. Groups: user.Groups[:], UserEmail: user.Email, UserCommonName: user.CommonName, @@ -102,7 +103,7 @@ func (s *Server) GetSession(w http.ResponseWriter, r *http.Request, req *saml.Id // sendLoginForm produces a form which requests a username and password and directs the user // back to the IDP authorize URL to restart the SAML login flow, this time establishing a // session based on the credentials that were provided. -func (s *Server) sendLoginForm(w http.ResponseWriter, r *http.Request, req *saml.IdpAuthnRequest, toast string) { +func (s *Server) sendLoginForm(w http.ResponseWriter, _ *http.Request, req *saml.IdpAuthnRequest, toast string) { tmpl := template.Must(template.New("saml-post-form").Parse(`` + `` + `

{{.Toast}}

` + @@ -135,7 +136,7 @@ func (s *Server) sendLoginForm(w http.ResponseWriter, r *http.Request, req *saml // in the request body, then they are validated. For valid credentials, the response is a // 200 OK and the JSON session object. For invalid credentials, the HTML login prompt form // is sent. -func (s *Server) HandleLogin(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleLogin(_ web.C, w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return @@ -144,38 +145,48 @@ func (s *Server) HandleLogin(c web.C, w http.ResponseWriter, r *http.Request) { if session == nil { return } - json.NewEncoder(w).Encode(session) + if err := json.NewEncoder(w).Encode(session); err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } } // HandleListSessions handles the `GET /sessions/` request and responds with a JSON formatted list // of session names. -func (s *Server) HandleListSessions(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleListSessions(_ web.C, w http.ResponseWriter, _ *http.Request) { sessions, err := s.Store.List("/sessions/") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - json.NewEncoder(w).Encode(struct { + err = json.NewEncoder(w).Encode(struct { Sessions []string `json:"sessions"` }{Sessions: sessions}) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } } // HandleGetSession handles the `GET /sessions/:id` request and responds with the session // object in JSON format. -func (s *Server) HandleGetSession(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleGetSession(c web.C, w http.ResponseWriter, _ *http.Request) { session := saml.Session{} err := s.Store.Get(fmt.Sprintf("/sessions/%s", c.URLParams["id"]), &session) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - json.NewEncoder(w).Encode(session) + if err := json.NewEncoder(w).Encode(session); err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } } // HandleDeleteSession handles the `DELETE /sessions/:id` request. It invalidates the // specified session. -func (s *Server) HandleDeleteSession(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleDeleteSession(c web.C, w http.ResponseWriter, _ *http.Request) { err := s.Store.Delete(fmt.Sprintf("/sessions/%s", c.URLParams["id"])) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) diff --git a/samlidp/session_test.go b/samlidp/session_test.go index 7f5d1351..cb3d5c3d 100644 --- a/samlidp/session_test.go +++ b/samlidp/session_test.go @@ -17,7 +17,7 @@ func TestSessionsCrud(t *testing.T) { test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, is.Equal("{\"sessions\":[]}\n", - string(w.Body.Bytes()))) + w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("PUT", "https://idp.example.com/users/alice", @@ -34,7 +34,7 @@ func TestSessionsCrud(t *testing.T) { assert.Check(t, is.Equal("session=AAIEBggKDA4QEhQWGBocHiAiJCYoKiwuMDI0Njg6PD4=; Path=/; Max-Age=3600; HttpOnly; Secure", w.Header().Get("Set-Cookie"))) assert.Check(t, is.Equal("{\"ID\":\"AAIEBggKDA4QEhQWGBocHiAiJCYoKiwuMDI0Njg6PD4=\",\"CreateTime\":\"2015-12-01T01:57:09Z\",\"ExpireTime\":\"2015-12-01T02:57:09Z\",\"Index\":\"40424446484a4c4e50525456585a5c5e60626466686a6c6e70727476787a7c7e\",\"NameID\":\"\",\"NameIDFormat\":\"\",\"SubjectID\":\"\",\"Groups\":null,\"UserName\":\"alice\",\"UserEmail\":\"\",\"UserCommonName\":\"\",\"UserSurname\":\"\",\"UserGivenName\":\"\",\"UserScopedAffiliation\":\"\",\"CustomAttributes\":null}\n", - string(w.Body.Bytes()))) + w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("GET", "https://idp.example.com/login", nil) @@ -42,14 +42,14 @@ func TestSessionsCrud(t *testing.T) { test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, is.Equal("{\"ID\":\"AAIEBggKDA4QEhQWGBocHiAiJCYoKiwuMDI0Njg6PD4=\",\"CreateTime\":\"2015-12-01T01:57:09Z\",\"ExpireTime\":\"2015-12-01T02:57:09Z\",\"Index\":\"40424446484a4c4e50525456585a5c5e60626466686a6c6e70727476787a7c7e\",\"NameID\":\"\",\"NameIDFormat\":\"\",\"SubjectID\":\"\",\"Groups\":null,\"UserName\":\"alice\",\"UserEmail\":\"\",\"UserCommonName\":\"\",\"UserSurname\":\"\",\"UserGivenName\":\"\",\"UserScopedAffiliation\":\"\",\"CustomAttributes\":null}\n", - string(w.Body.Bytes()))) + w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("GET", "https://idp.example.com/sessions/AAIEBggKDA4QEhQWGBocHiAiJCYoKiwuMDI0Njg6PD4=", nil) test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, is.Equal("{\"ID\":\"AAIEBggKDA4QEhQWGBocHiAiJCYoKiwuMDI0Njg6PD4=\",\"CreateTime\":\"2015-12-01T01:57:09Z\",\"ExpireTime\":\"2015-12-01T02:57:09Z\",\"Index\":\"40424446484a4c4e50525456585a5c5e60626466686a6c6e70727476787a7c7e\",\"NameID\":\"\",\"NameIDFormat\":\"\",\"SubjectID\":\"\",\"Groups\":null,\"UserName\":\"alice\",\"UserEmail\":\"\",\"UserCommonName\":\"\",\"UserSurname\":\"\",\"UserGivenName\":\"\",\"UserScopedAffiliation\":\"\",\"CustomAttributes\":null}\n", - string(w.Body.Bytes()))) + w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("DELETE", "https://idp.example.com/sessions/AAIEBggKDA4QEhQWGBocHiAiJCYoKiwuMDI0Njg6PD4=", nil) @@ -61,6 +61,6 @@ func TestSessionsCrud(t *testing.T) { test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, is.Equal("{\"sessions\":[]}\n", - string(w.Body.Bytes()))) + w.Body.String())) } diff --git a/samlidp/shortcut.go b/samlidp/shortcut.go index 151a84ea..4c5b8650 100644 --- a/samlidp/shortcut.go +++ b/samlidp/shortcut.go @@ -31,28 +31,35 @@ type Shortcut struct { // HandleListShortcuts handles the `GET /shortcuts/` request and responds with a JSON formatted list // of shortcut names. -func (s *Server) HandleListShortcuts(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleListShortcuts(_ web.C, w http.ResponseWriter, _ *http.Request) { shortcuts, err := s.Store.List("/shortcuts/") if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - json.NewEncoder(w).Encode(struct { + err = json.NewEncoder(w).Encode(struct { Shortcuts []string `json:"shortcuts"` }{Shortcuts: shortcuts}) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } } // HandleGetShortcut handles the `GET /shortcuts/:id` request and responds with the shortcut // object in JSON format. -func (s *Server) HandleGetShortcut(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleGetShortcut(c web.C, w http.ResponseWriter, _ *http.Request) { shortcut := Shortcut{} err := s.Store.Get(fmt.Sprintf("/shortcuts/%s", c.URLParams["id"]), &shortcut) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - json.NewEncoder(w).Encode(shortcut) + if err := json.NewEncoder(w).Encode(shortcut); err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } } // HandlePutShortcut handles the `PUT /shortcuts/:id` request. It accepts a JSON formatted @@ -74,7 +81,7 @@ func (s *Server) HandlePutShortcut(c web.C, w http.ResponseWriter, r *http.Reque } // HandleDeleteShortcut handles the `DELETE /shortcuts/:id` request. -func (s *Server) HandleDeleteShortcut(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleDeleteShortcut(c web.C, w http.ResponseWriter, _ *http.Request) { err := s.Store.Delete(fmt.Sprintf("/shortcuts/%s", c.URLParams["id"])) if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -100,7 +107,7 @@ func (s *Server) HandleIDPInitiated(c web.C, w http.ResponseWriter, r *http.Requ case shortcut.RelayState != nil: relayState = *shortcut.RelayState case shortcut.URISuffixAsRelayState: - relayState, _ = c.URLParams["*"] + relayState = c.URLParams["*"] } s.idpConfigMu.RLock() diff --git a/samlidp/shortcut_test.go b/samlidp/shortcut_test.go index 7a15f148..e74f34bf 100644 --- a/samlidp/shortcut_test.go +++ b/samlidp/shortcut_test.go @@ -17,7 +17,7 @@ func TestShortcutsCrud(t *testing.T) { test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, is.Equal("{\"shortcuts\":[]}\n", - string(w.Body.Bytes()))) + w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("PUT", "https://idp.example.com/shortcuts/bob", @@ -30,14 +30,14 @@ func TestShortcutsCrud(t *testing.T) { test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, is.Equal("{\"name\":\"bob\",\"service_provider\":\"https://example.com/saml2/metadata\",\"url_suffix_as_relay_state\":true}\n", - string(w.Body.Bytes()))) + w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("GET", "https://idp.example.com/shortcuts/", nil) test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, is.Equal("{\"shortcuts\":[\"bob\"]}\n", - string(w.Body.Bytes()))) + w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("DELETE", "https://idp.example.com/shortcuts/bob", nil) @@ -49,7 +49,7 @@ func TestShortcutsCrud(t *testing.T) { test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) assert.Check(t, is.Equal("{\"shortcuts\":[]}\n", - string(w.Body.Bytes()))) + w.Body.String())) } func TestShortcut(t *testing.T) { @@ -78,7 +78,7 @@ func TestShortcut(t *testing.T) { r.Header.Set("Cookie", "session=AAIEBggKDA4QEhQWGBocHiAiJCYoKiwuMDI0Njg6PD4=") test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) - body := string(w.Body.Bytes()) + body := w.Body.String() assert.Check(t, strings.Contains(body, ""), diff --git a/samlidp/user.go b/samlidp/user.go index 46d1a964..c8c412cb 100644 --- a/samlidp/user.go +++ b/samlidp/user.go @@ -25,7 +25,7 @@ type User struct { // HandleListUsers handles the `GET /users/` request and responds with a JSON formatted list // of user names. -func (s *Server) HandleListUsers(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleListUsers(_ web.C, w http.ResponseWriter, _ *http.Request) { users, err := s.Store.List("/users/") if err != nil { s.logger.Printf("ERROR: %s", err) @@ -33,14 +33,19 @@ func (s *Server) HandleListUsers(c web.C, w http.ResponseWriter, r *http.Request return } - json.NewEncoder(w).Encode(struct { + err = json.NewEncoder(w).Encode(struct { Users []string `json:"users"` }{Users: users}) + if err != nil { + s.logger.Printf("ERROR: %s", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } } // HandleGetUser handles the `GET /users/:id` request and responds with the user object in JSON // format. The HashedPassword field is excluded. -func (s *Server) HandleGetUser(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleGetUser(c web.C, w http.ResponseWriter, _ *http.Request) { user := User{} err := s.Store.Get(fmt.Sprintf("/users/%s", c.URLParams["id"]), &user) if err != nil { @@ -49,7 +54,11 @@ func (s *Server) HandleGetUser(c web.C, w http.ResponseWriter, r *http.Request) return } user.HashedPassword = nil - json.NewEncoder(w).Encode(user) + if err := json.NewEncoder(w).Encode(user); err != nil { + s.logger.Printf("ERROR: %s", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } } // HandlePutUser handles the `PUT /users/:id` request. It accepts a JSON formatted user object in @@ -99,7 +108,7 @@ func (s *Server) HandlePutUser(c web.C, w http.ResponseWriter, r *http.Request) } // HandleDeleteUser handles the `DELETE /users/:id` request. -func (s *Server) HandleDeleteUser(c web.C, w http.ResponseWriter, r *http.Request) { +func (s *Server) HandleDeleteUser(c web.C, w http.ResponseWriter, _ *http.Request) { err := s.Store.Delete(fmt.Sprintf("/users/%s", c.URLParams["id"])) if err != nil { s.logger.Printf("ERROR: %s", err) diff --git a/samlidp/user_test.go b/samlidp/user_test.go index ecac3459..30a98a6b 100644 --- a/samlidp/user_test.go +++ b/samlidp/user_test.go @@ -16,7 +16,7 @@ func TestUsersCrud(t *testing.T) { r, _ := http.NewRequest("GET", "https://idp.example.com/users/", nil) test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) - assert.Check(t, is.Equal("{\"users\":[]}\n", string(w.Body.Bytes()))) + assert.Check(t, is.Equal("{\"users\":[]}\n", w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("PUT", "https://idp.example.com/users/alice", @@ -28,13 +28,13 @@ func TestUsersCrud(t *testing.T) { r, _ = http.NewRequest("GET", "https://idp.example.com/users/alice", nil) test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) - assert.Check(t, is.Equal("{\"name\":\"alice\"}\n", string(w.Body.Bytes()))) + assert.Check(t, is.Equal("{\"name\":\"alice\"}\n", w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("GET", "https://idp.example.com/users/", nil) test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) - assert.Check(t, is.Equal("{\"users\":[\"alice\"]}\n", string(w.Body.Bytes()))) + assert.Check(t, is.Equal("{\"users\":[\"alice\"]}\n", w.Body.String())) w = httptest.NewRecorder() r, _ = http.NewRequest("DELETE", "https://idp.example.com/users/alice", nil) @@ -45,5 +45,5 @@ func TestUsersCrud(t *testing.T) { r, _ = http.NewRequest("GET", "https://idp.example.com/users/", nil) test.Server.ServeHTTP(w, r) assert.Check(t, is.Equal(http.StatusOK, w.Code)) - assert.Check(t, is.Equal("{\"users\":[]}\n", string(w.Body.Bytes()))) + assert.Check(t, is.Equal("{\"users\":[]}\n", w.Body.String())) } diff --git a/samlsp/error.go b/samlsp/error.go index 662bce74..496faccf 100644 --- a/samlsp/error.go +++ b/samlsp/error.go @@ -14,7 +14,7 @@ type ErrorFunction func(w http.ResponseWriter, r *http.Request, err error) // DefaultOnError is the default ErrorFunction implementation. It prints // an message via the standard log package and returns a simple text // "Forbidden" message to the user. -func DefaultOnError(w http.ResponseWriter, r *http.Request, err error) { +func DefaultOnError(w http.ResponseWriter, _ *http.Request, err error) { if parseErr, ok := err.(*saml.InvalidResponseError); ok { log.Printf("WARNING: received invalid saml response: %s (now: %s) %s", parseErr.Response, parseErr.Now, parseErr.PrivateErr) diff --git a/samlsp/fetch_metadata.go b/samlsp/fetch_metadata.go index 1ef521ac..4d92503e 100644 --- a/samlsp/fetch_metadata.go +++ b/samlsp/fetch_metadata.go @@ -12,6 +12,8 @@ import ( "github.com/crewjam/httperr" xrv "github.com/mattermost/xml-roundtrip-validator" + "github.com/crewjam/saml/logger" + "github.com/crewjam/saml" ) @@ -61,7 +63,11 @@ func FetchMetadata(ctx context.Context, httpClient *http.Client, metadataURL url if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + logger.DefaultLogger.Printf("Error while closing response body during fetch metadata: %v", err) + } + }() if resp.StatusCode >= 400 { return nil, httperr.Response(*resp) } diff --git a/samlsp/fetch_metadata_go117_test.go b/samlsp/fetch_metadata_go117_test.go index be855c1a..c4edc61c 100644 --- a/samlsp/fetch_metadata_go117_test.go +++ b/samlsp/fetch_metadata_go117_test.go @@ -18,12 +18,13 @@ import ( func TestFetchMetadataRejectsInvalid(t *testing.T) { test := NewMiddlewareTest(t) - test.IDPMetadata = bytes.Replace(test.IDPMetadata, - []byte("]]"), -1) + test.IDPMetadata = bytes.ReplaceAll(test.IDPMetadata, + []byte("]]")) testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Check(t, is.Equal("/metadata", r.URL.String())) - w.Write(test.IDPMetadata) + _, err := w.Write(test.IDPMetadata) + assert.Check(t, err) })) fmt.Println(testServer.URL + "/metadata") diff --git a/samlsp/fetch_metadata_test.go b/samlsp/fetch_metadata_test.go index f1da1320..bd90dd8a 100644 --- a/samlsp/fetch_metadata_test.go +++ b/samlsp/fetch_metadata_test.go @@ -17,7 +17,8 @@ func TestFetchMetadata(t *testing.T) { testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Check(t, is.Equal("/metadata", r.URL.String())) - w.Write(test.IDPMetadata) + _, err := w.Write(test.IDPMetadata) + assert.Check(t, err) })) fmt.Println(testServer.URL + "/metadata") diff --git a/samlsp/middleware.go b/samlsp/middleware.go index 834a79c1..f5eabb16 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -1,6 +1,7 @@ package samlsp import ( + "bytes" "encoding/xml" "net/http" @@ -65,16 +66,22 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // ServeMetadata handles requests for the SAML metadata endpoint. -func (m *Middleware) ServeMetadata(w http.ResponseWriter, r *http.Request) { +func (m *Middleware) ServeMetadata(w http.ResponseWriter, _ *http.Request) { buf, _ := xml.MarshalIndent(m.ServiceProvider.Metadata(), "", " ") w.Header().Set("Content-Type", "application/samlmetadata+xml") - w.Write(buf) - return + if _, err := w.Write(buf); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } // ServeACS handles requests for the SAML ACS endpoint. func (m *Middleware) ServeACS(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + err := r.ParseForm() + if err != nil { + m.OnError(w, r, err) + return + } possibleRequestIDs := []string{} if m.ServiceProvider.AllowIDPInitiated { @@ -93,7 +100,6 @@ func (m *Middleware) ServeACS(w http.ResponseWriter, r *http.Request) { } m.CreateSessionFromAssertion(w, r, assertion, m.ServiceProvider.DefaultRedirectURI) - return } // RequireAccount is HTTP middleware that requires that each request be @@ -114,7 +120,6 @@ func (m *Middleware) RequireAccount(handler http.Handler) http.Handler { } m.OnError(w, r, err) - return }) } @@ -173,9 +178,14 @@ func (m *Middleware) HandleStartAuthFlow(w http.ResponseWriter, r *http.Request) "script-src 'sha256-AjPdJSbZmeWHnEc5ykvJFay8FTWeTeRbs9dutfZ0HqE='; "+ "reflected-xss block; referrer no-referrer;") w.Header().Add("Content-type", "text/html") - w.Write([]byte(``)) - w.Write(authReq.Post(relayState)) - w.Write([]byte(``)) + var buf bytes.Buffer + buf.WriteString(``) + buf.Write(authReq.Post(relayState)) + buf.WriteString(``) + if _, err := w.Write(buf.Bytes()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } return } panic("not reached") @@ -195,7 +205,10 @@ func (m *Middleware) CreateSessionFromAssertion(w http.ResponseWriter, r *http.R return } } else { - m.RequestTracker.StopTrackingRequest(w, r, trackedRequestIndex) + if err := m.RequestTracker.StopTrackingRequest(w, r, trackedRequestIndex); err != nil { + m.OnError(w, r, err) + return + } redirectURI = trackedRequest.URI } diff --git a/samlsp/middleware_test.go b/samlsp/middleware_test.go index 8b8863b0..801aad08 100644 --- a/samlsp/middleware_test.go +++ b/samlsp/middleware_test.go @@ -426,7 +426,7 @@ func TestMiddlewareDefaultCookieDomainIPv4(t *testing.T) { req, _ := http.NewRequest("GET", "/", nil) resp := httptest.NewRecorder() - sp.CreateSession(resp, req, &saml.Assertion{}) + assert.Check(t, sp.CreateSession(resp, req, &saml.Assertion{})) assert.Check(t, strings.Contains(resp.Header().Get("Set-Cookie"), "Domain=127.0.0.1;"), @@ -445,7 +445,7 @@ func TestMiddlewareDefaultCookieDomainIPv6(t *testing.T) { req, _ := http.NewRequest("GET", "/", nil) resp := httptest.NewRecorder() - sp.CreateSession(resp, req, &saml.Assertion{}) + assert.Check(t, sp.CreateSession(resp, req, &saml.Assertion{})) assert.Check(t, strings.Contains(resp.Header().Get("Set-Cookie"), "Domain=::1;"), diff --git a/samlsp/request_tracker_jwt.go b/samlsp/request_tracker_jwt.go index 625a66e9..0ca47258 100644 --- a/samlsp/request_tracker_jwt.go +++ b/samlsp/request_tracker_jwt.go @@ -25,7 +25,7 @@ var _ TrackedRequestCodec = JWTTrackedRequestCodec{} // JWTTrackedRequestClaims represents the JWT claims for a tracked request. type JWTTrackedRequestClaims struct { - jwt.StandardClaims + jwt.RegisteredClaims TrackedRequest SAMLAuthnRequest bool `json:"saml-authn-request"` } @@ -34,12 +34,12 @@ type JWTTrackedRequestClaims struct { func (s JWTTrackedRequestCodec) Encode(value TrackedRequest) (string, error) { now := saml.TimeNow() claims := JWTTrackedRequestClaims{ - StandardClaims: jwt.StandardClaims{ - Audience: s.Audience, - ExpiresAt: now.Add(s.MaxAge).Unix(), - IssuedAt: now.Unix(), + RegisteredClaims: jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{s.Audience}, + ExpiresAt: jwt.NewNumericDate(now.Add(s.MaxAge)), + IssuedAt: jwt.NewNumericDate(now), Issuer: s.Issuer, - NotBefore: now.Unix(), // TODO(ross): correct for clock skew + NotBefore: jwt.NewNumericDate(now), // TODO(ross): correct for clock skew Subject: value.Index, }, TrackedRequest: value, @@ -67,7 +67,7 @@ func (s JWTTrackedRequestCodec) Decode(signed string) (*TrackedRequest, error) { if !claims.VerifyIssuer(s.Issuer, true) { return nil, fmt.Errorf("expected issuer %q, got %q", s.Issuer, claims.Issuer) } - if claims.SAMLAuthnRequest != true { + if !claims.SAMLAuthnRequest { return nil, fmt.Errorf("expected saml-authn-request") } claims.TrackedRequest.Index = claims.Subject diff --git a/samlsp/session_cookie_test.go b/samlsp/session_cookie_test.go index ef594ee2..74fcf2cb 100644 --- a/samlsp/session_cookie_test.go +++ b/samlsp/session_cookie_test.go @@ -26,10 +26,12 @@ func TestCookieSameSite(t *testing.T) { resp := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) err := csp.CreateSession(resp, req, &saml.Assertion{}) - assert.Check(t, err) + assert.Check(tb, err) - cookies := resp.Result().Cookies() - assert.Check(t, is.Len(cookies, 1), "Expected to have a cookie set") + result := resp.Result() + cookies := result.Cookies() + assert.Check(tb, is.Len(cookies, 1), "Expected to have a cookie set") + assert.Check(tb, result.Body.Close()) return cookies[0] } diff --git a/samlsp/session_jwt.go b/samlsp/session_jwt.go index c4531fb9..8d801e47 100644 --- a/samlsp/session_jwt.go +++ b/samlsp/session_jwt.go @@ -106,7 +106,7 @@ func (c JWTSessionCodec) Decode(signed string) (Session, error) { if !claims.VerifyIssuer(c.Issuer, true) { return nil, fmt.Errorf("expected issuer %q, got %q", c.Issuer, claims.Issuer) } - if claims.SAMLSession != true { + if !claims.SAMLSession { return nil, errors.New("expected saml-session") } return claims, nil diff --git a/schema.go b/schema.go index f81133a2..b17c949b 100644 --- a/schema.go +++ b/schema.go @@ -48,7 +48,7 @@ type AuthnRequest struct { NameIDPolicy *NameIDPolicy `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"` Conditions *Conditions RequestedAuthnContext *RequestedAuthnContext - //Scoping *Scoping // TODO + // Scoping *Scoping // TODO ForceAuthn *bool `xml:",attr"` IsPassive *bool `xml:",attr"` @@ -108,7 +108,7 @@ func (r *LogoutRequest) Element() *etree.Element { } // MarshalXML implements xml.Marshaler -func (r *LogoutRequest) MarshalXML(e *xml.Encoder, start xml.StartElement) error { +func (r *LogoutRequest) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { type Alias LogoutRequest aux := &struct { IssueInstant RelaxedTime `xml:",attr"` @@ -209,9 +209,9 @@ func (r *AuthnRequest) Element() *etree.Element { if r.RequestedAuthnContext != nil { el.AddChild(r.RequestedAuthnContext.Element()) } - //if r.Scoping != nil { - // el.AddChild(r.Scoping.Element()) - //} + // if r.Scoping != nil { + // el.AddChild(r.Scoping.Element()) + // } if r.ForceAuthn != nil { el.CreateAttr("ForceAuthn", strconv.FormatBool(*r.ForceAuthn)) } @@ -237,7 +237,7 @@ func (r *AuthnRequest) Element() *etree.Element { } // MarshalXML implements xml.Marshaler -func (r *AuthnRequest) MarshalXML(e *xml.Encoder, start xml.StartElement) error { +func (r *AuthnRequest) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { type Alias AuthnRequest aux := &struct { IssueInstant RelaxedTime `xml:",attr"` @@ -374,7 +374,7 @@ func (r *ArtifactResolve) SoapRequest() *etree.Element { } // MarshalXML implements xml.Marshaler -func (r *ArtifactResolve) MarshalXML(e *xml.Encoder, start xml.StartElement) error { +func (r *ArtifactResolve) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { type Alias ArtifactResolve aux := &struct { IssueInstant RelaxedTime `xml:",attr"` @@ -448,7 +448,7 @@ func (r *ArtifactResponse) Element() *etree.Element { } // MarshalXML implements xml.Marshaler -func (r *ArtifactResponse) MarshalXML(e *xml.Encoder, start xml.StartElement) error { +func (r *ArtifactResponse) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { type Alias ArtifactResponse aux := &struct { IssueInstant RelaxedTime `xml:",attr"` @@ -542,7 +542,7 @@ func (r *Response) Element() *etree.Element { } // MarshalXML implements xml.Marshaler -func (r *Response) MarshalXML(e *xml.Encoder, start xml.StartElement) error { +func (r *Response) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { type Alias Response aux := &struct { IssueInstant RelaxedTime `xml:",attr"` @@ -1153,9 +1153,9 @@ func (a *SubjectLocality) Element() *etree.Element { // See http://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf ยง2.7.2.2 type AuthnContext struct { AuthnContextClassRef *AuthnContextClassRef - //AuthnContextDecl *AuthnContextDecl ... TODO - //AuthnContextDeclRef *AuthnContextDeclRef ... TODO - //AuthenticatingAuthorities []AuthenticatingAuthority... TODO + // AuthnContextDecl *AuthnContextDecl ... TODO + // AuthnContextDeclRef *AuthnContextDeclRef ... TODO + // AuthenticatingAuthorities []AuthenticatingAuthority... TODO } // Element returns an etree.Element representing the object in XML form. @@ -1292,7 +1292,7 @@ func (r *LogoutResponse) Element() *etree.Element { } // MarshalXML implements xml.Marshaler -func (r *LogoutResponse) MarshalXML(e *xml.Encoder, start xml.StartElement) error { +func (r *LogoutResponse) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { type Alias LogoutResponse aux := &struct { IssueInstant RelaxedTime `xml:",attr"` diff --git a/service_provider.go b/service_provider.go index f827e611..ad21321e 100644 --- a/service_provider.go +++ b/service_provider.go @@ -23,6 +23,7 @@ import ( dsig "github.com/russellhaering/goxmldsig" "github.com/russellhaering/goxmldsig/etreeutils" + "github.com/crewjam/saml/logger" "github.com/crewjam/saml/xmlenc" ) @@ -189,13 +190,13 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor { } } - var sloEndpoints []Endpoint - for _, binding := range sp.LogoutBindings { - sloEndpoints = append(sloEndpoints, Endpoint{ + sloEndpoints := make([]Endpoint, len(sp.LogoutBindings)) + for i, binding := range sp.LogoutBindings { + sloEndpoints[i] = Endpoint{ Binding: binding, Location: sp.SloURL.String(), ResponseLocation: sp.SloURL.String(), - }) + } } return &EntityDescriptor{ @@ -245,19 +246,23 @@ func (sp *ServiceProvider) MakeRedirectAuthenticationRequest(relayState string) } // Redirect returns a URL suitable for using the redirect binding with the request -func (req *AuthnRequest) Redirect(relayState string, sp *ServiceProvider) (*url.URL, error) { +func (r *AuthnRequest) Redirect(relayState string, sp *ServiceProvider) (*url.URL, error) { w := &bytes.Buffer{} w1 := base64.NewEncoder(base64.StdEncoding, w) w2, _ := flate.NewWriter(w1, 9) doc := etree.NewDocument() - doc.SetRoot(req.Element()) + doc.SetRoot(r.Element()) if _, err := doc.WriteTo(w2); err != nil { panic(err) } - w2.Close() - w1.Close() + if err := w2.Close(); err != nil { + panic(err) + } + if err := w1.Close(); err != nil { + panic(err) + } - rv, _ := url.Parse(req.Destination) + rv, _ := url.Parse(r.Destination) // We can't depend on Query().set() as order matters for signing reqString := string(w.Bytes()) query := rv.RawQuery @@ -347,11 +352,11 @@ func (sp *ServiceProvider) getIDPSigningCerts() ([]*x509.Certificate, error) { return nil, errors.New("cannot find any signing certificate in the IDP SSO descriptor") } - var certs []*x509.Certificate + certs := make([]*x509.Certificate, len(certStrs)) // cleanup whitespace regex := regexp.MustCompile(`\s+`) - for _, certStr := range certStrs { + for i, certStr := range certStrs { certStr = regex.ReplaceAllString(certStr, "") certBytes, err := base64.StdEncoding.DecodeString(certStr) if err != nil { @@ -362,7 +367,7 @@ func (sp *ServiceProvider) getIDPSigningCerts() ([]*x509.Certificate, error) { if err != nil { return nil, err } - certs = append(certs, parsedCert) + certs[i] = parsedCert } return certs, nil @@ -434,9 +439,9 @@ func GetSigningContext(sp *ServiceProvider) (*dsig.SigningContext, error) { Leaf: sp.Certificate, } // TODO: add intermediates for SP - //for _, cert := range sp.Intermediates { - // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - //} + // for _, cert := range sp.Intermediates { + // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) + // } keyStore := dsig.TLSCertKeyStore(keyPair) if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && @@ -503,9 +508,9 @@ func (sp *ServiceProvider) MakePostAuthenticationRequest(relayState string) ([]b } // Post returns an HTML form suitable for using the HTTP-POST binding with the request -func (req *AuthnRequest) Post(relayState string) []byte { +func (r *AuthnRequest) Post(relayState string) []byte { doc := etree.NewDocument() - doc.SetRoot(req.Element()) + doc.SetRoot(r.Element()) reqBuf, err := doc.WriteToBytes() if err != nil { panic(err) @@ -525,7 +530,7 @@ func (req *AuthnRequest) Post(relayState string) []byte { SAMLRequest string RelayState string }{ - URL: req.Destination, + URL: r.Destination, SAMLRequest: encodedReqBuf, RelayState: relayState, } @@ -574,7 +579,7 @@ type InvalidResponseError struct { } func (ivr *InvalidResponseError) Error() string { - return fmt.Sprintf("Authentication failed") + return "Authentication failed" } // ErrBadStatus is returned when the assertion provided is valid but the @@ -601,7 +606,7 @@ func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID artifactResolveRequest, err := sp.MakeArtifactResolveRequest(artifactID) if err != nil { - retErr.PrivateErr = fmt.Errorf("Cannot generate artifact resolution request: %s", err) + retErr.PrivateErr = fmt.Errorf("cannot generate artifact resolution request: %s", err) return nil, retErr } @@ -628,7 +633,11 @@ func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID retErr.PrivateErr = fmt.Errorf("cannot resolve artifact: %s", err) return nil, retErr } - defer response.Body.Close() + defer func() { + if err := response.Body.Close(); err != nil { + logger.DefaultLogger.Printf("Error while closing response body during artifact resolution: %v", err) + } + }() if response.StatusCode != 200 { retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: HTTP status %d (%s)", response.StatusCode, response.Status) return nil, retErr @@ -748,11 +757,12 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme var signatureRequirement signatureRequirement sigErr := sp.validateSignature(artifactResponseEl) - if sigErr == nil { + switch sigErr { + case nil: signatureRequirement = signatureNotRequired - } else if sigErr == errSignatureElementNotPresent { + case errSignatureElementNotPresent: signatureRequirement = signatureRequired - } else { + default: retErr.PrivateErr = sigErr return nil, retErr } @@ -883,13 +893,14 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ } if signatureRequirement == signatureRequired { - if responseSignatureErr == nil { + switch responseSignatureErr { + case nil: // since the request has a signature, none of the Assertions need one signatureRequirement = signatureNotRequired - } else if responseSignatureErr == errSignatureElementNotPresent { + case errSignatureElementNotPresent: // the request has no signature, so assertions must be signed signatureRequirement = signatureRequired // nop - } else { + default: return nil, responseSignatureErr } } @@ -1073,7 +1084,7 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque return nil } -var errSignatureElementNotPresent = errors.New("Signature element not present") +var errSignatureElementNotPresent = errors.New("signature element not present") // validateSignature returns nil iff the Signature embedded in the element is valid func (sp *ServiceProvider) validateSignature(el *etree.Element) error { @@ -1149,9 +1160,9 @@ func (sp *ServiceProvider) SignLogoutRequest(req *LogoutRequest) error { Leaf: sp.Certificate, } // TODO: add intermediates for SP - //for _, cert := range sp.Intermediates { - // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - //} + // for _, cert := range sp.Intermediates { + // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) + // } keyStore := dsig.TLSCertKeyStore(keyPair) if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && @@ -1217,22 +1228,26 @@ func (sp *ServiceProvider) MakeRedirectLogoutRequest(nameID, relayState string) } // Redirect returns a URL suitable for using the redirect binding with the request -func (req *LogoutRequest) Redirect(relayState string) *url.URL { +func (r *LogoutRequest) Redirect(relayState string) *url.URL { w := &bytes.Buffer{} w1 := base64.NewEncoder(base64.StdEncoding, w) w2, _ := flate.NewWriter(w1, 9) doc := etree.NewDocument() - doc.SetRoot(req.Element()) + doc.SetRoot(r.Element()) if _, err := doc.WriteTo(w2); err != nil { panic(err) } - w2.Close() - w1.Close() + if err := w2.Close(); err != nil { + panic(err) + } + if err := w1.Close(); err != nil { + panic(err) + } - rv, _ := url.Parse(req.Destination) + rv, _ := url.Parse(r.Destination) query := rv.Query() - query.Set("SAMLRequest", string(w.Bytes())) + query.Set("SAMLRequest", w.String()) if relayState != "" { query.Set("RelayState", relayState) } @@ -1253,9 +1268,9 @@ func (sp *ServiceProvider) MakePostLogoutRequest(nameID, relayState string) ([]b } // Post returns an HTML form suitable for using the HTTP-POST binding with the request -func (req *LogoutRequest) Post(relayState string) []byte { +func (r *LogoutRequest) Post(relayState string) []byte { doc := etree.NewDocument() - doc.SetRoot(req.Element()) + doc.SetRoot(r.Element()) reqBuf, err := doc.WriteToBytes() if err != nil { panic(err) @@ -1275,7 +1290,7 @@ func (req *LogoutRequest) Post(relayState string) []byte { SAMLRequest string RelayState string }{ - URL: req.Destination, + URL: r.Destination, SAMLRequest: encodedReqBuf, RelayState: relayState, } @@ -1327,22 +1342,26 @@ func (sp *ServiceProvider) MakeRedirectLogoutResponse(logoutRequestID, relayStat } // Redirect returns a URL suitable for using the redirect binding with the LogoutResponse. -func (resp *LogoutResponse) Redirect(relayState string) *url.URL { +func (r *LogoutResponse) Redirect(relayState string) *url.URL { w := &bytes.Buffer{} w1 := base64.NewEncoder(base64.StdEncoding, w) w2, _ := flate.NewWriter(w1, 9) doc := etree.NewDocument() - doc.SetRoot(resp.Element()) + doc.SetRoot(r.Element()) if _, err := doc.WriteTo(w2); err != nil { panic(err) } - w2.Close() - w1.Close() + if err := w2.Close(); err != nil { + panic(err) + } + if err := w1.Close(); err != nil { + panic(err) + } - rv, _ := url.Parse(resp.Destination) + rv, _ := url.Parse(r.Destination) query := rv.Query() - query.Set("SAMLResponse", string(w.Bytes())) + query.Set("SAMLResponse", w.String()) if relayState != "" { query.Set("RelayState", relayState) } @@ -1363,9 +1382,9 @@ func (sp *ServiceProvider) MakePostLogoutResponse(logoutRequestID, relayState st } // Post returns an HTML form suitable for using the HTTP-POST binding with the LogoutResponse. -func (resp *LogoutResponse) Post(relayState string) []byte { +func (r *LogoutResponse) Post(relayState string) []byte { doc := etree.NewDocument() - doc.SetRoot(resp.Element()) + doc.SetRoot(r.Element()) reqBuf, err := doc.WriteToBytes() if err != nil { panic(err) @@ -1385,7 +1404,7 @@ func (resp *LogoutResponse) Post(relayState string) []byte { SAMLResponse string RelayState string }{ - URL: resp.Destination, + URL: r.Destination, SAMLResponse: encodedReqBuf, RelayState: relayState, } @@ -1406,9 +1425,9 @@ func (sp *ServiceProvider) SignLogoutResponse(resp *LogoutResponse) error { Leaf: sp.Certificate, } // TODO: add intermediates for SP - //for _, cert := range sp.Intermediates { - // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - //} + // for _, cert := range sp.Intermediates { + // keyPair.Certificate = append(keyPair.Certificate, cert.Raw) + // } keyStore := dsig.TLSCertKeyStore(keyPair) if sp.SignatureMethod != dsig.RSASHA1SignatureMethod && @@ -1498,10 +1517,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error retErr.PrivateErr = err return retErr } - if err := sp.validateLogoutResponse(&resp); err != nil { - return err - } - return nil + return sp.validateLogoutResponse(&resp) } // ValidateLogoutResponseRedirect returns a nil error if the logout response is valid. @@ -1594,6 +1610,7 @@ func firstSet(a, b string) string { // findChildren returns all the elements matching childNS/childTag that are direct children of parentEl. func findChildren(parentEl *etree.Element, childNS string, childTag string) ([]*etree.Element, error) { + //nolint:prealloc // We don't know how many child elements we'll actually put into this array. var rv []*etree.Element for _, childEl := range parentEl.ChildElements() { if childEl.Tag != childTag { diff --git a/service_provider_test.go b/service_provider_test.go index 1389c290..c1f2d80b 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -951,9 +951,10 @@ func TestSPCanProcessResponseWithoutDestination(t *testing.T) { assert.Check(t, err) } -func (test *ServiceProviderTest) responseDom() (doc *etree.Document) { +func (test *ServiceProviderTest) responseDom(t *testing.T) (doc *etree.Document) { doc = etree.NewDocument() - doc.ReadFromBytes(test.SamlResponse) + err := doc.ReadFromBytes(test.SamlResponse) + assert.Check(t, err) return doc } @@ -985,7 +986,7 @@ func TestServiceProviderMismatchedDestinationsWithSignaturePresent(t *testing.T) req := http.Request{PostForm: url.Values{}} s.AcsURL = mustParseURL("https://wrong/saml2/acs") - bytes, _ := test.responseDom().WriteToBytes() + bytes, _ := test.responseDom(t).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1005,7 +1006,7 @@ func TestServiceProviderMissingDestinationWithSignaturePresent(t *testing.T) { assert.Check(t, err) req := http.Request{PostForm: url.Values{}} - bytes, _ := removeDestinationFromDocument(addSignatureToDocument(test.responseDom())).WriteToBytes() + bytes, _ := removeDestinationFromDocument(addSignatureToDocument(test.responseDom(t))).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1026,7 +1027,7 @@ func TestSPMismatchedDestinationsWithSignaturePresent(t *testing.T) { req := http.Request{PostForm: url.Values{}} test.replaceDestination("https://wrong/saml2/acs") - bytes, _ := addSignatureToDocument(test.responseDom()).WriteToBytes() + bytes, _ := addSignatureToDocument(test.responseDom(t)).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1047,7 +1048,7 @@ func TestSPMismatchedDestinationsWithNoSignaturePresent(t *testing.T) { req := http.Request{PostForm: url.Values{}} test.replaceDestination("https://wrong/saml2/acs") - bytes, _ := test.responseDom().WriteToBytes() + bytes, _ := test.responseDom(t).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1068,7 +1069,7 @@ func TestSPMissingDestinationWithSignaturePresent(t *testing.T) { req := http.Request{PostForm: url.Values{}} test.replaceDestination("") - bytes, _ := addSignatureToDocument(test.responseDom()).WriteToBytes() + bytes, _ := addSignatureToDocument(test.responseDom(t)).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1113,19 +1114,19 @@ func TestSPInvalidAssertions(t *testing.T) { err = s.validateAssertion(&assertion, []string{"id-9e61753d64e928af5a7a341a97f420c9"}, TimeNow()) assert.Check(t, is.Error(err, "issuer is not \"https://idp.testshib.org/idp/shibboleth\"")) assertion = Assertion{} - xml.Unmarshal(assertionBuf, &assertion) + assert.Check(t, xml.Unmarshal(assertionBuf, &assertion)) assertion.Subject.NameID.NameQualifier = "bob" err = s.validateAssertion(&assertion, []string{"id-9e61753d64e928af5a7a341a97f420c9"}, TimeNow()) assert.Check(t, err) // not verified assertion = Assertion{} - xml.Unmarshal(assertionBuf, &assertion) + assert.Check(t, xml.Unmarshal(assertionBuf, &assertion)) assertion.Subject.NameID.SPNameQualifier = "bob" err = s.validateAssertion(&assertion, []string{"id-9e61753d64e928af5a7a341a97f420c9"}, TimeNow()) assert.Check(t, err) // not verified assertion = Assertion{} - xml.Unmarshal(assertionBuf, &assertion) + assert.Check(t, xml.Unmarshal(assertionBuf, &assertion)) err = s.validateAssertion(&assertion, []string{"any request id"}, TimeNow()) assert.Check(t, is.Error(err, "assertion SubjectConfirmation one of the possible request IDs ([any request id])")) @@ -1134,31 +1135,31 @@ func TestSPInvalidAssertions(t *testing.T) { err = s.validateAssertion(&assertion, []string{"id-9e61753d64e928af5a7a341a97f420c9"}, TimeNow()) assert.Check(t, is.Error(err, "assertion SubjectConfirmation Recipient is not https://15661444.ngrok.io/saml2/acs")) assertion = Assertion{} - xml.Unmarshal(assertionBuf, &assertion) + assert.Check(t, xml.Unmarshal(assertionBuf, &assertion)) assertion.Subject.SubjectConfirmations[0].SubjectConfirmationData.NotOnOrAfter = TimeNow().Add(-1 * time.Hour) err = s.validateAssertion(&assertion, []string{"id-9e61753d64e928af5a7a341a97f420c9"}, TimeNow()) assert.Check(t, is.Error(err, "assertion SubjectConfirmationData is expired")) assertion = Assertion{} - xml.Unmarshal(assertionBuf, &assertion) + assert.Check(t, xml.Unmarshal(assertionBuf, &assertion)) assertion.Conditions.NotBefore = TimeNow().Add(time.Hour) err = s.validateAssertion(&assertion, []string{"id-9e61753d64e928af5a7a341a97f420c9"}, TimeNow()) assert.Check(t, is.Error(err, "assertion Conditions is not yet valid")) assertion = Assertion{} - xml.Unmarshal(assertionBuf, &assertion) + assert.Check(t, xml.Unmarshal(assertionBuf, &assertion)) assertion.Conditions.NotOnOrAfter = TimeNow().Add(-1 * time.Hour) err = s.validateAssertion(&assertion, []string{"id-9e61753d64e928af5a7a341a97f420c9"}, TimeNow()) assert.Check(t, is.Error(err, "assertion Conditions is expired")) assertion = Assertion{} - xml.Unmarshal(assertionBuf, &assertion) + assert.Check(t, xml.Unmarshal(assertionBuf, &assertion)) assertion.Conditions.AudienceRestrictions[0].Audience.Value = "not/our/metadata/url" err = s.validateAssertion(&assertion, []string{"id-9e61753d64e928af5a7a341a97f420c9"}, TimeNow()) assert.Check(t, is.Error(err, "assertion Conditions AudienceRestriction does not contain \"https://15661444.ngrok.io/saml2/metadata\"")) assertion = Assertion{} - xml.Unmarshal(assertionBuf, &assertion) + assert.Check(t, xml.Unmarshal(assertionBuf, &assertion)) // Not having an audience is not an error assertion.Conditions.AudienceRestrictions = []AudienceRestriction{} @@ -1249,7 +1250,7 @@ func TestXswPermutationThreeIsRejected(t *testing.T) { // // When no assertions are valid, we return the first error encountered, which in this case is that // there is no Signature on the element. - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "Signature element not present")) + assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "signature element not present")) } func TestXswPermutationFourIsRejected(t *testing.T) { @@ -1279,7 +1280,7 @@ func TestXswPermutationFourIsRejected(t *testing.T) { // This permutation contains a signed assertion embedded within an unsigned assertion. // I'm pretty sure this is just not allowed, so we properly decide that there are no // signed assertions at all. - assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "Signature element not present")) + assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "signature element not present")) } func TestXswPermutationFiveIsRejected(t *testing.T) { @@ -1362,7 +1363,7 @@ func TestXswPermutationSevenIsRejected(t *testing.T) { req := http.Request{PostForm: url.Values{}} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) - //It's the assertion signature that can't be verified. The error message is generic and always mentions Response + // It's the assertion signature that can't be verified. The error message is generic and always mentions Response assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "cannot validate signature on Assertion: Signature could not be verified")) } @@ -1393,7 +1394,7 @@ func TestXswPermutationEightIsRejected(t *testing.T) { req := http.Request{PostForm: url.Values{}} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) - //It's the assertion signature that can't be verified. The error message is generic and always mentions Response + // It's the assertion signature that can't be verified. The error message is generic and always mentions Response assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "cannot validate signature on Assertion: Signature could not be verified")) } @@ -1424,7 +1425,7 @@ func TestXswPermutationNineIsRejected(t *testing.T) { req := http.Request{PostForm: url.Values{}} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) - //It's the assertion signature that can't be verified. The error message is generic and always mentions Response + // It's the assertion signature that can't be verified. The error message is generic and always mentions Response assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "cannot validate signature on Assertion: Missing signature referencing the top-level element")) } diff --git a/testsaml/parse.go b/testsaml/parse.go index 6c64398a..63e3dbc5 100644 --- a/testsaml/parse.go +++ b/testsaml/parse.go @@ -1,3 +1,4 @@ +// Package testsaml contains functions for use in testing SAML requests and responses. package testsaml import ( diff --git a/util.go b/util.go index c9731b1b..eda053ee 100644 --- a/util.go +++ b/util.go @@ -21,6 +21,7 @@ var Clock *dsig.Clock // rand.Reader, but it can be replaced for testing. var RandReader = rand.Reader +//nolint:unparam // This always receives 20, but we want the option to do more or less if needed. func randomBytes(n int) []byte { rv := make([]byte, n) diff --git a/xmlenc/cbc.go b/xmlenc/cbc.go index 77ddb3b2..991ba1eb 100644 --- a/xmlenc/cbc.go +++ b/xmlenc/cbc.go @@ -31,7 +31,7 @@ func (e CBC) Algorithm() string { // Encrypt encrypts plaintext with key, which should be a []byte of length KeySize(). // It returns an xenc:EncryptedData element. -func (e CBC) Encrypt(key interface{}, plaintext []byte, nonce []byte) (*etree.Element, error) { +func (e CBC) Encrypt(key interface{}, plaintext []byte, _ []byte) (*etree.Element, error) { keyBuf, ok := key.([]byte) if !ok { return nil, ErrIncorrectKeyType("[]byte") diff --git a/xmlenc/decrypt.go b/xmlenc/decrypt.go index 93991f9f..98a575da 100644 --- a/xmlenc/decrypt.go +++ b/xmlenc/decrypt.go @@ -90,6 +90,7 @@ func validateRSAKeyIfPresent(key interface{}, encryptedKey *etree.Element) (*rsa // if the key will work, or let the service provider know which key // to use to decrypt the message. Either way, verification is not // security-critical. + //nolint:revive,staticcheck // Keep the later empty branch so that we know to address this at a later date. if el := encryptedKey.FindElement("./KeyInfo/X509Data/X509Certificate"); el != nil { certPEMbuf := el.Text() certPEMbuf = "-----BEGIN CERTIFICATE-----\n" + certPEMbuf + "\n-----END CERTIFICATE-----\n" diff --git a/xmlenc/decrypt_test.go b/xmlenc/decrypt_test.go index 2e23f8f6..a8872f22 100644 --- a/xmlenc/decrypt_test.go +++ b/xmlenc/decrypt_test.go @@ -100,6 +100,5 @@ func TestCanDecryptWithoutCertificate(t *testing.T) { el = doc.Root().FindElement("//EncryptedData") _, err = Decrypt(key, el) assert.Check(t, err) - //assertion.NotNil(t, plaintext) }) } diff --git a/xmlenc/digest.go b/xmlenc/digest.go index 801347f2..3eaaf7bc 100644 --- a/xmlenc/digest.go +++ b/xmlenc/digest.go @@ -6,6 +6,7 @@ import ( "crypto/sha512" "hash" + //nolint:staticcheck // We should support this for legacy reasons. "golang.org/x/crypto/ripemd160" )