diff --git a/aws/request/connection_reset_error.go b/aws/request/connection_reset_error.go index 2d13754cfe1..d9b37f4d32a 100644 --- a/aws/request/connection_reset_error.go +++ b/aws/request/connection_reset_error.go @@ -5,5 +5,14 @@ import ( ) func isErrConnectionReset(err error) bool { - return strings.Contains(err.Error(), "connection reset") + if strings.Contains(err.Error(), "read: connection reset") { + return false + } + + if strings.Contains(err.Error(), "connection reset") || + strings.Contains(err.Error(), "broken pipe") { + return true + } + + return false } diff --git a/aws/request/connection_reset_error_test.go b/aws/request/connection_reset_error_test.go index cae48a3a7e2..102461426b6 100644 --- a/aws/request/connection_reset_error_test.go +++ b/aws/request/connection_reset_error_test.go @@ -36,14 +36,22 @@ func TestSerializationErrConnectionReset_accept(t *testing.T) { Err error ExpectAttempts int }{ - "with temporary": { + "accept with temporary": { Err: errAcceptConnectionResetStub, ExpectAttempts: 6, }, - "not temporary": { + "read not temporary": { Err: errReadConnectionResetStub, ExpectAttempts: 1, }, + "write with temporary": { + Err: errWriteConnectionResetStub, + ExpectAttempts: 6, + }, + "write broken pipe with temporary": { + Err: errWriteBrokenPipeStub, + ExpectAttempts: 6, + }, "generic connection reset": { Err: errConnectionResetStub, ExpectAttempts: 6, @@ -86,6 +94,7 @@ func TestSerializationErrConnectionReset_accept(t *testing.T) { } cfg := unit.Session.Config.Copy() cfg.MaxRetries = aws.Int(5) + cfg.SleepDelay = func(time.Duration) {} req := request.New( *cfg, diff --git a/aws/request/http_request_retry_test.go b/aws/request/http_request_retry_test.go index 2e057f69934..fcdd1ce819b 100644 --- a/aws/request/http_request_retry_test.go +++ b/aws/request/http_request_retry_test.go @@ -1,5 +1,3 @@ -// +build go1.5 - package request_test import ( diff --git a/aws/request/request.go b/aws/request/request.go index 19da3fcd826..0c46b7d2c31 100644 --- a/aws/request/request.go +++ b/aws/request/request.go @@ -231,6 +231,10 @@ func (r *Request) WillRetry() bool { return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries() } +func fmtAttemptCount(retryCount, maxRetries int) string { + return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries) +} + // ParamsFilled returns if the request's parameters have been populated // and the parameters are valid. False is returned if no parameters are // provided or invalid. @@ -330,16 +334,17 @@ func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, err return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil } -func debugLogReqError(r *Request, stage string, retrying bool, err error) { +const ( + willRetry = "will retry" + notRetrying = "not retrying" + retryCount = "retry %v/%v" +) + +func debugLogReqError(r *Request, stage, retryStr string, err error) { if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) { return } - retryStr := "not retrying" - if retrying { - retryStr = "will retry" - } - r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v", stage, r.ClientInfo.ServiceName, r.Operation.Name, retryStr, err)) } @@ -358,12 +363,12 @@ func (r *Request) Build() error { if !r.built { r.Handlers.Validate.Run(r) if r.Error != nil { - debugLogReqError(r, "Validate Request", false, r.Error) + debugLogReqError(r, "Validate Request", notRetrying, r.Error) return r.Error } r.Handlers.Build.Run(r) if r.Error != nil { - debugLogReqError(r, "Build Request", false, r.Error) + debugLogReqError(r, "Build Request", notRetrying, r.Error) return r.Error } r.built = true @@ -379,7 +384,7 @@ func (r *Request) Build() error { func (r *Request) Sign() error { r.Build() if r.Error != nil { - debugLogReqError(r, "Build Request", false, r.Error) + debugLogReqError(r, "Build Request", notRetrying, r.Error) return r.Error } @@ -473,7 +478,7 @@ func (r *Request) Send() error { r.AttemptTime = time.Now() if err := r.Sign(); err != nil { - debugLogReqError(r, "Sign Request", false, err) + debugLogReqError(r, "Sign Request", notRetrying, err) return err } @@ -520,7 +525,9 @@ func (r *Request) sendRequest() (sendErr error) { r.Retryable = nil r.Handlers.Send.Run(r) if r.Error != nil { - debugLogReqError(r, "Send Request", r.WillRetry(), r.Error) + debugLogReqError(r, "Send Request", + fmtAttemptCount(r.RetryCount, r.MaxRetries()), + r.Error) return r.Error } @@ -528,13 +535,17 @@ func (r *Request) sendRequest() (sendErr error) { r.Handlers.ValidateResponse.Run(r) if r.Error != nil { r.Handlers.UnmarshalError.Run(r) - debugLogReqError(r, "Validate Response", r.WillRetry(), r.Error) + debugLogReqError(r, "Validate Response", + fmtAttemptCount(r.RetryCount, r.MaxRetries()), + r.Error) return r.Error } r.Handlers.Unmarshal.Run(r) if r.Error != nil { - debugLogReqError(r, "Unmarshal Response", r.WillRetry(), r.Error) + debugLogReqError(r, "Unmarshal Response", + fmtAttemptCount(r.RetryCount, r.MaxRetries()), + r.Error) return r.Error } @@ -565,8 +576,8 @@ type temporary interface { Temporary() bool } -func shouldRetryCancel(err error) bool { - switch err := err.(type) { +func shouldRetryCancel(origErr error) bool { + switch err := origErr.(type) { case awserr.Error: if err.Code() == CanceledErrorCode { return false @@ -585,7 +596,7 @@ func shouldRetryCancel(err error) bool { case temporary: // If the error is temporary, we want to allow continuation of the // retry process - return err.Temporary() + return err.Temporary() || isErrConnectionReset(origErr) case nil: // `awserr.Error.OrigErr()` can be nil, meaning there was an error but // because we don't know the cause, it is marked as retryable. See diff --git a/aws/request/request_test.go b/aws/request/request_test.go index 286ee36325b..0449c47a33f 100644 --- a/aws/request/request_test.go +++ b/aws/request/request_test.go @@ -43,10 +43,26 @@ func (e *tempNetworkError) Error() string { var ( // net.OpError accept, are always temporary - errAcceptConnectionResetStub = &tempNetworkError{isTemp: true, op: "accept", msg: "connection reset"} + errAcceptConnectionResetStub = &tempNetworkError{ + isTemp: true, op: "accept", msg: "connection reset", + } // net.OpError read for ECONNRESET is not temporary. - errReadConnectionResetStub = &tempNetworkError{isTemp: false, op: "read", msg: "connection reset"} + errReadConnectionResetStub = &tempNetworkError{ + isTemp: false, op: "read", msg: "connection reset", + } + + // net.OpError write for ECONNRESET may not be temporary, but is treaded as + // temporary by the SDK. + errWriteConnectionResetStub = &tempNetworkError{ + isTemp: false, op: "write", msg: "connection reset", + } + + // net.OpError write for broken pipe may not be temporary, but is treaded as + // temporary by the SDK. + errWriteBrokenPipeStub = &tempNetworkError{ + isTemp: false, op: "write", msg: "broken pipe", + } // Generic connection reset error errConnectionResetStub = errors.New("connection reset") diff --git a/service/s3/s3manager/upload_test.go b/service/s3/s3manager/upload_test.go index 7f729735639..810c340e81e 100644 --- a/service/s3/s3manager/upload_test.go +++ b/service/s3/s3manager/upload_test.go @@ -1,3 +1,5 @@ +// +build go1.8 + package s3manager_test import ( @@ -7,12 +9,15 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "os" "reflect" "regexp" "sort" + "strconv" "strings" "sync" "testing" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" @@ -1030,3 +1035,271 @@ func TestUploadMaxPartsEOF(t *testing.T) { t.Errorf("expect %v ops, got %v", e, a) } } + +func createTempFile(t *testing.T, size int64) (*os.File, func(*testing.T), error) { + file, err := ioutil.TempFile(os.TempDir(), aws.SDKName+t.Name()) + if err != nil { + return nil, nil, err + } + filename := file.Name() + if err := file.Truncate(size); err != nil { + return nil, nil, err + } + + return file, + func(t *testing.T) { + if err := file.Close(); err != nil { + t.Errorf("failed to close temp file, %s, %v", filename, err) + } + if err := os.Remove(filename); err != nil { + t.Errorf("failed to remove temp file, %s, %v", filename, err) + } + }, + nil +} + +func buildFailHandlers(tb testing.TB, parts, retry int) []http.Handler { + handlers := make([]http.Handler, parts) + for i := 0; i < len(handlers); i++ { + handlers[i] = &failPartHandler{ + tb: tb, + failsRemaining: retry, + successHandler: successPartHandler{tb: tb}, + } + } + + return handlers +} + +func TestUploadRetry(t *testing.T) { + const numParts, retries = 3, 10 + + testFile, testFileCleanup, err := createTempFile(t, s3manager.DefaultUploadPartSize*numParts) + if err != nil { + t.Fatalf("failed to create test file, %v", err) + } + defer testFileCleanup(t) + + cases := map[string]struct { + Body io.Reader + PartHandlers func(testing.TB) []http.Handler + }{ + "bytes.Buffer": { + Body: bytes.NewBuffer(make([]byte, s3manager.DefaultUploadPartSize*numParts)), + PartHandlers: func(tb testing.TB) []http.Handler { + return buildFailHandlers(tb, numParts, retries) + }, + }, + "bytes.Reader": { + Body: bytes.NewReader(make([]byte, s3manager.DefaultUploadPartSize*numParts)), + PartHandlers: func(tb testing.TB) []http.Handler { + return buildFailHandlers(tb, numParts, retries) + }, + }, + "os.File": { + Body: testFile, + PartHandlers: func(tb testing.TB) []http.Handler { + return buildFailHandlers(tb, numParts, retries) + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + mux := newMockS3UploadServer(t, c.PartHandlers(t)) + server := httptest.NewServer(mux) + defer server.Close() + + sess := unit.Session.Copy(&aws.Config{ + Endpoint: aws.String(server.URL), + S3ForcePathStyle: aws.Bool(true), + DisableSSL: aws.Bool(true), + Logger: t, + MaxRetries: aws.Int(retries + 1), + SleepDelay: func(time.Duration) {}, + + LogLevel: aws.LogLevel( + aws.LogDebugWithRequestErrors | aws.LogDebugWithRequestRetries, + ), + //Credentials: credentials.AnonymousCredentials, + }) + + uploader := s3manager.NewUploader(sess, func(u *s3manager.Uploader) { + // u.Concurrency = 1 + }) + _, err := uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Body: c.Body, + }) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + }) + } +} + +type mockS3UploadServer struct { + *http.ServeMux + + tb testing.TB + partHandler []http.Handler +} + +func newMockS3UploadServer(tb testing.TB, partHandler []http.Handler) *mockS3UploadServer { + s := &mockS3UploadServer{ + ServeMux: http.NewServeMux(), + partHandler: partHandler, + tb: tb, + } + + s.HandleFunc("/", s.handleRequest) + + return s +} + +func (s mockS3UploadServer) handleRequest(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + _, hasUploads := r.URL.Query()["uploads"] + + switch { + case r.Method == "POST" && hasUploads: + // CreateMultipartUpload + w.Header().Set("Content-Length", strconv.Itoa(len(createUploadResp))) + w.Write([]byte(createUploadResp)) + + case r.Method == "PUT": + // UploadPart + partNumStr := r.URL.Query().Get("partNumber") + id, err := strconv.Atoi(partNumStr) + if err != nil { + failRequest(w, 400, "BadRequest", + fmt.Sprintf("unable to parse partNumber, %q, %v", + partNumStr, err)) + return + } + id-- + if id < 0 || id >= len(s.partHandler) { + failRequest(w, 400, "BadRequest", + fmt.Sprintf("invalid partNumber %v", id)) + return + } + s.partHandler[id].ServeHTTP(w, r) + + case r.Method == "POST": + // CompleteMultipartUpload + w.Header().Set("Content-Length", strconv.Itoa(len(completeUploadResp))) + w.Write([]byte(completeUploadResp)) + + case r.Method == "DELETE": + // AbortMultipartUpload + w.Header().Set("Content-Length", strconv.Itoa(len(abortUploadResp))) + w.WriteHeader(200) + w.Write([]byte(abortUploadResp)) + + default: + failRequest(w, 400, "BadRequest", + fmt.Sprintf("invalid request %v %v", r.Method, r.URL)) + } +} + +func failRequest(w http.ResponseWriter, status int, code, msg string) { + msg = fmt.Sprintf(baseRequestErrorResp, code, msg) + w.Header().Set("Content-Length", strconv.Itoa(len(msg))) + w.WriteHeader(status) + w.Write([]byte(msg)) +} + +type successPartHandler struct { + tb testing.TB +} + +func (h successPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + n, err := io.Copy(ioutil.Discard, r.Body) + if err != nil { + failRequest(w, 400, "BadRequest", + fmt.Sprintf("failed to read body, %v", err)) + return + } + + contLenStr := r.Header.Get("Content-Length") + expectLen, err := strconv.ParseInt(contLenStr, 10, 64) + if err != nil { + h.tb.Logf("expect content-length, got %q, %v", contLenStr, err) + failRequest(w, 400, "BadRequest", + fmt.Sprintf("unable to get content-length %v", err)) + return + } + if e, a := expectLen, n; e != a { + h.tb.Logf("expect %v read, got %v", e, a) + failRequest(w, 400, "BadRequest", + fmt.Sprintf( + "content-length and body do not match, %v, %v", e, a)) + return + } + + w.Header().Set("Content-Length", strconv.Itoa(len(uploadPartResp))) + w.Write([]byte(uploadPartResp)) +} + +type failPartHandler struct { + tb testing.TB + + failsRemaining int + successHandler http.Handler +} + +func (h *failPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + if h.failsRemaining == 0 && h.successHandler != nil { + h.successHandler.ServeHTTP(w, r) + return + } + + io.Copy(ioutil.Discard, r.Body) + + failRequest(w, 500, "InternalException", + fmt.Sprintf("mock error, partNumber %v", r.URL.Query().Get("partNumber"))) + + h.failsRemaining-- +} + +const createUploadResp = ` + + bucket + key + abc123 + +` +const uploadPartResp = ` + + key + +` +const baseRequestErrorResp = ` + + %s + %s + request-id + host-id + +` +const completeUploadResp = ` + + bucket + key + key + https://bucket.us-west-2.amazonaws.com/key + abc123 + +` + +const abortUploadResp = ` + + +`