diff --git a/aws/request/connection_reset_error.go b/aws/request/connection_reset_error.go
index 2d13754cfe1..d9b37f4d32a 100644
--- a/aws/request/connection_reset_error.go
+++ b/aws/request/connection_reset_error.go
@@ -5,5 +5,14 @@ import (
)
func isErrConnectionReset(err error) bool {
- return strings.Contains(err.Error(), "connection reset")
+ if strings.Contains(err.Error(), "read: connection reset") {
+ return false
+ }
+
+ if strings.Contains(err.Error(), "connection reset") ||
+ strings.Contains(err.Error(), "broken pipe") {
+ return true
+ }
+
+ return false
}
diff --git a/aws/request/connection_reset_error_test.go b/aws/request/connection_reset_error_test.go
index cae48a3a7e2..102461426b6 100644
--- a/aws/request/connection_reset_error_test.go
+++ b/aws/request/connection_reset_error_test.go
@@ -36,14 +36,22 @@ func TestSerializationErrConnectionReset_accept(t *testing.T) {
Err error
ExpectAttempts int
}{
- "with temporary": {
+ "accept with temporary": {
Err: errAcceptConnectionResetStub,
ExpectAttempts: 6,
},
- "not temporary": {
+ "read not temporary": {
Err: errReadConnectionResetStub,
ExpectAttempts: 1,
},
+ "write with temporary": {
+ Err: errWriteConnectionResetStub,
+ ExpectAttempts: 6,
+ },
+ "write broken pipe with temporary": {
+ Err: errWriteBrokenPipeStub,
+ ExpectAttempts: 6,
+ },
"generic connection reset": {
Err: errConnectionResetStub,
ExpectAttempts: 6,
@@ -86,6 +94,7 @@ func TestSerializationErrConnectionReset_accept(t *testing.T) {
}
cfg := unit.Session.Config.Copy()
cfg.MaxRetries = aws.Int(5)
+ cfg.SleepDelay = func(time.Duration) {}
req := request.New(
*cfg,
diff --git a/aws/request/http_request_retry_test.go b/aws/request/http_request_retry_test.go
index 2e057f69934..fcdd1ce819b 100644
--- a/aws/request/http_request_retry_test.go
+++ b/aws/request/http_request_retry_test.go
@@ -1,5 +1,3 @@
-// +build go1.5
-
package request_test
import (
diff --git a/aws/request/request.go b/aws/request/request.go
index 19da3fcd826..0c46b7d2c31 100644
--- a/aws/request/request.go
+++ b/aws/request/request.go
@@ -231,6 +231,10 @@ func (r *Request) WillRetry() bool {
return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries()
}
+func fmtAttemptCount(retryCount, maxRetries int) string {
+ return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries)
+}
+
// ParamsFilled returns if the request's parameters have been populated
// and the parameters are valid. False is returned if no parameters are
// provided or invalid.
@@ -330,16 +334,17 @@ func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, err
return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
}
-func debugLogReqError(r *Request, stage string, retrying bool, err error) {
+const (
+ willRetry = "will retry"
+ notRetrying = "not retrying"
+ retryCount = "retry %v/%v"
+)
+
+func debugLogReqError(r *Request, stage, retryStr string, err error) {
if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) {
return
}
- retryStr := "not retrying"
- if retrying {
- retryStr = "will retry"
- }
-
r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v",
stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err))
}
@@ -358,12 +363,12 @@ func (r *Request) Build() error {
if !r.built {
r.Handlers.Validate.Run(r)
if r.Error != nil {
- debugLogReqError(r, "Validate Request", false, r.Error)
+ debugLogReqError(r, "Validate Request", notRetrying, r.Error)
return r.Error
}
r.Handlers.Build.Run(r)
if r.Error != nil {
- debugLogReqError(r, "Build Request", false, r.Error)
+ debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error
}
r.built = true
@@ -379,7 +384,7 @@ func (r *Request) Build() error {
func (r *Request) Sign() error {
r.Build()
if r.Error != nil {
- debugLogReqError(r, "Build Request", false, r.Error)
+ debugLogReqError(r, "Build Request", notRetrying, r.Error)
return r.Error
}
@@ -473,7 +478,7 @@ func (r *Request) Send() error {
r.AttemptTime = time.Now()
if err := r.Sign(); err != nil {
- debugLogReqError(r, "Sign Request", false, err)
+ debugLogReqError(r, "Sign Request", notRetrying, err)
return err
}
@@ -520,7 +525,9 @@ func (r *Request) sendRequest() (sendErr error) {
r.Retryable = nil
r.Handlers.Send.Run(r)
if r.Error != nil {
- debugLogReqError(r, "Send Request", r.WillRetry(), r.Error)
+ debugLogReqError(r, "Send Request",
+ fmtAttemptCount(r.RetryCount, r.MaxRetries()),
+ r.Error)
return r.Error
}
@@ -528,13 +535,17 @@ func (r *Request) sendRequest() (sendErr error) {
r.Handlers.ValidateResponse.Run(r)
if r.Error != nil {
r.Handlers.UnmarshalError.Run(r)
- debugLogReqError(r, "Validate Response", r.WillRetry(), r.Error)
+ debugLogReqError(r, "Validate Response",
+ fmtAttemptCount(r.RetryCount, r.MaxRetries()),
+ r.Error)
return r.Error
}
r.Handlers.Unmarshal.Run(r)
if r.Error != nil {
- debugLogReqError(r, "Unmarshal Response", r.WillRetry(), r.Error)
+ debugLogReqError(r, "Unmarshal Response",
+ fmtAttemptCount(r.RetryCount, r.MaxRetries()),
+ r.Error)
return r.Error
}
@@ -565,8 +576,8 @@ type temporary interface {
Temporary() bool
}
-func shouldRetryCancel(err error) bool {
- switch err := err.(type) {
+func shouldRetryCancel(origErr error) bool {
+ switch err := origErr.(type) {
case awserr.Error:
if err.Code() == CanceledErrorCode {
return false
@@ -585,7 +596,7 @@ func shouldRetryCancel(err error) bool {
case temporary:
// If the error is temporary, we want to allow continuation of the
// retry process
- return err.Temporary()
+ return err.Temporary() || isErrConnectionReset(origErr)
case nil:
// `awserr.Error.OrigErr()` can be nil, meaning there was an error but
// because we don't know the cause, it is marked as retryable. See
diff --git a/aws/request/request_test.go b/aws/request/request_test.go
index 286ee36325b..0449c47a33f 100644
--- a/aws/request/request_test.go
+++ b/aws/request/request_test.go
@@ -43,10 +43,26 @@ func (e *tempNetworkError) Error() string {
var (
// net.OpError accept, are always temporary
- errAcceptConnectionResetStub = &tempNetworkError{isTemp: true, op: "accept", msg: "connection reset"}
+ errAcceptConnectionResetStub = &tempNetworkError{
+ isTemp: true, op: "accept", msg: "connection reset",
+ }
// net.OpError read for ECONNRESET is not temporary.
- errReadConnectionResetStub = &tempNetworkError{isTemp: false, op: "read", msg: "connection reset"}
+ errReadConnectionResetStub = &tempNetworkError{
+ isTemp: false, op: "read", msg: "connection reset",
+ }
+
+ // net.OpError write for ECONNRESET may not be temporary, but is treaded as
+ // temporary by the SDK.
+ errWriteConnectionResetStub = &tempNetworkError{
+ isTemp: false, op: "write", msg: "connection reset",
+ }
+
+ // net.OpError write for broken pipe may not be temporary, but is treaded as
+ // temporary by the SDK.
+ errWriteBrokenPipeStub = &tempNetworkError{
+ isTemp: false, op: "write", msg: "broken pipe",
+ }
// Generic connection reset error
errConnectionResetStub = errors.New("connection reset")
diff --git a/service/s3/s3manager/upload_test.go b/service/s3/s3manager/upload_test.go
index 7f729735639..810c340e81e 100644
--- a/service/s3/s3manager/upload_test.go
+++ b/service/s3/s3manager/upload_test.go
@@ -1,3 +1,5 @@
+// +build go1.8
+
package s3manager_test
import (
@@ -7,12 +9,15 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
+ "os"
"reflect"
"regexp"
"sort"
+ "strconv"
"strings"
"sync"
"testing"
+ "time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
@@ -1030,3 +1035,271 @@ func TestUploadMaxPartsEOF(t *testing.T) {
t.Errorf("expect %v ops, got %v", e, a)
}
}
+
+func createTempFile(t *testing.T, size int64) (*os.File, func(*testing.T), error) {
+ file, err := ioutil.TempFile(os.TempDir(), aws.SDKName+t.Name())
+ if err != nil {
+ return nil, nil, err
+ }
+ filename := file.Name()
+ if err := file.Truncate(size); err != nil {
+ return nil, nil, err
+ }
+
+ return file,
+ func(t *testing.T) {
+ if err := file.Close(); err != nil {
+ t.Errorf("failed to close temp file, %s, %v", filename, err)
+ }
+ if err := os.Remove(filename); err != nil {
+ t.Errorf("failed to remove temp file, %s, %v", filename, err)
+ }
+ },
+ nil
+}
+
+func buildFailHandlers(tb testing.TB, parts, retry int) []http.Handler {
+ handlers := make([]http.Handler, parts)
+ for i := 0; i < len(handlers); i++ {
+ handlers[i] = &failPartHandler{
+ tb: tb,
+ failsRemaining: retry,
+ successHandler: successPartHandler{tb: tb},
+ }
+ }
+
+ return handlers
+}
+
+func TestUploadRetry(t *testing.T) {
+ const numParts, retries = 3, 10
+
+ testFile, testFileCleanup, err := createTempFile(t, s3manager.DefaultUploadPartSize*numParts)
+ if err != nil {
+ t.Fatalf("failed to create test file, %v", err)
+ }
+ defer testFileCleanup(t)
+
+ cases := map[string]struct {
+ Body io.Reader
+ PartHandlers func(testing.TB) []http.Handler
+ }{
+ "bytes.Buffer": {
+ Body: bytes.NewBuffer(make([]byte, s3manager.DefaultUploadPartSize*numParts)),
+ PartHandlers: func(tb testing.TB) []http.Handler {
+ return buildFailHandlers(tb, numParts, retries)
+ },
+ },
+ "bytes.Reader": {
+ Body: bytes.NewReader(make([]byte, s3manager.DefaultUploadPartSize*numParts)),
+ PartHandlers: func(tb testing.TB) []http.Handler {
+ return buildFailHandlers(tb, numParts, retries)
+ },
+ },
+ "os.File": {
+ Body: testFile,
+ PartHandlers: func(tb testing.TB) []http.Handler {
+ return buildFailHandlers(tb, numParts, retries)
+ },
+ },
+ }
+
+ for name, c := range cases {
+ t.Run(name, func(t *testing.T) {
+ mux := newMockS3UploadServer(t, c.PartHandlers(t))
+ server := httptest.NewServer(mux)
+ defer server.Close()
+
+ sess := unit.Session.Copy(&aws.Config{
+ Endpoint: aws.String(server.URL),
+ S3ForcePathStyle: aws.Bool(true),
+ DisableSSL: aws.Bool(true),
+ Logger: t,
+ MaxRetries: aws.Int(retries + 1),
+ SleepDelay: func(time.Duration) {},
+
+ LogLevel: aws.LogLevel(
+ aws.LogDebugWithRequestErrors | aws.LogDebugWithRequestRetries,
+ ),
+ //Credentials: credentials.AnonymousCredentials,
+ })
+
+ uploader := s3manager.NewUploader(sess, func(u *s3manager.Uploader) {
+ // u.Concurrency = 1
+ })
+ _, err := uploader.Upload(&s3manager.UploadInput{
+ Bucket: aws.String("bucket"),
+ Key: aws.String("key"),
+ Body: c.Body,
+ })
+
+ if err != nil {
+ t.Fatalf("expect no error, got %v", err)
+ }
+ })
+ }
+}
+
+type mockS3UploadServer struct {
+ *http.ServeMux
+
+ tb testing.TB
+ partHandler []http.Handler
+}
+
+func newMockS3UploadServer(tb testing.TB, partHandler []http.Handler) *mockS3UploadServer {
+ s := &mockS3UploadServer{
+ ServeMux: http.NewServeMux(),
+ partHandler: partHandler,
+ tb: tb,
+ }
+
+ s.HandleFunc("/", s.handleRequest)
+
+ return s
+}
+
+func (s mockS3UploadServer) handleRequest(w http.ResponseWriter, r *http.Request) {
+ defer r.Body.Close()
+
+ _, hasUploads := r.URL.Query()["uploads"]
+
+ switch {
+ case r.Method == "POST" && hasUploads:
+ // CreateMultipartUpload
+ w.Header().Set("Content-Length", strconv.Itoa(len(createUploadResp)))
+ w.Write([]byte(createUploadResp))
+
+ case r.Method == "PUT":
+ // UploadPart
+ partNumStr := r.URL.Query().Get("partNumber")
+ id, err := strconv.Atoi(partNumStr)
+ if err != nil {
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("unable to parse partNumber, %q, %v",
+ partNumStr, err))
+ return
+ }
+ id--
+ if id < 0 || id >= len(s.partHandler) {
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("invalid partNumber %v", id))
+ return
+ }
+ s.partHandler[id].ServeHTTP(w, r)
+
+ case r.Method == "POST":
+ // CompleteMultipartUpload
+ w.Header().Set("Content-Length", strconv.Itoa(len(completeUploadResp)))
+ w.Write([]byte(completeUploadResp))
+
+ case r.Method == "DELETE":
+ // AbortMultipartUpload
+ w.Header().Set("Content-Length", strconv.Itoa(len(abortUploadResp)))
+ w.WriteHeader(200)
+ w.Write([]byte(abortUploadResp))
+
+ default:
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("invalid request %v %v", r.Method, r.URL))
+ }
+}
+
+func failRequest(w http.ResponseWriter, status int, code, msg string) {
+ msg = fmt.Sprintf(baseRequestErrorResp, code, msg)
+ w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
+ w.WriteHeader(status)
+ w.Write([]byte(msg))
+}
+
+type successPartHandler struct {
+ tb testing.TB
+}
+
+func (h successPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ defer r.Body.Close()
+
+ n, err := io.Copy(ioutil.Discard, r.Body)
+ if err != nil {
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("failed to read body, %v", err))
+ return
+ }
+
+ contLenStr := r.Header.Get("Content-Length")
+ expectLen, err := strconv.ParseInt(contLenStr, 10, 64)
+ if err != nil {
+ h.tb.Logf("expect content-length, got %q, %v", contLenStr, err)
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf("unable to get content-length %v", err))
+ return
+ }
+ if e, a := expectLen, n; e != a {
+ h.tb.Logf("expect %v read, got %v", e, a)
+ failRequest(w, 400, "BadRequest",
+ fmt.Sprintf(
+ "content-length and body do not match, %v, %v", e, a))
+ return
+ }
+
+ w.Header().Set("Content-Length", strconv.Itoa(len(uploadPartResp)))
+ w.Write([]byte(uploadPartResp))
+}
+
+type failPartHandler struct {
+ tb testing.TB
+
+ failsRemaining int
+ successHandler http.Handler
+}
+
+func (h *failPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ defer r.Body.Close()
+
+ if h.failsRemaining == 0 && h.successHandler != nil {
+ h.successHandler.ServeHTTP(w, r)
+ return
+ }
+
+ io.Copy(ioutil.Discard, r.Body)
+
+ failRequest(w, 500, "InternalException",
+ fmt.Sprintf("mock error, partNumber %v", r.URL.Query().Get("partNumber")))
+
+ h.failsRemaining--
+}
+
+const createUploadResp = `
+
+ bucket
+ key
+ abc123
+
+`
+const uploadPartResp = `
+
+ key
+
+`
+const baseRequestErrorResp = `
+
+ %s
+ %s
+ request-id
+ host-id
+
+`
+const completeUploadResp = `
+
+ bucket
+ key
+ key
+ https://bucket.us-west-2.amazonaws.com/key
+ abc123
+
+`
+
+const abortUploadResp = `
+
+
+`