Skip to content

Commit

Permalink
Merge pull request #216 from mgwoj/main
Browse files Browse the repository at this point in the history
Re-sign request on retry
  • Loading branch information
manicminer authored May 9, 2024
2 parents 8db2bb6 + edadfe1 commit ca91556
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 4 deletions.
23 changes: 19 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t
// attempted. If overriding this, be sure to close the body if needed.
type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error)

// PrepareRetry is called before retry operation. It can be used for example to re-sign the request
type PrepareRetry func(req *http.Request) error

// Client is used to make HTTP requests. It adds additional functionality
// like automatic retries to tolerate minor outages.
type Client struct {
Expand Down Expand Up @@ -423,6 +426,9 @@ type Client struct {
// ErrorHandler specifies the custom error handler to use, if any
ErrorHandler ErrorHandler

// PrepareRetry can prepare the request for retry operation, for example re-sign it
PrepareRetry PrepareRetry

loggerInit sync.Once
clientInit sync.Once
}
Expand Down Expand Up @@ -653,10 +659,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
var resp *http.Response
var attempt int
var shouldRetry bool
var doErr, respErr, checkErr error
var doErr, respErr, checkErr, prepareErr error

for i := 0; ; i++ {
doErr, respErr = nil, nil
doErr, respErr, prepareErr = nil, nil, nil
attempt++

// Always rewind the request body when non-nil.
Expand Down Expand Up @@ -763,17 +769,26 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
// without racing against the closeBody call in persistConn.writeLoop.
httpreq := *req.Request
req.Request = &httpreq

if c.PrepareRetry != nil {
if err := c.PrepareRetry(req.Request); err != nil {
prepareErr = err
break
}
}
}

// this is the closest we have to success criteria
if doErr == nil && respErr == nil && checkErr == nil && !shouldRetry {
if doErr == nil && respErr == nil && checkErr == nil && prepareErr == nil && !shouldRetry {
return resp, nil
}

defer c.HTTPClient.CloseIdleConnections()

var err error
if checkErr != nil {
if prepareErr != nil {
err = prepareErr
} else if checkErr != nil {
err = checkErr
} else if respErr != nil {
err = respErr
Expand Down
123 changes: 123 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net/http/httptest"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -352,6 +353,128 @@ func TestClient_Do_WithResponseHandler(t *testing.T) {
}
}

func TestClient_Do_WithPrepareRetry(t *testing.T) {
// Create the client. Use short retry windows so we fail faster.
client := NewClient()
client.RetryWaitMin = 10 * time.Millisecond
client.RetryWaitMax = 10 * time.Millisecond
client.RetryMax = 2

var checks int
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
checks++
if err != nil && strings.Contains(err.Error(), "nonretryable") {
return false, nil
}
return DefaultRetryPolicy(context.TODO(), resp, err)
}

var prepareChecks int
client.PrepareRetry = func(req *http.Request) error {
prepareChecks++
req.Header.Set("foo", strconv.Itoa(prepareChecks))
return nil
}

// Mock server which always responds 200.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()

var shouldSucceed bool
tests := []struct {
name string
handler ResponseHandlerFunc
expectedChecks int // often 2x number of attempts since we check twice
expectedPrepareChecks int
err string
}{
{
name: "nil handler",
handler: nil,
expectedChecks: 1,
expectedPrepareChecks: 0,
},
{
name: "handler always succeeds",
handler: func(*http.Response) error {
return nil
},
expectedChecks: 2,
expectedPrepareChecks: 0,
},
{
name: "handler always fails in a retryable way",
handler: func(*http.Response) error {
return errors.New("retryable failure")
},
expectedChecks: 6,
expectedPrepareChecks: 2,
},
{
name: "handler always fails in a nonretryable way",
handler: func(*http.Response) error {
return errors.New("nonretryable failure")
},
expectedChecks: 2,
expectedPrepareChecks: 0,
},
{
name: "handler succeeds on second attempt",
handler: func(*http.Response) error {
if shouldSucceed {
return nil
}
shouldSucceed = true
return errors.New("retryable failure")
},
expectedChecks: 4,
expectedPrepareChecks: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
checks = 0
prepareChecks = 0
shouldSucceed = false
// Create the request
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
req.SetResponseHandler(tt.handler)

// Send the request.
_, err = client.Do(req)
if err != nil && !strings.Contains(err.Error(), tt.err) {
t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error())
}
if err == nil && tt.err != "" {
t.Fatalf("no error, expected: %s", tt.err)
}

if checks != tt.expectedChecks {
t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks)
}

if prepareChecks != tt.expectedPrepareChecks {
t.Fatalf("expected %d attempts of prepare check, got %d attempts", tt.expectedPrepareChecks, prepareChecks)
}
header := req.Request.Header.Get("foo")
if tt.expectedPrepareChecks == 0 && header != "" {
t.Fatalf("expected no changes to request header 'foo', but got '%s'", header)
}
expectedHeader := strconv.Itoa(tt.expectedPrepareChecks)
if tt.expectedPrepareChecks != 0 && header != expectedHeader {
t.Fatalf("expected changes in request header 'foo' '%s', but got '%s'", expectedHeader, header)
}

})
}
}

func TestClient_Do_fails(t *testing.T) {
// Mock server which always responds 500.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit ca91556

Please sign in to comment.