diff --git a/client.go b/client.go index c4db765..f6f2524 100644 --- a/client.go +++ b/client.go @@ -575,8 +575,9 @@ func (c *Client) SetAuthScheme(scheme string) *Client { return c } -// SetDigestAuth method sets the Digest Access auth scheme for the client. If a server responds with 401 and sends -// a Digest challenge in the WWW-Authenticate header, requests will be resent with the appropriate Authorization header. +// SetDigestAuth method sets the Digest Auth transport with provided credentials in the client. +// If a server responds with 401 and sends a Digest challenge in the header `WWW-Authenticate`, +// the request will be resent with the appropriate digest `Authorization` header. // // For Example: To set the Digest scheme with user "Mufasa" and password "Circle Of Life" // @@ -584,24 +585,19 @@ func (c *Client) SetAuthScheme(scheme string) *Client { // // Information about Digest Access Authentication can be found in [RFC 7616]. // -// See [Request.SetDigestAuth]. +// NOTE: +// - On the QOP `auth-int` scenario, the request body is read into memory to +// compute the body hash that consumes additional memory usage. +// - It is recommended to create a dedicated client instance for digest auth, +// as it does digest auth for all the requests raised by the client. // // [RFC 7616]: https://datatracker.ietf.org/doc/html/rfc7616 func (c *Client) SetDigestAuth(username, password string) *Client { - c.lock.Lock() - oldTransport := c.httpClient.Transport - c.lock.Unlock() - c.AddRequestMiddleware(func(c *Client, _ *Request) error { - c.httpClient.Transport = &digestTransport{ - credentials: credentials{username, password}, - transport: oldTransport, - } - return nil - }) - c.AddResponseMiddleware(func(c *Client, _ *Response) error { - c.httpClient.Transport = oldTransport - return nil - }) + dt := &digestTransport{ + credentials: &credentials{username, password}, + transport: c.Transport(), + } + c.SetTransport(dt) return c } diff --git a/client_test.go b/client_test.go index 4e24741..a40c685 100644 --- a/client_test.go +++ b/client_test.go @@ -90,79 +90,6 @@ func TestClientAuthScheme(t *testing.T) { } -func TestClientDigestAuth(t *testing.T) { - conf := defaultDigestServerConf() - ts := createDigestServer(t, conf) - defer ts.Close() - - c := dcnl(). - SetBaseURL(ts.URL+"/"). - SetDigestAuth(conf.username, conf.password) - - resp, err := c.R(). - SetResult(&AuthSuccess{}). - Get(conf.uri) - assertError(t, err) - assertEqual(t, http.StatusOK, resp.StatusCode()) - - t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) - logResponse(t, resp) -} - -func TestClientDigestSession(t *testing.T) { - conf := defaultDigestServerConf() - conf.algo = "MD5-sess" - conf.qop = "auth, auth-int" - ts := createDigestServer(t, conf) - defer ts.Close() - - c := dcnl(). - SetBaseURL(ts.URL+"/"). - SetDigestAuth(conf.username, conf.password) - - resp, err := c.R(). - SetResult(&AuthSuccess{}). - Get(conf.uri) - assertError(t, err) - assertEqual(t, http.StatusOK, resp.StatusCode()) - - t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) - logResponse(t, resp) -} - -func TestClientDigestErrors(t *testing.T) { - type test struct { - mutateConf func(*digestServerConfig) - expect error - } - tests := []test{ - {mutateConf: func(c *digestServerConfig) { c.algo = "BAD_ALGO" }, expect: ErrDigestAlgNotSupported}, - {mutateConf: func(c *digestServerConfig) { c.qop = "bad-qop" }, expect: ErrDigestQopNotSupported}, - {mutateConf: func(c *digestServerConfig) { c.qop = "" }, expect: ErrDigestNoQop}, - {mutateConf: func(c *digestServerConfig) { c.charset = "utf-16" }, expect: ErrDigestCharset}, - {mutateConf: func(c *digestServerConfig) { c.uri = "/bad" }, expect: ErrDigestBadChallenge}, - {mutateConf: func(c *digestServerConfig) { c.uri = "/unknown_param" }, expect: ErrDigestBadChallenge}, - {mutateConf: func(c *digestServerConfig) { c.uri = "/missing_value" }, expect: ErrDigestBadChallenge}, - {mutateConf: func(c *digestServerConfig) { c.uri = "/unclosed_quote" }, expect: ErrDigestBadChallenge}, - {mutateConf: func(c *digestServerConfig) { c.uri = "/no_challenge" }, expect: ErrDigestBadChallenge}, - {mutateConf: func(c *digestServerConfig) { c.uri = "/status_500" }, expect: nil}, - } - - for _, tc := range tests { - conf := defaultDigestServerConf() - tc.mutateConf(conf) - ts := createDigestServer(t, conf) - - c := dcnl(). - SetBaseURL(ts.URL+"/"). - SetDigestAuth(conf.username, conf.password) - - _, err := c.R().Get(conf.uri) - assertErrorIs(t, tc.expect, err) - ts.Close() - } -} - func TestClientResponseMiddleware(t *testing.T) { ts := createGenericServer(t) defer ts.Close() diff --git a/digest.go b/digest.go index 5b55dbe..98ac6d7 100644 --- a/digest.go +++ b/digest.go @@ -8,156 +8,108 @@ package resty import ( + "bytes" "crypto/md5" "crypto/rand" "crypto/sha256" "crypto/sha512" + "encoding/hex" "errors" "fmt" "hash" "io" "net/http" + "strconv" "strings" ) var ( - ErrDigestBadChallenge = errors.New("digest: challenge is bad") - ErrDigestCharset = errors.New("digest: unsupported charset") - ErrDigestAlgNotSupported = errors.New("digest: algorithm is not supported") - ErrDigestQopNotSupported = errors.New("digest: no supported qop in list") - ErrDigestNoQop = errors.New("digest: qop must be specified") + ErrDigestBadChallenge = errors.New("resty: digest: challenge is bad") + ErrDigestInvalidCharset = errors.New("resty: digest: invalid charset") + ErrDigestAlgNotSupported = errors.New("resty: digest: algorithm is not supported") + ErrDigestQopNotSupported = errors.New("resty: digest: qop is not supported") ) -var hashFuncs = map[string]func() hash.Hash{ +// Reference: https://datatracker.ietf.org/doc/html/rfc7616#section-6.1 +var digestHashFuncs = map[string]func() hash.Hash{ "": md5.New, "MD5": md5.New, "MD5-sess": md5.New, "SHA-256": sha256.New, "SHA-256-sess": sha256.New, - "SHA-512-256": sha512.New, - "SHA-512-256-sess": sha512.New, + "SHA-512": sha512.New, + "SHA-512-sess": sha512.New, + "SHA-512-256": sha512.New512_256, + "SHA-512-256-sess": sha512.New512_256, } +const ( + qopAuth = "auth" + qopAuthInt = "auth-int" +) + type digestTransport struct { - credentials + *credentials transport http.RoundTripper } func (dt *digestTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Copy the request, so we don't modify the input. - req2 := new(http.Request) - *req2 = *req - req2.Header = make(http.Header) - for k, s := range req.Header { - req2.Header[k] = s - } + // first request without body for all HTTP verbs + req1 := dt.cloneReq(req, true) - // Fix http: ContentLength=xxx with Body length 0 - if req2.Body == nil { - req2.ContentLength = 0 - } else if req2.GetBody != nil { - var err error - req2.Body, err = req2.GetBody() - if err != nil { - return nil, err - } + // make a request to get the 401 that contains the challenge. + res, err := dt.transport.RoundTrip(req1) + if err != nil || res.StatusCode != http.StatusUnauthorized { + return res, err } + _, _ = ioCopy(io.Discard, res.Body) + closeq(res.Body) - // Make a request to get the 401 that contains the challenge. - resp, err := dt.transport.RoundTrip(req) - if err != nil || resp.StatusCode != http.StatusUnauthorized { - return resp, err - } - chal := resp.Header.Get(hdrWwwAuthenticateKey) - if chal == "" { - return resp, ErrDigestBadChallenge + chaHdrValue := strings.TrimSpace(res.Header.Get(hdrWwwAuthenticateKey)) + if chaHdrValue == "" { + return res, ErrDigestBadChallenge } - c, err := parseChallenge(chal) + cha, err := dt.parseChallenge(chaHdrValue) if err != nil { - return resp, err + return nil, err } - // Form credentials based on the challenge - cr := dt.newCredentials(req2, c) - auth, err := cr.authorize() + // prepare second request + req2 := dt.cloneReq(req, false) + cred, err := dt.createCredentials(cha, req2) if err != nil { - return resp, err + return nil, err } - err = resp.Body.Close() + + auth, err := cred.digest(cha) if err != nil { return nil, err } - // Make authenticated request req2.Header.Set(hdrAuthorizationKey, auth) return dt.transport.RoundTrip(req2) } -func (dt *digestTransport) newCredentials(req *http.Request, c *challenge) *digestCredentials { - return &digestCredentials{ - username: dt.Username, - userhash: c.userhash, - realm: c.realm, - nonce: c.nonce, - digestURI: req.URL.RequestURI(), - algorithm: c.algorithm, - sessionAlg: strings.HasSuffix(c.algorithm, "-sess"), - opaque: c.opaque, - messageQop: c.qop, - nc: 0, - method: req.Method, - password: dt.Password, - } -} - -type challenge struct { - realm string - domain string - nonce string - opaque string - stale string - algorithm string - qop string - userhash string -} - -func (c *challenge) setValue(k, v string) error { - switch k { - case "realm": - c.realm = v - case "domain": - c.domain = v - case "nonce": - c.nonce = v - case "opaque": - c.opaque = v - case "stale": - c.stale = v - case "algorithm": - c.algorithm = v - case "qop": - c.qop = v - case "charset": - if strings.ToUpper(v) != "UTF-8" { - return ErrDigestCharset - } - case "userhash": - c.userhash = v - default: - return ErrDigestBadChallenge +func (dt *digestTransport) cloneReq(r *http.Request, first bool) *http.Request { + r1 := r.Clone(r.Context()) + if first { + r1.Body = http.NoBody + r1.ContentLength = 0 + r1.GetBody = nil } - return nil + return r1 } -func parseChallenge(input string) (*challenge, error) { +func (dt *digestTransport) parseChallenge(input string) (*digestChallenge, error) { const ws = " \n\r\t" s := strings.Trim(input, ws) if !strings.HasPrefix(s, "Digest ") { return nil, ErrDigestBadChallenge } + s = strings.Trim(s[7:], ws) - c := &challenge{} + c := &digestChallenge{} b := strings.Builder{} key := "" quoted := false @@ -187,137 +139,260 @@ func parseChallenge(input string) (*challenge, error) { b.WriteRune(r) } } + + key = strings.TrimSpace(key) if quoted || (key == "" && b.Len() > 0) { return nil, ErrDigestBadChallenge } + if key != "" { val := strings.Trim(b.String(), ws) if err := c.setValue(key, val); err != nil { return nil, err } } - return c, nil -} -type digestCredentials struct { - username string - userhash string - realm string - nonce string - digestURI string - algorithm string - sessionAlg bool - cNonce string - opaque string - messageQop string - nc int - method string - password string + return c, nil } -func (c *digestCredentials) authorize() (string, error) { - if _, ok := hashFuncs[c.algorithm]; !ok { - return "", ErrDigestAlgNotSupported +func (dt *digestTransport) createCredentials(cha *digestChallenge, req *http.Request) (*digestCredentials, error) { + cred := &digestCredentials{ + username: dt.Username, + password: dt.Password, + uri: req.URL.RequestURI(), + method: req.Method, + realm: cha.realm, + nonce: cha.nonce, + nc: cha.nc, + algorithm: cha.algorithm, + sessAlgorithm: strings.HasSuffix(cha.algorithm, "-sess"), + opaque: cha.opaque, + userHash: cha.userHash, } - if err := c.validateQop(); err != nil { - return "", err + if cha.isQopSupported(qopAuthInt) { + if err := dt.prepareBody(req); err != nil { + return nil, fmt.Errorf("resty: digest: failed to prepare body for auth-int: %w", err) + } + body, err := req.GetBody() + if err != nil { + return nil, fmt.Errorf("resty: digest: failed to get body for auth-int: %w", err) + } + if body != http.NoBody { + defer closeq(body) + h := newHashFunc(cha.algorithm) + if _, err := ioCopy(h, body); err != nil { + return nil, err + } + cred.bodyHash = hex.EncodeToString(h.Sum(nil)) + } } - resp, err := c.resp() - if err != nil { - return "", err + return cred, nil +} + +func (dt *digestTransport) prepareBody(req *http.Request) error { + if req.GetBody != nil { + return nil } - sl := make([]string, 0, 10) - if c.userhash == "true" { - // RFC 7616 3.4.4 - c.username = c.h(fmt.Sprintf("%s:%s", c.username, c.realm)) - sl = append(sl, fmt.Sprintf(`userhash=%s`, c.userhash)) + if req.Body == nil || req.Body == http.NoBody { + req.GetBody = func() (io.ReadCloser, error) { + return http.NoBody, nil + } + return nil } - sl = append(sl, fmt.Sprintf(`username="%s"`, c.username)) - sl = append(sl, fmt.Sprintf(`realm="%s"`, c.realm)) - sl = append(sl, fmt.Sprintf(`nonce="%s"`, c.nonce)) - sl = append(sl, fmt.Sprintf(`uri="%s"`, c.digestURI)) - sl = append(sl, fmt.Sprintf(`response="%s"`, resp)) - sl = append(sl, fmt.Sprintf(`algorithm=%s`, c.algorithm)) - if c.opaque != "" { - sl = append(sl, fmt.Sprintf(`opaque="%s"`, c.opaque)) + + b, err := ioReadAll(req.Body) + if err != nil { + return err } - if c.messageQop != "" { - sl = append(sl, fmt.Sprintf("qop=%s", c.messageQop)) - sl = append(sl, fmt.Sprintf("nc=%08x", c.nc)) - sl = append(sl, fmt.Sprintf(`cnonce="%s"`, c.cNonce)) + closeq(req.Body) + req.Body = io.NopCloser(bytes.NewReader(b)) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(b)), nil } - return fmt.Sprintf("Digest %s", strings.Join(sl, ", ")), nil + return nil +} + +type digestChallenge struct { + realm string + domain string + nonce string + opaque string + stale string + algorithm string + qop []string + nc int + userHash string } -func (c *digestCredentials) validateQop() error { - // Currently only supporting auth quality of protection. TODO: add auth-int support - // NOTE: cURL support auth-int qop for requests other than POST and PUT (i.e. w/o body) by hashing an empty string - // is this applicable for resty? see: https://github.com/curl/curl/blob/307b7543ea1e73ab04e062bdbe4b5bb409eaba3a/lib/vauth/digest.c#L774 - if c.messageQop == "" { - return ErrDigestNoQop +func (dc *digestChallenge) isQopSupported(qop string) bool { + for _, v := range dc.qop { + if v == qop { + return true + } } - possibleQops := strings.Split(c.messageQop, ",") - var authSupport bool - for _, qop := range possibleQops { - qop = strings.TrimSpace(qop) - if qop == "auth" { - authSupport = true - break + return false +} + +func (dc *digestChallenge) setValue(k, v string) error { + switch k { + case "realm": + dc.realm = v + case "domain": + dc.domain = v + case "nonce": + dc.nonce = v + case "opaque": + dc.opaque = v + case "stale": + dc.stale = v + case "algorithm": + dc.algorithm = v + case "qop": + if !isStringEmpty(v) { + dc.qop = strings.Split(v, ",") + } + case "charset": + if strings.ToUpper(v) != "UTF-8" { + return ErrDigestInvalidCharset } + case "nc": + nc, err := strconv.ParseInt(v, 16, 32) + if err != nil { + return fmt.Errorf("resty: digest: invalid nc: %w", err) + } + dc.nc = int(nc) + case "userhash": + dc.userHash = v + default: + return ErrDigestBadChallenge + } + return nil +} + +type digestCredentials struct { + username string + password string + userHash string + method string + uri string + realm string + nonce string + algorithm string + sessAlgorithm bool + cnonce string + opaque string + qop string + nc int + response string + bodyHash string +} + +func (dc *digestCredentials) parseQop(cha *digestChallenge) error { + if len(cha.qop) == 0 { + return nil } - if !authSupport { - return ErrDigestQopNotSupported + + if cha.isQopSupported(qopAuth) { + dc.qop = qopAuth + return nil } - c.messageQop = "auth" + if cha.isQopSupported(qopAuthInt) { + dc.qop = qopAuthInt + return nil + } - return nil + return ErrDigestQopNotSupported } -func (c *digestCredentials) h(data string) string { - hfCtor := hashFuncs[c.algorithm] - hf := hfCtor() - _, _ = hf.Write([]byte(data)) // Hash.Write never returns an error - return fmt.Sprintf("%x", hf.Sum(nil)) +func (dc *digestCredentials) h(data string) string { + h := newHashFunc(dc.algorithm) + _, _ = h.Write([]byte(data)) + return hex.EncodeToString(h.Sum(nil)) } -func (c *digestCredentials) resp() (string, error) { - c.nc++ +func (dc *digestCredentials) digest(cha *digestChallenge) (string, error) { + if _, ok := digestHashFuncs[dc.algorithm]; !ok { + return "", ErrDigestAlgNotSupported + } - b := make([]byte, 16) - _, err := io.ReadFull(rand.Reader, b) - if err != nil { + if err := dc.parseQop(cha); err != nil { return "", err } - c.cNonce = fmt.Sprintf("%x", b)[:32] - ha1 := c.ha1() - ha2 := c.ha2() + dc.nc++ - return c.kd(ha1, fmt.Sprintf("%s:%08x:%s:%s:%s", - c.nonce, c.nc, c.cNonce, c.messageQop, ha2)), nil + b := make([]byte, 16) + _, _ = io.ReadFull(rand.Reader, b) + dc.cnonce = hex.EncodeToString(b) + + ha1 := dc.ha1() + ha2 := dc.ha2() + + var resp string + switch dc.qop { + case "": + resp = fmt.Sprintf("%s:%s:%s", ha1, dc.nonce, ha2) + case qopAuth, qopAuthInt: + resp = fmt.Sprintf("%s:%s:%08x:%s:%s:%s", + ha1, dc.nonce, dc.nc, dc.cnonce, dc.qop, ha2) + } + dc.response = dc.h(resp) + + return "Digest " + dc.String(), nil } -func (c *digestCredentials) kd(secret, data string) string { - return c.h(fmt.Sprintf("%s:%s", secret, data)) +// https://datatracker.ietf.org/doc/html/rfc7616#section-3.4.2 +func (dc *digestCredentials) ha1() string { + a1 := dc.h(fmt.Sprintf("%s:%s:%s", dc.username, dc.realm, dc.password)) + if dc.sessAlgorithm { + return dc.h(fmt.Sprintf("%s:%s:%s", a1, dc.nonce, dc.cnonce)) + } + return a1 } -// RFC 7616 3.4.2 -func (c *digestCredentials) ha1() string { - ret := c.h(fmt.Sprintf("%s:%s:%s", c.username, c.realm, c.password)) - if c.sessionAlg { - return c.h(fmt.Sprintf("%s:%s:%s", ret, c.nonce, c.cNonce)) +// https://datatracker.ietf.org/doc/html/rfc7616#section-3.4.3 +func (dc *digestCredentials) ha2() string { + if dc.qop == qopAuthInt { + return dc.h(fmt.Sprintf("%s:%s:%s", dc.method, dc.uri, dc.bodyHash)) + } + return dc.h(fmt.Sprintf("%s:%s", dc.method, dc.uri)) +} + +func (dc *digestCredentials) String() string { + sl := make([]string, 0, 10) + // https://datatracker.ietf.org/doc/html/rfc7616#section-3.4.4 + if dc.userHash == "true" { + dc.username = dc.h(fmt.Sprintf("%s:%s", dc.username, dc.realm)) + } + sl = append(sl, fmt.Sprintf(`username="%s"`, dc.username)) + sl = append(sl, fmt.Sprintf(`realm="%s"`, dc.realm)) + sl = append(sl, fmt.Sprintf(`nonce="%s"`, dc.nonce)) + sl = append(sl, fmt.Sprintf(`uri="%s"`, dc.uri)) + if dc.algorithm != "" { + sl = append(sl, fmt.Sprintf(`algorithm=%s`, dc.algorithm)) + } + if dc.opaque != "" { + sl = append(sl, fmt.Sprintf(`opaque="%s"`, dc.opaque)) + } + if dc.qop != "" { + sl = append(sl, fmt.Sprintf("qop=%s", dc.qop)) + sl = append(sl, fmt.Sprintf("nc=%08x", dc.nc)) + sl = append(sl, fmt.Sprintf(`cnonce="%s"`, dc.cnonce)) } + sl = append(sl, fmt.Sprintf(`userhash=%s`, dc.userHash)) + sl = append(sl, fmt.Sprintf(`response="%s"`, dc.response)) - return ret + return strings.Join(sl, ", ") } -// RFC 7616 3.4.3 -func (c *digestCredentials) ha2() string { - // currently no auth-int support - return c.h(fmt.Sprintf("%s:%s", c.method, c.digestURI)) +func newHashFunc(algorithm string) hash.Hash { + hf := digestHashFuncs[algorithm] + h := hf() + h.Reset() + return h } diff --git a/digest_test.go b/digest_test.go new file mode 100644 index 0000000..76f8410 --- /dev/null +++ b/digest_test.go @@ -0,0 +1,293 @@ +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// resty source code and usage is governed by a MIT style +// license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT + +package resty + +import ( + "errors" + "io" + "net/http" + "strings" + "testing" +) + +type digestServerConfig struct { + realm, qop, nonce, opaque, algo, uri, charset, username, password, nc string +} + +func defaultDigestServerConf() *digestServerConfig { + return &digestServerConfig{ + realm: "testrealm@host.com", + qop: "auth", + nonce: "dcd98b7102dd2f0e8b11d0f600bfb0c093", + opaque: "5ccc069c403ebaf9f0171e9517f40e41", + algo: "MD5", + uri: "/dir/index.html", + charset: "utf-8", + username: "Mufasa", + password: "Circle Of Life", + nc: "00000001", + } +} + +func TestClientDigestAuth(t *testing.T) { + conf := *defaultDigestServerConf() + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl(). + SetBaseURL(ts.URL+"/"). + SetDigestAuth(conf.username, conf.password) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + Get(conf.uri) + assertError(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) +} + +func TestClientDigestAuthSession(t *testing.T) { + conf := *defaultDigestServerConf() + conf.algo = "MD5-sess" + conf.qop = "auth, auth-int" + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl(). + SetBaseURL(ts.URL+"/"). + SetDigestAuth(conf.username, conf.password) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + Get(conf.uri) + assertError(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) +} + +func TestClientDigestAuthErrors(t *testing.T) { + type test struct { + mutateConf func(*digestServerConfig) + expect error + } + tests := []test{ + {mutateConf: func(c *digestServerConfig) { c.algo = "BAD_ALGO" }, expect: ErrDigestAlgNotSupported}, + {mutateConf: func(c *digestServerConfig) { c.qop = "bad-qop" }, expect: ErrDigestQopNotSupported}, + {mutateConf: func(c *digestServerConfig) { c.charset = "utf-16" }, expect: ErrDigestInvalidCharset}, + {mutateConf: func(c *digestServerConfig) { c.uri = "/bad" }, expect: ErrDigestBadChallenge}, + {mutateConf: func(c *digestServerConfig) { c.uri = "/unknown_param" }, expect: ErrDigestBadChallenge}, + {mutateConf: func(c *digestServerConfig) { c.uri = "/missing_value" }, expect: ErrDigestBadChallenge}, + {mutateConf: func(c *digestServerConfig) { c.uri = "/unclosed_quote" }, expect: ErrDigestBadChallenge}, + {mutateConf: func(c *digestServerConfig) { c.uri = "/no_challenge" }, expect: ErrDigestBadChallenge}, + {mutateConf: func(c *digestServerConfig) { c.uri = "/status_500" }, expect: nil}, + } + + for _, tc := range tests { + conf := *defaultDigestServerConf() + tc.mutateConf(&conf) + ts := createDigestServer(t, &conf) + + c := dcnl(). + SetBaseURL(ts.URL+"/"). + SetDigestAuth(conf.username, conf.password) + + _, err := c.R().Get(conf.uri) + assertErrorIs(t, tc.expect, err) + ts.Close() + } +} + +func TestClientDigestAuthWithBody(t *testing.T) { + conf := *defaultDigestServerConf() + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetDigestAuth(conf.username, conf.password) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + SetHeader(hdrContentTypeKey, "application/json"). + SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). + Post(ts.URL + conf.uri) + + assertError(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) +} + +func TestClientDigestAuthWithBodyQopAuthInt(t *testing.T) { + conf := *defaultDigestServerConf() + conf.qop = "auth-int" + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetDigestAuth(conf.username, conf.password) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + SetHeader(hdrContentTypeKey, "application/json"). + SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). + Post(ts.URL + conf.uri) + + assertError(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) +} + +func TestClientDigestAuthWithBodyQopAuthIntIoCopyError(t *testing.T) { + conf := *defaultDigestServerConf() + conf.qop = "auth-int" + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetDigestAuth(conf.username, conf.password) + + errCopyMsg := "test copy error" + ioCopy = func(dst io.Writer, src io.Reader) (written int64, err error) { + return 0, errors.New(errCopyMsg) + } + t.Cleanup(func() { + ioCopy = io.Copy + }) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + SetHeader(hdrContentTypeKey, "application/json"). + SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). + Post(ts.URL + conf.uri) + + assertNotNil(t, err) + assertEqual(t, true, strings.Contains(err.Error(), errCopyMsg)) + assertEqual(t, 0, resp.StatusCode()) +} + +func TestClientDigestAuthWithBodyQopAuthIntGetBodyNil(t *testing.T) { + conf := *defaultDigestServerConf() + conf.qop = "auth-int" + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetDigestAuth(conf.username, conf.password) + c.SetRequestMiddlewares( + PrepareRequestMiddleware, + func(c *Client, r *Request) error { + r.RawRequest.GetBody = nil + return nil + }, + ) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + SetHeader(hdrContentTypeKey, "application/json"). + SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). + Post(ts.URL + conf.uri) + + assertError(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) +} + +func TestClientDigestAuthWithGetBodyError(t *testing.T) { + conf := *defaultDigestServerConf() + conf.qop = "auth-int" + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetDigestAuth(conf.username, conf.password) + c.SetRequestMiddlewares( + PrepareRequestMiddleware, + func(c *Client, r *Request) error { + r.RawRequest.GetBody = func() (_ io.ReadCloser, _ error) { + return nil, errors.New("get body test error") + } + return nil + }, + ) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + SetHeader(hdrContentTypeKey, "application/json"). + SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). + Post(ts.URL + conf.uri) + + assertNotNil(t, err) + assertEqual(t, true, strings.Contains(err.Error(), "resty: digest: failed to get body for auth-int: get body test error")) + assertEqual(t, 0, resp.StatusCode()) +} + +func TestClientDigestAuthWithGetBodyNilReadError(t *testing.T) { + conf := *defaultDigestServerConf() + conf.qop = "auth-int" + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetDigestAuth(conf.username, conf.password) + c.SetRequestMiddlewares( + PrepareRequestMiddleware, + func(c *Client, r *Request) error { + r.RawRequest.GetBody = nil + return nil + }, + ) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + SetHeader(hdrContentTypeKey, "application/json"). + SetBody(&brokenReadCloser{}). + Post(ts.URL + conf.uri) + + assertNotNil(t, err) + assertEqual(t, true, strings.Contains(err.Error(), "resty: digest: failed to prepare body for auth-int: read error")) + assertEqual(t, 0, resp.StatusCode()) +} + +func TestClientDigestAuthWithNoBodyQopAuthInt(t *testing.T) { + conf := *defaultDigestServerConf() + conf.qop = "auth-int" + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetDigestAuth(conf.username, conf.password) + + resp, err := c.R().Get(ts.URL + conf.uri) + + assertError(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) +} + +func TestClientDigestAuthNoQop(t *testing.T) { + conf := *defaultDigestServerConf() + conf.qop = "" + + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetDigestAuth(conf.username, conf.password) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + SetHeader(hdrContentTypeKey, "application/json"). + SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). + Post(ts.URL + conf.uri) + + assertNil(t, err) + assertEqual(t, "200 OK", resp.Status()) +} + +func TestClientDigestAuthWithIncorrectNcValue(t *testing.T) { + conf := *defaultDigestServerConf() + conf.nc = "1234567890" + + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetDigestAuth(conf.username, conf.password) + + resp, err := c.R(). + SetResult(&AuthSuccess{}). + SetHeader(hdrContentTypeKey, "application/json"). + SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). + Post(ts.URL + conf.uri) + + assertNotNil(t, err) + assertEqual(t, true, strings.Contains(err.Error(), `parsing "1234567890": value out of range`)) + assertEqual(t, "", resp.Status()) +} diff --git a/request.go b/request.go index 70f5ce8..3d759e1 100644 --- a/request.go +++ b/request.go @@ -661,36 +661,6 @@ func (r *Request) SetAuthScheme(scheme string) *Request { return r } -// SetDigestAuth method sets the Digest Access auth scheme for the HTTP request. -// If a server responds with 401 and sends a Digest challenge in the WWW-Authenticate Header, -// the request will be resent with the appropriate Authorization Header. -// -// For Example: To set the Digest scheme with username "Mufasa" and password "Circle Of Life" -// -// client.R().SetDigestAuth("Mufasa", "Circle Of Life") -// -// Information about Digest Access Authentication can be found in [RFC 7616] -// -// It overrides the digest username and password set by method [Client.SetDigestAuth]. -// -// [RFC 7616]: https://datatracker.ietf.org/doc/html/rfc7616 -func (r *Request) SetDigestAuth(username, password string) *Request { - oldTransport := r.client.httpClient.Transport - r.client.AddRequestMiddleware(func(c *Client, _ *Request) error { - c.httpClient.Transport = &digestTransport{ - credentials: credentials{username, password}, - transport: oldTransport, - } - return nil - }) - r.client.AddResponseMiddleware(func(c *Client, _ *Response) error { - c.httpClient.Transport = oldTransport - return nil - }) - - return r -} - // SetOutputFile method sets the output file for the current HTTP request. The current // HTTP response will be saved in the given file. It is similar to the `curl -o` flag. // diff --git a/request_test.go b/request_test.go index 2f0ca69..eaa1731 100644 --- a/request_test.go +++ b/request_test.go @@ -680,59 +680,6 @@ func TestRequestAuthScheme(t *testing.T) { assertEqual(t, http.StatusOK, resp.StatusCode()) } -func TestRequestDigestAuth(t *testing.T) { - conf := defaultDigestServerConf() - ts := createDigestServer(t, nil) - defer ts.Close() - - resp, err := dcnldr(). - SetDigestAuth(conf.username, conf.password). - SetResult(&AuthSuccess{}). - Get(ts.URL + conf.uri) - - assertError(t, err) - assertEqual(t, http.StatusOK, resp.StatusCode()) - - t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) - logResponse(t, resp) -} - -func TestRequestDigestAuthFail(t *testing.T) { - conf := defaultDigestServerConf() - ts := createDigestServer(t, nil) - defer ts.Close() - - resp, err := dcnldr(). - SetDigestAuth(conf.username, "wrongPassword"). - SetError(AuthError{}). - Get(ts.URL + conf.uri) - - assertError(t, err) - assertEqual(t, http.StatusUnauthorized, resp.StatusCode()) - - t.Logf("Result Error: %q", resp.Error().(*AuthError)) - logResponse(t, resp) -} - -func TestRequestDigestAuthWithBody(t *testing.T) { - conf := defaultDigestServerConf() - ts := createDigestServer(t, nil) - defer ts.Close() - - resp, err := dcnldr(). - SetDigestAuth(conf.username, conf.password). - SetResult(&AuthSuccess{}). - SetHeader(hdrContentTypeKey, "application/json"). - SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). - Post(ts.URL + conf.uri) - - assertError(t, err) - assertEqual(t, http.StatusOK, resp.StatusCode()) - - t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) - logResponse(t, resp) -} - func TestFormData(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() diff --git a/resty_test.go b/resty_test.go index 25fcc65..a1fd224 100644 --- a/resty_test.go +++ b/resty_test.go @@ -10,9 +10,9 @@ import ( "compress/flate" "compress/gzip" "compress/lzw" - "crypto/md5" "crypto/tls" "encoding/base64" + "encoding/hex" "encoding/json" "encoding/xml" "errors" @@ -768,24 +768,6 @@ func createUnixSocketEchoServer(t *testing.T) string { return socketPath } -type digestServerConfig struct { - realm, qop, nonce, opaque, algo, uri, charset, username, password string -} - -func defaultDigestServerConf() *digestServerConfig { - return &digestServerConfig{ - realm: "testrealm@host.com", - qop: "auth", - nonce: "dcd98b7102dd2f0e8b11d0f600bfb0c093", - opaque: "5ccc069c403ebaf9f0171e9517f40e41", - algo: "MD5", - uri: "/dir/index.html", - charset: "utf-8", - username: "Mufasa", - password: "Circle Of Life", - } -} - func createDigestServer(t *testing.T, conf *digestServerConfig) *httptest.Server { if conf == nil { conf = defaultDigestServerConf() @@ -822,14 +804,14 @@ func createDigestServer(t *testing.T, conf *digestServerConfig) *httptest.Server w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") - if !authorizationHeaderValid(t, r, conf) { - setWWWAuthHeader(w, - fmt.Sprintf(`Digest realm="%s", domain="%s", qop="%s", algorithm=%s, nonce="%s", opaque="%s", userhash=true, charset=%s, stale=FALSE`, - conf.realm, conf.uri, conf.qop, conf.algo, conf.nonce, conf.opaque, conf.charset)) - _, _ = w.Write([]byte(`{ "id": "unauthorized", "message": "Invalid credentials" }`)) - } else { + if authorizationHeaderValid(t, r, conf) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{ "id": "success", "message": "login successful" }`)) + } else { + setWWWAuthHeader(w, + fmt.Sprintf(`Digest realm="%s", domain="%s", qop="%s", algorithm=%s, nonce="%s", opaque="%s", userhash=true, charset=%s, stale=FALSE, nc=%s`, + conf.realm, conf.uri, conf.qop, conf.algo, conf.nonce, conf.opaque, conf.charset, conf.nc)) + _, _ = w.Write([]byte(`{ "id": "unauthorized", "message": "Invalid credentials" }`)) } }) @@ -837,20 +819,11 @@ func createDigestServer(t *testing.T, conf *digestServerConfig) *httptest.Server } func authorizationHeaderValid(t *testing.T, r *http.Request, conf *digestServerConfig) bool { - h := func(data string) (string, error) { - hf := md5.New() - - _, err := io.WriteString(hf, data) - if err != nil { - return "", err - } - - return fmt.Sprintf("%x", hf.Sum(nil)), nil - } input := r.Header.Get(hdrAuthorizationKey) if input == "" { return false } + const ws = " \n\r\t" const qs = `"` s := strings.Trim(input, ws) @@ -864,28 +837,53 @@ func authorizationHeaderValid(t *testing.T, r *http.Request, conf *digestServerC pairs[pair[0]] = strings.Trim(pair[1], qs) } - assertEqual(t, conf.opaque, pairs["opaque"]) assertEqual(t, conf.algo, pairs["algorithm"]) + h := func(data string) string { + h := newHashFunc(pairs["algorithm"]) + _, _ = h.Write([]byte(data)) + return hex.EncodeToString(h.Sum(nil)) + } + + assertEqual(t, conf.opaque, pairs["opaque"]) assertEqual(t, "true", pairs["userhash"]) - userhash, err := h(fmt.Sprintf("%s:%s", conf.username, conf.realm)) - assertError(t, err) - assertEqual(t, userhash, pairs["username"]) + userHash := h(fmt.Sprintf("%s:%s", conf.username, conf.realm)) + assertEqual(t, userHash, pairs["username"]) - ha1, err := h(fmt.Sprintf("%s:%s:%s", conf.username, conf.realm, conf.password)) - assertError(t, err) + ha1 := h(fmt.Sprintf("%s:%s:%s", conf.username, conf.realm, conf.password)) if strings.HasSuffix(conf.algo, "-sess") { - ha1, err = h(fmt.Sprintf("%s:%s:%s", ha1, pairs["nonce"], pairs["cnonce"])) - assertError(t, err) + ha1 = h(fmt.Sprintf("%s:%s:%s", ha1, pairs["nonce"], pairs["cnonce"])) } - ha2, err := h(fmt.Sprintf("%s:%s", r.Method, conf.uri)) - assertError(t, err) + ha2 := h(fmt.Sprintf("%s:%s", r.Method, conf.uri)) + + qop := pairs["qop"] + if qop == "" { + kd := h(fmt.Sprintf("%s:%s:%s", ha1, pairs["nonce"], ha2)) + return kd == pairs["response"] + } + nonceCount, err := strconv.Atoi(pairs["nc"]) assertError(t, err) - kd, err := h(fmt.Sprintf("%s:%s", ha1, fmt.Sprintf("%s:%08x:%s:%s:%s", - pairs["nonce"], nonceCount, pairs["cnonce"], pairs["qop"], ha2))) + + // auth scenario + if qop == qopAuth { + kd := h(fmt.Sprintf("%s:%s", ha1, fmt.Sprintf("%s:%08x:%s:%s:%s", + pairs["nonce"], nonceCount, pairs["cnonce"], pairs["qop"], ha2))) + return kd == pairs["response"] + } + + // auth-int scenario + body, err := io.ReadAll(r.Body) + r.Body.Close() assertError(t, err) + bodyHash := "" + if len(body) > 0 { + bodyHash = h(string(body)) + } + ha2 = h(fmt.Sprintf("%s:%s:%s", r.Method, conf.uri, bodyHash)) + kd := h(fmt.Sprintf("%s:%s", ha1, fmt.Sprintf("%s:%08x:%s:%s:%s", + pairs["nonce"], nonceCount, pairs["cnonce"], pairs["qop"], ha2))) return kd == pairs["response"] }