diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index bfaf214319f..408ff6b2f47 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -13,6 +13,14 @@ Deprecations SDK Features --- +* `service/s3/s3manager`: Add Upload Buffer Provider ([#404](https://github.com/aws/aws-sdk-go-v2/pull/404)) + * Adds a new `BufferProvider` member for specifying how part data can be buffered in memory. + * Windows platforms will now default to buffering 1MB per part to reduce contention when uploading files. + * Non-Windows platforms will continue to employ a non-buffering behavior. +* `service/s3/s3manager`: Add Download Buffer Provider ([#404](https://github.com/aws/aws-sdk-go-v2/pull/404)) + * Adds a new `BufferProvider` member for specifying how part data can be buffered in memory when copying from the http response body. + * Windows platforms will now default to buffering 1MB per part to reduce contention when downloading files. + * Non-Windows platforms will continue to employ a non-buffering behavior. * `service/dynamodb/dynamodbattribute`: New Encoder and Decoder Behavior for Empty Collections ([#401](https://github.com/aws/aws-sdk-go-v2/pull/401)) * The `Encoder` and `Decoder` types have been enhanced to support the marshaling of empty structures, maps, and slices to and from their respective DynamoDB AttributeValues. * This change incorporates the behavior changes introduced via a marshal option in V1 ([#2834](https://github.com/aws/aws-sdk-go/pull/2834)) @@ -24,6 +32,9 @@ SDK Enhancements * Related to [aws/aws-sdk-go#2310](https://github.com/aws/aws-sdk-go/pull/2310) * Fixes [#251](https://github.com/aws/aws-sdk-go-v2/issues/251) * `aws/request` : Retryer is now a named field on Request. ([#393](https://github.com/aws/aws-sdk-go-v2/pull/393)) +* `service/s3/s3manager`: Adds `sync.Pool` to allow reuse of part buffers for streaming payloads ([#404](https://github.com/aws/aws-sdk-go-v2/pull/404)) + * Fixes [#402](https://github.com/aws/aws-sdk-go-v2/issues/402) + * Uses the new behavior introduced in V1 [#2863](https://github.com/aws/aws-sdk-go/pull/2863) which allows the reuse of the sync.Pool across multiple Upload request that match part sizes. SDK Bugs --- diff --git a/Makefile b/Makefile index 4ab4a61e135..990f41fb847 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ LINTIGNOREDEPS='vendor/.+\.go' LINTIGNOREPKGCOMMENT='service/[^/]+/doc_custom.go:.+package comment should be of the form' LINTIGNOREENDPOINTS='aws/endpoints/defaults.go:.+(method|const) .+ should be ' UNIT_TEST_TAGS="example codegen awsinclude" +ALL_TAGS="example codegen awsinclude integration perftest" # SDK's Core and client packages that are compatable with Go 1.9+. SDK_CORE_PKGS=./aws/... ./private/... ./internal/... @@ -56,11 +57,14 @@ cleanup-models: ################### # Unit/CI Testing # ################### -unit: verify +build: + go build -o /dev/null -tags ${ALL_TAGS} ${SDK_ALL_PKGS} + +unit: verify build @echo "go test SDK and vendor packages" @go test -tags ${UNIT_TEST_TAGS} ${SDK_ALL_PKGS} -unit-with-race-cover: verify +unit-with-race-cover: verify build @echo "go test SDK and vendor packages" @go test -tags ${UNIT_TEST_TAGS} -race -cpu=1,2,4 ${SDK_ALL_PKGS} diff --git a/internal/awstesting/discard.go b/internal/awstesting/discard.go new file mode 100644 index 00000000000..866dfca4e33 --- /dev/null +++ b/internal/awstesting/discard.go @@ -0,0 +1,11 @@ +package awstesting + +// DiscardAt is an io.WriteAt that discards +// the requested bytes to be written +type DiscardAt struct{} + +// WriteAt discards the given []byte slice and returns len(p) bytes +// as having been written at the given offset. It will never return an error. +func (d DiscardAt) WriteAt(p []byte, off int64) (n int, err error) { + return len(p), nil +} diff --git a/internal/awstesting/endless_reader.go b/internal/awstesting/endless_reader.go new file mode 100644 index 00000000000..da1b429571b --- /dev/null +++ b/internal/awstesting/endless_reader.go @@ -0,0 +1,12 @@ +package awstesting + +// EndlessReader is an io.Reader that will always return +// that bytes have been read. +type EndlessReader struct{} + +// Read will report that it has read len(p) bytes in p. +// The content in the []byte will be unmodified. +// This will never return an error. +func (e EndlessReader) Read(p []byte) (int, error) { + return len(p), nil +} diff --git a/internal/awstesting/integration/integration.go b/internal/awstesting/integration/integration.go index 5b1b3053d66..b60442ec766 100644 --- a/internal/awstesting/integration/integration.go +++ b/internal/awstesting/integration/integration.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "fmt" "io" + "io/ioutil" "os" "github.com/aws/aws-sdk-go-v2/aws" @@ -63,3 +64,36 @@ func ConfigWithDefaultRegion(region string) aws.Config { return cfg } + +// CreateFileOfSize will return an *os.File that is of size bytes +func CreateFileOfSize(dir string, size int64) (*os.File, error) { + file, err := ioutil.TempFile(dir, "s3Bench") + if err != nil { + return nil, err + } + + err = file.Truncate(size) + if err != nil { + file.Close() + os.Remove(file.Name()) + return nil, err + } + + return file, nil +} + +// SizeToName returns a human-readable string for the given size bytes +func SizeToName(size int) string { + units := []string{"B", "KB", "MB", "GB"} + i := 0 + for size >= 1024 { + size /= 1024 + i++ + } + + if i > len(units)-1 { + i = len(units) - 1 + } + + return fmt.Sprintf("%d%s", size, units[i]) +} diff --git a/internal/awstesting/integration/performance/s3DownloadManager/README.md b/internal/awstesting/integration/performance/s3DownloadManager/README.md new file mode 100644 index 00000000000..c4066eec6de --- /dev/null +++ b/internal/awstesting/integration/performance/s3DownloadManager/README.md @@ -0,0 +1,39 @@ +## Performance Utility + +Downloads a test file from a S3 bucket using the SDK's S3 download manager. Allows passing +in custom configuration for the HTTP client and SDK's Download Manager behavior. + +## Build +### Standalone +```sh +go build -tags "integration perftest" -o s3DownloadManager ./awstesting/integration/performance/s3DownloadManager +``` +### Benchmarking +```sh +go test -tags "integration perftest" -c -o s3DownloadManager ./awstesting/integration/performance/s3DownloadManager +``` + +## Usage Example: +### Standalone +```sh +AWS_REGION=us-west-2 AWS_PROFILE=aws-go-sdk-team-test ./s3DownloadManager \ +-bucket aws-sdk-go-data \ +-size 10485760 \ +-client.idle-conns 1000 \ +-client.idle-conns-host 300 \ +-client.timeout.connect=1s \ +-client.timeout.response-header=1s +``` + +### Benchmarking +```sh +AWS_REGION=us-west-2 AWS_PROFILE=aws-go-sdk-team-test ./s3DownloadManager \ +-test.bench=. \ +-test.benchmem \ +-test.benchtime 1x \ +-bucket aws-sdk-go-data \ +-client.idle-conns 1000 \ +-client.idle-conns-host 300 \ +-client.timeout.connect=1s \ +-client.timeout.response-header=1s +``` diff --git a/internal/awstesting/integration/performance/s3DownloadManager/client.go b/internal/awstesting/integration/performance/s3DownloadManager/client.go new file mode 100644 index 00000000000..3b16c5f900b --- /dev/null +++ b/internal/awstesting/integration/performance/s3DownloadManager/client.go @@ -0,0 +1,32 @@ +// +build integration,perftest + +package main + +import ( + "net" + "net/http" + "time" +) + +func NewClient(cfg ClientConfig) *http.Client { + tr := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: cfg.Timeouts.Connect, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: cfg.MaxIdleConns, + MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, + IdleConnTimeout: 90 * time.Second, + + DisableKeepAlives: !cfg.KeepAlive, + TLSHandshakeTimeout: cfg.Timeouts.TLSHandshake, + ExpectContinueTimeout: cfg.Timeouts.ExpectContinue, + ResponseHeaderTimeout: cfg.Timeouts.ResponseHeader, + } + + return &http.Client{ + Transport: tr, + } +} diff --git a/internal/awstesting/integration/performance/s3DownloadManager/config.go b/internal/awstesting/integration/performance/s3DownloadManager/config.go new file mode 100644 index 00000000000..44d5f1bc1f5 --- /dev/null +++ b/internal/awstesting/integration/performance/s3DownloadManager/config.go @@ -0,0 +1,152 @@ +// +build integration,perftest + +package main + +import ( + "flag" + "fmt" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" +) + +type Config struct { + Bucket string + Size int64 + LogVerbose bool + + SDK SDKConfig + Client ClientConfig +} + +func (c *Config) SetupFlags(prefix string, flagset *flag.FlagSet) { + flagset.StringVar(&c.Bucket, "bucket", "", + "The S3 bucket `name` to download the object from.") + flagset.Int64Var(&c.Size, "size", 0, + "The S3 object size in bytes to be first uploaded then downloaded") + flagset.BoolVar(&c.LogVerbose, "verbose", false, + "The output log will include verbose request information") + + c.SDK.SetupFlags(prefix, flagset) + c.Client.SetupFlags(prefix, flagset) +} + +func (c *Config) Validate() error { + var errs Errors + + if len(c.Bucket) == 0 || c.Size <= 0 { + errs = append(errs, fmt.Errorf("bucket and filename/size are required")) + } + + if err := c.SDK.Validate(); err != nil { + errs = append(errs, err) + } + if err := c.Client.Validate(); err != nil { + errs = append(errs, err) + } + + if len(errs) != 0 { + return errs + } + + return nil +} + +type SDKConfig struct { + PartSize int64 + Concurrency int + BufferProvider s3manager.WriterReadFromProvider +} + +func (c *SDKConfig) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "sdk." + + flagset.Int64Var(&c.PartSize, prefix+"part-size", s3manager.DefaultDownloadPartSize, + "Specifies the `size` of parts of the object to download.") + flagset.IntVar(&c.Concurrency, prefix+"concurrency", s3manager.DefaultDownloadConcurrency, + "Specifies the number of parts to download `at once`.") +} + +func (c *SDKConfig) Validate() error { + return nil +} + +type ClientConfig struct { + KeepAlive bool + Timeouts Timeouts + + MaxIdleConns int + MaxIdleConnsPerHost int +} + +func (c *ClientConfig) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "client." + + flagset.BoolVar(&c.KeepAlive, prefix+"http-keep-alive", true, + "Specifies if HTTP keep alive is enabled.") + + defTR := http.DefaultTransport.(*http.Transport) + + flagset.IntVar(&c.MaxIdleConns, prefix+"idle-conns", defTR.MaxIdleConns, + "Specifies max idle connection pool size.") + + flagset.IntVar(&c.MaxIdleConnsPerHost, prefix+"idle-conns-host", http.DefaultMaxIdleConnsPerHost, + "Specifies max idle connection pool per host, will be truncated by idle-conns.") + + c.Timeouts.SetupFlags(prefix, flagset) +} + +func (c *ClientConfig) Validate() error { + var errs Errors + + if err := c.Timeouts.Validate(); err != nil { + errs = append(errs, err) + } + + if len(errs) != 0 { + return errs + } + return nil +} + +type Timeouts struct { + Connect time.Duration + TLSHandshake time.Duration + ExpectContinue time.Duration + ResponseHeader time.Duration +} + +func (c *Timeouts) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "timeout." + + flagset.DurationVar(&c.Connect, prefix+"connect", 30*time.Second, + "The `timeout` connecting to the remote host.") + + defTR := http.DefaultTransport.(*http.Transport) + + flagset.DurationVar(&c.TLSHandshake, prefix+"tls", defTR.TLSHandshakeTimeout, + "The `timeout` waiting for the TLS handshake to complete.") + + flagset.DurationVar(&c.ExpectContinue, prefix+"expect-continue", defTR.ExpectContinueTimeout, + "The `timeout` waiting for the TLS handshake to complete.") + + flagset.DurationVar(&c.ResponseHeader, prefix+"response-header", defTR.ResponseHeaderTimeout, + "The `timeout` waiting for the TLS handshake to complete.") +} + +func (c *Timeouts) Validate() error { + return nil +} + +type Errors []error + +func (es Errors) Error() string { + var buf strings.Builder + for _, e := range es { + buf.WriteString(e.Error()) + } + + return buf.String() +} diff --git a/internal/awstesting/integration/performance/s3DownloadManager/main.go b/internal/awstesting/integration/performance/s3DownloadManager/main.go new file mode 100644 index 00000000000..84446b92c32 --- /dev/null +++ b/internal/awstesting/integration/performance/s3DownloadManager/main.go @@ -0,0 +1,221 @@ +// +build integration,perftest + +package main + +import ( + "context" + "flag" + "fmt" + "io" + "log" + "os" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/external" + "github.com/aws/aws-sdk-go-v2/internal/awstesting" + "github.com/aws/aws-sdk-go-v2/internal/awstesting/integration" + "github.com/aws/aws-sdk-go-v2/internal/sdkio" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" +) + +var config Config + +func main() { + parseCommandLine() + + log.SetOutput(os.Stderr) + + log.Printf("uploading %s file to s3://%s\n", integration.SizeToName(int(config.Size)), config.Bucket) + key, err := setupDownloadTest(config.Bucket, config.Size) + if err != nil { + log.Fatalf("failed to setup download testing: %v", err) + } + + traces := make(chan *RequestTrace, config.SDK.Concurrency) + requestTracer := downloadRequestTracer(traces) + downloader := newDownloader(config.Client, config.SDK, requestTracer) + + metricReportDone := startTraceReceiver(traces) + + log.Println("starting download...") + start := time.Now() + _, err = downloader.Download(&awstesting.DiscardAt{}, &s3.GetObjectInput{ + Bucket: &config.Bucket, + Key: &key, + }) + if err != nil { + log.Fatalf("failed to download object, %v", err) + } + close(traces) + + dur := time.Since(start) + log.Printf("Download finished, Size: %d, Dur: %s, Throughput: %.5f GB/s", + config.Size, dur, (float64(config.Size)/(float64(dur)/float64(time.Second)))/float64(1e9), + ) + + <-metricReportDone + + log.Printf("cleaning up s3://%s/%s\n", config.Bucket, key) + if err = teardownDownloadTest(config.Bucket, key); err != nil { + log.Fatalf("failed to teardwn test artifacts: %v", err) + } +} + +func parseCommandLine() { + config.SetupFlags("", flag.CommandLine) + + if err := flag.CommandLine.Parse(os.Args[1:]); err != nil { + flag.CommandLine.PrintDefaults() + log.Fatalf("failed to parse CLI commands") + } + if err := config.Validate(); err != nil { + flag.CommandLine.PrintDefaults() + log.Fatalf("invalid arguments: %v", err) + } +} + +func setupDownloadTest(bucket string, size int64) (key string, err error) { + er := &awstesting.EndlessReader{} + lr := io.LimitReader(er, size) + + key = integration.UniqueID() + + cfg, err := external.LoadDefaultAWSConfig() + if err != nil { + return "", fmt.Errorf("failed to load config: %v", err) + } + + client := s3.New(cfg) + client.Disable100Continue = true + + uploader := s3manager.NewUploaderWithClient(client, func(u *s3manager.Uploader) { + u.PartSize = 100 * sdkio.MebiByte + u.RequestOptions = append(u.RequestOptions, func(r *aws.Request) { + if r.Operation.Name != "UploadPart" && r.Operation.Name != "PutObject" { + return + } + + r.HTTPRequest.Header.Set("X-Amz-Content-Sha256", "UNSIGNED-PAYLOAD") + }) + }) + + _, err = uploader.Upload(&s3manager.UploadInput{ + Bucket: &bucket, + Body: lr, + Key: &key, + }) + if err != nil { + err = fmt.Errorf("failed to upload test object to s3: %v", err) + } + + return +} + +func teardownDownloadTest(bucket, key string) error { + cfg, err := external.LoadDefaultAWSConfig() + if err != nil { + return fmt.Errorf("failed to load config: %v", err) + } + svc := s3.New(cfg) + + resp := svc.DeleteObjectRequest(&s3.DeleteObjectInput{Bucket: &bucket, Key: &key}) + _, err = resp.Send(context.Background()) + return err +} + +func startTraceReceiver(traces <-chan *RequestTrace) <-chan struct{} { + metricReportDone := make(chan struct{}) + + go func() { + defer close(metricReportDone) + metrics := map[string]*RequestTrace{} + for trace := range traces { + curTrace, ok := metrics[trace.Operation] + if !ok { + curTrace = trace + } else { + curTrace.attempts = append(curTrace.attempts, trace.attempts...) + if len(trace.errs) != 0 { + curTrace.errs = append(curTrace.errs, trace.errs...) + } + curTrace.finish = trace.finish + } + + metrics[trace.Operation] = curTrace + } + + for _, name := range []string{ + "GetObject", + } { + if trace, ok := metrics[name]; ok { + printAttempts(name, trace, config.LogVerbose) + } + } + }() + + return metricReportDone +} + +func printAttempts(op string, trace *RequestTrace, verbose bool) { + if !verbose { + return + } + + log.Printf("%s: latency:%s requests:%d errors:%d", + op, + trace.finish.Sub(trace.start), + len(trace.attempts), + len(trace.errs), + ) + + for _, a := range trace.attempts { + log.Printf(" * %s", a) + } + if err := trace.Err(); err != nil { + log.Printf("Operation Errors: %v", err) + } + log.Println() +} + +func downloadRequestTracer(traces chan<- *RequestTrace) aws.Option { + tracerOption := func(r *aws.Request) { + id := "op" + if v, ok := r.Params.(*s3.GetObjectInput); ok { + if v.Range != nil { + id = *v.Range + } + } + tracer := NewRequestTrace(r.Context(), r.Operation.Name, id) + r.SetContext(tracer) + + r.Handlers.Send.PushFront(tracer.OnSendAttempt) + r.Handlers.CompleteAttempt.PushBack(tracer.OnCompleteAttempt) + r.Handlers.Complete.PushBack(tracer.OnComplete) + r.Handlers.Complete.PushBack(func(rr *aws.Request) { + traces <- tracer + }) + } + + return tracerOption +} + +func newDownloader(clientConfig ClientConfig, sdkConfig SDKConfig, options ...aws.Option) *s3manager.Downloader { + client := NewClient(clientConfig) + + cfg, err := external.LoadDefaultAWSConfig(aws.Config{HTTPClient: client}) + if err != nil { + log.Fatalf("failed to load session, %v", err) + } + + downloader := s3manager.NewDownloader(cfg, func(d *s3manager.Downloader) { + d.PartSize = sdkConfig.PartSize + d.Concurrency = sdkConfig.Concurrency + d.BufferProvider = sdkConfig.BufferProvider + + d.RequestOptions = append(d.RequestOptions, options...) + }) + + return downloader +} diff --git a/internal/awstesting/integration/performance/s3DownloadManager/main_test.go b/internal/awstesting/integration/performance/s3DownloadManager/main_test.go new file mode 100644 index 00000000000..c20d86ef693 --- /dev/null +++ b/internal/awstesting/integration/performance/s3DownloadManager/main_test.go @@ -0,0 +1,112 @@ +// +build integration,perftest + +package main + +import ( + "flag" + "fmt" + "io" + "io/ioutil" + "os" + "testing" + + "github.com/aws/aws-sdk-go-v2/internal/awstesting/integration" + "github.com/aws/aws-sdk-go-v2/internal/sdkio" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" +) + +var benchConfig BenchmarkConfig + +type BenchmarkConfig struct { + bucket string + tempdir string + clientConfig ClientConfig +} + +func (b *BenchmarkConfig) SetupFlags(prefix string, flagSet *flag.FlagSet) { + flagSet.StringVar(&b.bucket, "bucket", "", "Bucket to use for benchmark") + flagSet.StringVar(&b.tempdir, "temp", os.TempDir(), "location to create temporary files") + b.clientConfig.SetupFlags(prefix, flagSet) +} + +var benchStrategies = map[string]s3manager.WriterReadFromProvider{ + "Unbuffered": nil, + "Buffered": s3manager.NewPooledBufferedWriterReadFromProvider(int(sdkio.MebiByte)), +} + +func BenchmarkDownload(b *testing.B) { + baseSdkConfig := SDKConfig{} + + // FileSizes: 5 MB, 1 GB + for _, fileSize := range []int64{5 * sdkio.MebiByte, 1 * sdkio.GibiByte} { + key, err := setupDownloadTest(benchConfig.bucket, fileSize) + if err != nil { + b.Fatalf("failed to setup download test: %v", err) + } + f, err := ioutil.TempFile(benchConfig.tempdir, "BenchmarkDownload") + if err != nil { + b.Fatalf("failed to create temporary file: %v", err) + } + b.Run(fmt.Sprintf("%s File", integration.SizeToName(int(fileSize))), func(b *testing.B) { + // Concurrency: 5, 10, 100 + for _, concurrency := range []int{s3manager.DefaultDownloadConcurrency, 2 * s3manager.DefaultDownloadConcurrency, 100} { + b.Run(fmt.Sprintf("%d Concurrency", concurrency), func(b *testing.B) { + // PartSize: 5 MB, 25 MB, 100 MB + for _, partSize := range []int64{s3manager.DefaultDownloadPartSize, 25 * sdkio.MebiByte, 100 * sdkio.MebiByte} { + if partSize > fileSize { + continue + } + b.Run(fmt.Sprintf("%s PartSize", integration.SizeToName(int(partSize))), func(b *testing.B) { + for name, strat := range benchStrategies { + b.Run(name, func(b *testing.B) { + sdkConfig := baseSdkConfig + sdkConfig.Concurrency = concurrency + sdkConfig.PartSize = partSize + sdkConfig.BufferProvider = strat + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchDownload(b, benchConfig.bucket, key, f, sdkConfig, benchConfig.clientConfig) + _, err := f.Seek(0, io.SeekStart) + if err != nil { + b.Fatalf("failed to seek file back to beginning: %v", err) + } + } + }) + } + }) + } + }) + } + }) + + err = teardownDownloadTest(benchConfig.bucket, key) + if err != nil { + b.Fatalf("failed to cleanup test file: %v", err) + } + if err = f.Close(); err != nil { + b.Errorf("failed to close file: %v", err) + } + if err = os.Remove(f.Name()); err != nil { + b.Errorf("failed to remove file: %v", err) + } + } +} + +func benchDownload(b *testing.B, bucket, key string, body io.WriterAt, sdkConfig SDKConfig, clientConfig ClientConfig) { + downloader := newDownloader(clientConfig, sdkConfig) + _, err := downloader.Download(body, &s3.GetObjectInput{ + Bucket: &bucket, + Key: &key, + }) + if err != nil { + b.Fatalf("failed to download object, %v", err) + } +} + +func TestMain(m *testing.M) { + benchConfig.SetupFlags("", flag.CommandLine) + flag.Parse() + os.Exit(m.Run()) +} diff --git a/internal/awstesting/integration/performance/s3DownloadManager/metric.go b/internal/awstesting/integration/performance/s3DownloadManager/metric.go new file mode 100644 index 00000000000..13563a15100 --- /dev/null +++ b/internal/awstesting/integration/performance/s3DownloadManager/metric.go @@ -0,0 +1,204 @@ +// +build integration,perftest + +package main + +import ( + "context" + "crypto/tls" + "fmt" + "net/http/httptrace" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" +) + +type RequestTrace struct { + Operation string + ID string + + context.Context + + start, finish time.Time + + errs Errors + attempts []RequestAttempt + curAttempt RequestAttempt +} + +func NewRequestTrace(ctx context.Context, op, id string) *RequestTrace { + rt := &RequestTrace{ + Operation: op, + ID: id, + start: time.Now(), + attempts: []RequestAttempt{}, + curAttempt: RequestAttempt{ + ID: id, + }, + } + + trace := &httptrace.ClientTrace{ + GetConn: rt.getConn, + GotConn: rt.gotConn, + PutIdleConn: rt.putIdleConn, + GotFirstResponseByte: rt.gotFirstResponseByte, + Got100Continue: rt.got100Continue, + DNSStart: rt.dnsStart, + DNSDone: rt.dnsDone, + ConnectStart: rt.connectStart, + ConnectDone: rt.connectDone, + TLSHandshakeStart: rt.tlsHandshakeStart, + TLSHandshakeDone: rt.tlsHandshakeDone, + WroteHeaders: rt.wroteHeaders, + Wait100Continue: rt.wait100Continue, + WroteRequest: rt.wroteRequest, + } + + rt.Context = httptrace.WithClientTrace(ctx, trace) + + return rt +} + +func (rt *RequestTrace) AppendError(err error) { + rt.errs = append(rt.errs, err) +} +func (rt *RequestTrace) OnSendAttempt(r *aws.Request) { + rt.curAttempt.SendStart = time.Now() +} +func (rt *RequestTrace) OnCompleteAttempt(r *aws.Request) { + rt.curAttempt.Start = r.AttemptTime + rt.curAttempt.Finish = time.Now() + rt.curAttempt.Err = r.Error + + if r.Error != nil { + rt.AppendError(r.Error) + } + + rt.attempts = append(rt.attempts, rt.curAttempt) + rt.curAttempt = RequestAttempt{ + ID: rt.curAttempt.ID, + AttemptNum: rt.curAttempt.AttemptNum + 1, + } +} +func (rt *RequestTrace) OnComplete(r *aws.Request) { + rt.finish = time.Now() + // Last attempt includes reading the response body + if len(rt.attempts) > 0 { + rt.attempts[len(rt.attempts)-1].Finish = rt.finish + } +} + +func (rt *RequestTrace) Err() error { + if len(rt.errs) != 0 { + return rt.errs + } + return nil +} +func (rt *RequestTrace) TotalLatency() time.Duration { + return rt.finish.Sub(rt.start) +} +func (rt *RequestTrace) Attempts() []RequestAttempt { + return rt.attempts +} +func (rt *RequestTrace) Retries() int { + return len(rt.attempts) - 1 +} + +func (rt *RequestTrace) getConn(hostPort string) {} +func (rt *RequestTrace) gotConn(info httptrace.GotConnInfo) { + rt.curAttempt.Reused = info.Reused +} +func (rt *RequestTrace) putIdleConn(err error) {} +func (rt *RequestTrace) gotFirstResponseByte() { + rt.curAttempt.FirstResponseByte = time.Now() +} +func (rt *RequestTrace) got100Continue() {} +func (rt *RequestTrace) dnsStart(info httptrace.DNSStartInfo) { + rt.curAttempt.DNSStart = time.Now() +} +func (rt *RequestTrace) dnsDone(info httptrace.DNSDoneInfo) { + rt.curAttempt.DNSDone = time.Now() +} +func (rt *RequestTrace) connectStart(network, addr string) { + rt.curAttempt.ConnectStart = time.Now() +} +func (rt *RequestTrace) connectDone(network, addr string, err error) { + rt.curAttempt.ConnectDone = time.Now() +} +func (rt *RequestTrace) tlsHandshakeStart() { + rt.curAttempt.TLSHandshakeStart = time.Now() +} +func (rt *RequestTrace) tlsHandshakeDone(state tls.ConnectionState, err error) { + rt.curAttempt.TLSHandshakeDone = time.Now() +} +func (rt *RequestTrace) wroteHeaders() { + rt.curAttempt.WroteHeaders = time.Now() +} +func (rt *RequestTrace) wait100Continue() { + rt.curAttempt.Read100Continue = time.Now() +} +func (rt *RequestTrace) wroteRequest(info httptrace.WroteRequestInfo) { + rt.curAttempt.RequestWritten = time.Now() +} + +type RequestAttempt struct { + Start, Finish time.Time + SendStart time.Time + Err error + + Reused bool + ID string + AttemptNum int + + DNSStart, DNSDone time.Time + ConnectStart, ConnectDone time.Time + TLSHandshakeStart, TLSHandshakeDone time.Time + WroteHeaders time.Time + RequestWritten time.Time + Read100Continue time.Time + FirstResponseByte time.Time +} + +func (a RequestAttempt) String() string { + const sep = ", " + + var w strings.Builder + w.WriteString(a.ID + "-" + strconv.Itoa(a.AttemptNum) + sep) + w.WriteString("Latency:" + durToMSString(a.Finish.Sub(a.Start)) + sep) + w.WriteString("SDKPreSend:" + durToMSString(a.SendStart.Sub(a.Start)) + sep) + + writeStart := a.SendStart + fmt.Fprintf(&w, "ConnReused:%t"+sep, a.Reused) + if !a.Reused { + w.WriteString("DNS:" + durToMSString(a.DNSDone.Sub(a.DNSStart)) + sep) + w.WriteString("Connect:" + durToMSString(a.ConnectDone.Sub(a.ConnectStart)) + sep) + w.WriteString("TLS:" + durToMSString(a.TLSHandshakeDone.Sub(a.TLSHandshakeStart)) + sep) + writeStart = a.TLSHandshakeDone + } + + writeHeader := a.WroteHeaders.Sub(writeStart) + w.WriteString("WriteHeader:" + durToMSString(writeHeader) + sep) + if !a.Read100Continue.IsZero() { + // With 100-continue + w.WriteString("Read100Cont:" + durToMSString(a.Read100Continue.Sub(a.WroteHeaders)) + sep) + w.WriteString("WritePayload:" + durToMSString(a.FirstResponseByte.Sub(a.RequestWritten)) + sep) + + w.WriteString("RespRead:" + durToMSString(a.Finish.Sub(a.RequestWritten)) + sep) + } else { + // No 100-continue + w.WriteString("WritePayload:" + durToMSString(a.RequestWritten.Sub(a.WroteHeaders)) + sep) + + if !a.FirstResponseByte.IsZero() { + w.WriteString("RespFirstByte:" + durToMSString(a.FirstResponseByte.Sub(a.RequestWritten)) + sep) + w.WriteString("RespRead:" + durToMSString(a.Finish.Sub(a.FirstResponseByte)) + sep) + } + } + + return w.String() +} + +func durToMSString(v time.Duration) string { + ms := float64(v) / float64(time.Millisecond) + return fmt.Sprintf("%0.6f", ms) +} diff --git a/internal/awstesting/integration/performance/s3UploadManager/README.md b/internal/awstesting/integration/performance/s3UploadManager/README.md new file mode 100644 index 00000000000..c70ec213900 --- /dev/null +++ b/internal/awstesting/integration/performance/s3UploadManager/README.md @@ -0,0 +1,43 @@ +## Performance Utility + +Uploads a file to a S3 bucket using the SDK's S3 upload manager. Allows passing +in custom configuration for the HTTP client and SDK's Upload Manager behavior. + +## Build +### Standalone +```sh +go build -tags "integration perftest" -o s3UploadPerfGo ./awstesting/integration/performance/s3UploadManager +``` +### Benchmarking +```sh +go test -tags "integration perftest" -c -o s3UploadPerfGo ./awstesting/integration/performance/s3UploadManager +``` + +## Usage Example: +### Standalone +```sh +AWS_REGION=us-west-2 AWS_PROFILE=aws-go-sdk-team-test ./s3UploadPerfGo \ +-bucket aws-sdk-go-data \ +-key 10GB.file \ +-file /tmp/10GB.file \ +-client.idle-conns 1000 \ +-client.idle-conns-host 300 \ +-sdk.concurrency 100 \ +-sdk.unsigned \ +-sdk.100-continue=false \ +-client.timeout.connect=1s \ +-client.timeout.response-header=1s +``` + +### Benchmarking +```sh +AWS_REGION=us-west-2 AWS_PROFILE=aws-go-sdk-team-test ./s3UploadPerfGo \ +-test.bench=. \ +-test.benchmem \ +-test.benchtime 1x \ +-bucket aws-sdk-go-data \ +-client.idle-conns 1000 \ +-client.idle-conns-host 300 \ +-client.timeout.connect=1s \ +-client.timeout.response-header=1s +``` diff --git a/internal/awstesting/integration/performance/s3UploadManager/client.go b/internal/awstesting/integration/performance/s3UploadManager/client.go new file mode 100644 index 00000000000..3b16c5f900b --- /dev/null +++ b/internal/awstesting/integration/performance/s3UploadManager/client.go @@ -0,0 +1,32 @@ +// +build integration,perftest + +package main + +import ( + "net" + "net/http" + "time" +) + +func NewClient(cfg ClientConfig) *http.Client { + tr := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: cfg.Timeouts.Connect, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: cfg.MaxIdleConns, + MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, + IdleConnTimeout: 90 * time.Second, + + DisableKeepAlives: !cfg.KeepAlive, + TLSHandshakeTimeout: cfg.Timeouts.TLSHandshake, + ExpectContinueTimeout: cfg.Timeouts.ExpectContinue, + ResponseHeaderTimeout: cfg.Timeouts.ResponseHeader, + } + + return &http.Client{ + Transport: tr, + } +} diff --git a/internal/awstesting/integration/performance/s3UploadManager/config.go b/internal/awstesting/integration/performance/s3UploadManager/config.go new file mode 100644 index 00000000000..33cc42d11c2 --- /dev/null +++ b/internal/awstesting/integration/performance/s3UploadManager/config.go @@ -0,0 +1,165 @@ +// +build integration,perftest + +package main + +import ( + "flag" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" +) + +type Config struct { + Bucket string + Filename string + Size int64 + TempDir string + LogVerbose bool + + SDK SDKConfig + Client ClientConfig +} + +func (c *Config) SetupFlags(prefix string, flagset *flag.FlagSet) { + flagset.StringVar(&c.Bucket, "bucket", "", + "The S3 bucket `name` to upload the object to.") + flagset.StringVar(&c.Filename, "file", "", + "The `path` of the local file to upload.") + flagset.Int64Var(&c.Size, "size", 0, + "The S3 object size in bytes to upload") + flagset.StringVar(&c.TempDir, "temp", os.TempDir(), "location to create temporary files") + flagset.BoolVar(&c.LogVerbose, "verbose", false, + "The output log will include verbose request information") + + c.SDK.SetupFlags(prefix, flagset) + c.Client.SetupFlags(prefix, flagset) +} + +func (c *Config) Validate() error { + var errs Errors + + if len(c.Bucket) == 0 || (c.Size <= 0 && c.Filename == "") { + errs = append(errs, fmt.Errorf("bucket and filename/size are required")) + } + + if err := c.SDK.Validate(); err != nil { + errs = append(errs, err) + } + if err := c.Client.Validate(); err != nil { + errs = append(errs, err) + } + + if len(errs) != 0 { + return errs + } + + return nil +} + +type SDKConfig struct { + PartSize int64 + Concurrency int + WithUnsignedPayload bool + ExpectContinue bool + BufferProvider s3manager.ReadSeekerWriteToProvider +} + +func (c *SDKConfig) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "sdk." + + flagset.Int64Var(&c.PartSize, prefix+"part-size", s3manager.DefaultUploadPartSize, + "Specifies the `size` of parts of the object to upload.") + flagset.IntVar(&c.Concurrency, prefix+"concurrency", s3manager.DefaultUploadConcurrency, + "Specifies the number of parts to upload `at once`.") + flagset.BoolVar(&c.WithUnsignedPayload, prefix+"unsigned", false, + "Specifies if the SDK will use UNSIGNED_PAYLOAD for part SHA256 in request signature.") + + flagset.BoolVar(&c.ExpectContinue, prefix+"100-continue", true, + "Specifies if the SDK requests will wait for the 100 continue response before sending request payload.") +} + +func (c *SDKConfig) Validate() error { + return nil +} + +type ClientConfig struct { + KeepAlive bool + Timeouts Timeouts + + MaxIdleConns int + MaxIdleConnsPerHost int +} + +func (c *ClientConfig) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "client." + + flagset.BoolVar(&c.KeepAlive, prefix+"http-keep-alive", true, + "Specifies if HTTP keep alive is enabled.") + + defTR := http.DefaultTransport.(*http.Transport) + + flagset.IntVar(&c.MaxIdleConns, prefix+"idle-conns", defTR.MaxIdleConns, + "Specifies max idle connection pool size.") + + flagset.IntVar(&c.MaxIdleConnsPerHost, prefix+"idle-conns-host", http.DefaultMaxIdleConnsPerHost, + "Specifies max idle connection pool per host, will be truncated by idle-conns.") + + c.Timeouts.SetupFlags(prefix, flagset) +} + +func (c *ClientConfig) Validate() error { + var errs Errors + + if err := c.Timeouts.Validate(); err != nil { + errs = append(errs, err) + } + + if len(errs) != 0 { + return errs + } + return nil +} + +type Timeouts struct { + Connect time.Duration + TLSHandshake time.Duration + ExpectContinue time.Duration + ResponseHeader time.Duration +} + +func (c *Timeouts) SetupFlags(prefix string, flagset *flag.FlagSet) { + prefix += "timeout." + + flagset.DurationVar(&c.Connect, prefix+"connect", 30*time.Second, + "The `timeout` connecting to the remote host.") + + defTR := http.DefaultTransport.(*http.Transport) + + flagset.DurationVar(&c.TLSHandshake, prefix+"tls", defTR.TLSHandshakeTimeout, + "The `timeout` waiting for the TLS handshake to complete.") + + flagset.DurationVar(&c.ExpectContinue, prefix+"expect-continue", defTR.ExpectContinueTimeout, + "The `timeout` waiting for the TLS handshake to complete.") + + flagset.DurationVar(&c.ResponseHeader, prefix+"response-header", defTR.ResponseHeaderTimeout, + "The `timeout` waiting for the TLS handshake to complete.") +} + +func (c *Timeouts) Validate() error { + return nil +} + +type Errors []error + +func (es Errors) Error() string { + var buf strings.Builder + for _, e := range es { + buf.WriteString(e.Error()) + } + + return buf.String() +} diff --git a/internal/awstesting/integration/performance/s3UploadManager/main.go b/internal/awstesting/integration/performance/s3UploadManager/main.go new file mode 100644 index 00000000000..eebd0e391e8 --- /dev/null +++ b/internal/awstesting/integration/performance/s3UploadManager/main.go @@ -0,0 +1,189 @@ +// +build integration,perftest + +package main + +import ( + "flag" + "github.com/aws/aws-sdk-go-v2/aws/external" + "log" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/internal/awstesting/integration" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" +) + +var config Config + +func main() { + parseCommandLine() + + log.SetOutput(os.Stderr) + + var ( + file *os.File + err error + ) + + if config.Filename != "" { + file, err = os.Open(config.Filename) + if err != nil { + log.Fatalf("failed to open file: %v", err) + } + } else { + file, err = integration.CreateFileOfSize(config.TempDir, config.Size) + if err != nil { + log.Fatalf("failed to create file: %v", err) + } + defer os.Remove(file.Name()) + } + + defer file.Close() + + traces := make(chan *RequestTrace, config.SDK.Concurrency) + requestTracer := uploadRequestTracer(traces) + uploader := newUploader(config.Client, config.SDK, requestTracer) + + metricReportDone := make(chan struct{}) + go func() { + defer close(metricReportDone) + metrics := map[string]*RequestTrace{} + for trace := range traces { + curTrace, ok := metrics[trace.Operation] + if !ok { + curTrace = trace + } else { + curTrace.attempts = append(curTrace.attempts, trace.attempts...) + if len(trace.errs) != 0 { + curTrace.errs = append(curTrace.errs, trace.errs...) + } + curTrace.finish = trace.finish + } + + metrics[trace.Operation] = curTrace + } + + for _, name := range []string{ + "CreateMultipartUpload", + "CompleteMultipartUpload", + "UploadPart", + "PutObject", + } { + if trace, ok := metrics[name]; ok { + printAttempts(name, trace, config.LogVerbose) + } + } + }() + + log.Println("starting upload...") + start := time.Now() + _, err = uploader.Upload(&s3manager.UploadInput{ + Bucket: &config.Bucket, + Key: aws.String(filepath.Base(file.Name())), + Body: file, + }) + if err != nil { + log.Fatalf("failed to upload object, %v", err) + } + close(traces) + + fileInfo, _ := file.Stat() + size := fileInfo.Size() + dur := time.Since(start) + log.Printf("Upload finished, Size: %d, Dur: %s, Throughput: %.5f GB/s", + size, dur, (float64(size)/(float64(dur)/float64(time.Second)))/float64(1e9), + ) + + <-metricReportDone +} + +func parseCommandLine() { + config.SetupFlags("", flag.CommandLine) + + if err := flag.CommandLine.Parse(os.Args[1:]); err != nil { + flag.CommandLine.PrintDefaults() + log.Fatalf("failed to parse CLI commands") + } + if err := config.Validate(); err != nil { + flag.CommandLine.PrintDefaults() + log.Fatalf("invalid arguments: %v", err) + } +} + +func printAttempts(op string, trace *RequestTrace, verbose bool) { + if !verbose { + return + } + + log.Printf("%s: latency:%s requests:%d errors:%d", + op, + trace.finish.Sub(trace.start), + len(trace.attempts), + len(trace.errs), + ) + + for _, a := range trace.attempts { + log.Printf(" * %s", a) + } + if err := trace.Err(); err != nil { + log.Printf("Operation Errors: %v", err) + } + log.Println() +} + +func uploadRequestTracer(traces chan<- *RequestTrace) aws.Option { + tracerOption := func(r *aws.Request) { + id := "op" + if v, ok := r.Params.(*s3.UploadPartInput); ok { + id = strconv.FormatInt(*v.PartNumber, 10) + } + tracer := NewRequestTrace(r.Context(), r.Operation.Name, id) + r.SetContext(tracer) + + r.Handlers.Send.PushFront(tracer.OnSendAttempt) + r.Handlers.CompleteAttempt.PushBack(tracer.OnCompleteAttempt) + r.Handlers.Complete.PushBack(tracer.OnComplete) + r.Handlers.Complete.PushBack(func(rr *aws.Request) { + traces <- tracer + }) + } + + return tracerOption +} + +func SetUnsignedPayload(r *aws.Request) { + if r.Operation.Name != "UploadPart" && r.Operation.Name != "PutObject" { + return + } + r.HTTPRequest.Header.Set("X-Amz-Content-Sha256", "UNSIGNED-PAYLOAD") +} + +func newUploader(clientConfig ClientConfig, sdkConfig SDKConfig, options ...aws.Option) *s3manager.Uploader { + client := NewClient(clientConfig) + + if sdkConfig.WithUnsignedPayload { + options = append(options, SetUnsignedPayload) + } + + cfg, err := external.LoadDefaultAWSConfig(aws.Config{HTTPClient: client}) + if err != nil { + log.Fatalf("failed to load config: %v", err) + } + + svc := s3.New(cfg) + svc.Disable100Continue = !sdkConfig.ExpectContinue + + uploader := s3manager.NewUploaderWithClient(svc, func(u *s3manager.Uploader) { + u.PartSize = sdkConfig.PartSize + u.Concurrency = sdkConfig.Concurrency + u.BufferProvider = sdkConfig.BufferProvider + + u.RequestOptions = append(u.RequestOptions, options...) + }) + + return uploader +} diff --git a/internal/awstesting/integration/performance/s3UploadManager/main_test.go b/internal/awstesting/integration/performance/s3UploadManager/main_test.go new file mode 100644 index 00000000000..aa64ec63d95 --- /dev/null +++ b/internal/awstesting/integration/performance/s3UploadManager/main_test.go @@ -0,0 +1,139 @@ +// +build integration,perftest + +package main + +import ( + "bytes" + "flag" + "fmt" + "io" + "os" + "path/filepath" + "testing" + + "github.com/aws/aws-sdk-go-v2/internal/awstesting/integration" + "github.com/aws/aws-sdk-go-v2/internal/sdkio" + "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" +) + +var benchConfig BenchmarkConfig + +type BenchmarkConfig struct { + bucket string + tempdir string + clientConfig ClientConfig +} + +func (b *BenchmarkConfig) SetupFlags(prefix string, flagSet *flag.FlagSet) { + flagSet.StringVar(&b.bucket, "bucket", "", "Bucket to use for benchmark") + flagSet.StringVar(&b.tempdir, "temp", os.TempDir(), "location to create temporary files") + b.clientConfig.SetupFlags(prefix, flagSet) +} + +var benchStrategies = []struct { + name string + bufferProvider s3manager.ReadSeekerWriteToProvider +}{ + {name: "Unbuffered", bufferProvider: nil}, + {name: "Buffered", bufferProvider: s3manager.NewBufferedReadSeekerWriteToPool(1024 * 1024)}, +} + +func BenchmarkInMemory(b *testing.B) { + memBreader := bytes.NewReader(make([]byte, 1*sdkio.GibiByte)) + + baseSdkConfig := SDKConfig{WithUnsignedPayload: true, ExpectContinue: true} + + key := integration.UniqueID() + // Concurrency: 5, 10, 100 + for _, concurrency := range []int{s3manager.DefaultUploadConcurrency, 2 * s3manager.DefaultUploadConcurrency, 100} { + b.Run(fmt.Sprintf("%d_Concurrency", concurrency), func(b *testing.B) { + // PartSize: 5 MB, 25 MB, 100 MB + for _, partSize := range []int64{s3manager.DefaultUploadPartSize, 25 * sdkio.MebiByte, 100 * sdkio.MebiByte} { + b.Run(fmt.Sprintf("%s_PartSize", integration.SizeToName(int(partSize))), func(b *testing.B) { + sdkConfig := baseSdkConfig + + sdkConfig.BufferProvider = nil + sdkConfig.Concurrency = concurrency + sdkConfig.PartSize = partSize + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchUpload(b, benchConfig.bucket, key, memBreader, sdkConfig, benchConfig.clientConfig) + _, err := memBreader.Seek(0, io.SeekStart) + if err != nil { + b.Fatalf("failed to seek to start of file: %v", err) + } + } + }) + } + }) + } +} + +func BenchmarkUpload(b *testing.B) { + baseSdkConfig := SDKConfig{WithUnsignedPayload: true, ExpectContinue: true} + + // FileSizes: 5 MB, 1 GB, 10 GB + for _, fileSize := range []int64{5 * sdkio.MebiByte, sdkio.GibiByte, 10 * sdkio.GibiByte} { + b.Run(fmt.Sprintf("%s_File", integration.SizeToName(int(fileSize))), func(b *testing.B) { + b.Logf("creating file of size: %s", integration.SizeToName(int(fileSize))) + file, err := integration.CreateFileOfSize(benchConfig.tempdir, fileSize) + if err != nil { + b.Fatalf("failed to create file: %v", err) + } + + // Concurrency: 5, 10, 100 + for _, concurrency := range []int{s3manager.DefaultUploadConcurrency, 2 * s3manager.DefaultUploadConcurrency, 100} { + b.Run(fmt.Sprintf("%d_Concurrency", concurrency), func(b *testing.B) { + // PartSize: 5 MB, 25 MB, 100 MB + for _, partSize := range []int64{s3manager.DefaultUploadPartSize, 25 * sdkio.MebiByte, 100 * sdkio.MebiByte} { + if partSize > fileSize { + continue + } + b.Run(fmt.Sprintf("%s_PartSize", integration.SizeToName(int(partSize))), func(b *testing.B) { + for _, strat := range benchStrategies { + b.Run(strat.name, func(b *testing.B) { + sdkConfig := baseSdkConfig + + sdkConfig.BufferProvider = strat.bufferProvider + sdkConfig.Concurrency = concurrency + sdkConfig.PartSize = partSize + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchUpload(b, benchConfig.bucket, filepath.Base(file.Name()), file, sdkConfig, benchConfig.clientConfig) + _, err := file.Seek(0, io.SeekStart) + if err != nil { + b.Fatalf("failed to seek to start of file: %v", err) + } + } + }) + } + }) + } + }) + } + + os.Remove(file.Name()) + file.Close() + }) + } +} + +func benchUpload(b *testing.B, bucket, key string, reader io.ReadSeeker, sdkConfig SDKConfig, clientConfig ClientConfig) { + uploader := newUploader(clientConfig, sdkConfig, SetUnsignedPayload) + _, err := uploader.Upload(&s3manager.UploadInput{ + Bucket: &bucket, + Key: &key, + Body: reader, + }) + if err != nil { + b.Fatalf("failed to upload object, %v", err) + } +} + +func TestMain(m *testing.M) { + benchConfig.SetupFlags("", flag.CommandLine) + flag.Parse() + os.Exit(m.Run()) +} diff --git a/internal/awstesting/integration/performance/s3UploadManager/metric.go b/internal/awstesting/integration/performance/s3UploadManager/metric.go new file mode 100644 index 00000000000..13563a15100 --- /dev/null +++ b/internal/awstesting/integration/performance/s3UploadManager/metric.go @@ -0,0 +1,204 @@ +// +build integration,perftest + +package main + +import ( + "context" + "crypto/tls" + "fmt" + "net/http/httptrace" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" +) + +type RequestTrace struct { + Operation string + ID string + + context.Context + + start, finish time.Time + + errs Errors + attempts []RequestAttempt + curAttempt RequestAttempt +} + +func NewRequestTrace(ctx context.Context, op, id string) *RequestTrace { + rt := &RequestTrace{ + Operation: op, + ID: id, + start: time.Now(), + attempts: []RequestAttempt{}, + curAttempt: RequestAttempt{ + ID: id, + }, + } + + trace := &httptrace.ClientTrace{ + GetConn: rt.getConn, + GotConn: rt.gotConn, + PutIdleConn: rt.putIdleConn, + GotFirstResponseByte: rt.gotFirstResponseByte, + Got100Continue: rt.got100Continue, + DNSStart: rt.dnsStart, + DNSDone: rt.dnsDone, + ConnectStart: rt.connectStart, + ConnectDone: rt.connectDone, + TLSHandshakeStart: rt.tlsHandshakeStart, + TLSHandshakeDone: rt.tlsHandshakeDone, + WroteHeaders: rt.wroteHeaders, + Wait100Continue: rt.wait100Continue, + WroteRequest: rt.wroteRequest, + } + + rt.Context = httptrace.WithClientTrace(ctx, trace) + + return rt +} + +func (rt *RequestTrace) AppendError(err error) { + rt.errs = append(rt.errs, err) +} +func (rt *RequestTrace) OnSendAttempt(r *aws.Request) { + rt.curAttempt.SendStart = time.Now() +} +func (rt *RequestTrace) OnCompleteAttempt(r *aws.Request) { + rt.curAttempt.Start = r.AttemptTime + rt.curAttempt.Finish = time.Now() + rt.curAttempt.Err = r.Error + + if r.Error != nil { + rt.AppendError(r.Error) + } + + rt.attempts = append(rt.attempts, rt.curAttempt) + rt.curAttempt = RequestAttempt{ + ID: rt.curAttempt.ID, + AttemptNum: rt.curAttempt.AttemptNum + 1, + } +} +func (rt *RequestTrace) OnComplete(r *aws.Request) { + rt.finish = time.Now() + // Last attempt includes reading the response body + if len(rt.attempts) > 0 { + rt.attempts[len(rt.attempts)-1].Finish = rt.finish + } +} + +func (rt *RequestTrace) Err() error { + if len(rt.errs) != 0 { + return rt.errs + } + return nil +} +func (rt *RequestTrace) TotalLatency() time.Duration { + return rt.finish.Sub(rt.start) +} +func (rt *RequestTrace) Attempts() []RequestAttempt { + return rt.attempts +} +func (rt *RequestTrace) Retries() int { + return len(rt.attempts) - 1 +} + +func (rt *RequestTrace) getConn(hostPort string) {} +func (rt *RequestTrace) gotConn(info httptrace.GotConnInfo) { + rt.curAttempt.Reused = info.Reused +} +func (rt *RequestTrace) putIdleConn(err error) {} +func (rt *RequestTrace) gotFirstResponseByte() { + rt.curAttempt.FirstResponseByte = time.Now() +} +func (rt *RequestTrace) got100Continue() {} +func (rt *RequestTrace) dnsStart(info httptrace.DNSStartInfo) { + rt.curAttempt.DNSStart = time.Now() +} +func (rt *RequestTrace) dnsDone(info httptrace.DNSDoneInfo) { + rt.curAttempt.DNSDone = time.Now() +} +func (rt *RequestTrace) connectStart(network, addr string) { + rt.curAttempt.ConnectStart = time.Now() +} +func (rt *RequestTrace) connectDone(network, addr string, err error) { + rt.curAttempt.ConnectDone = time.Now() +} +func (rt *RequestTrace) tlsHandshakeStart() { + rt.curAttempt.TLSHandshakeStart = time.Now() +} +func (rt *RequestTrace) tlsHandshakeDone(state tls.ConnectionState, err error) { + rt.curAttempt.TLSHandshakeDone = time.Now() +} +func (rt *RequestTrace) wroteHeaders() { + rt.curAttempt.WroteHeaders = time.Now() +} +func (rt *RequestTrace) wait100Continue() { + rt.curAttempt.Read100Continue = time.Now() +} +func (rt *RequestTrace) wroteRequest(info httptrace.WroteRequestInfo) { + rt.curAttempt.RequestWritten = time.Now() +} + +type RequestAttempt struct { + Start, Finish time.Time + SendStart time.Time + Err error + + Reused bool + ID string + AttemptNum int + + DNSStart, DNSDone time.Time + ConnectStart, ConnectDone time.Time + TLSHandshakeStart, TLSHandshakeDone time.Time + WroteHeaders time.Time + RequestWritten time.Time + Read100Continue time.Time + FirstResponseByte time.Time +} + +func (a RequestAttempt) String() string { + const sep = ", " + + var w strings.Builder + w.WriteString(a.ID + "-" + strconv.Itoa(a.AttemptNum) + sep) + w.WriteString("Latency:" + durToMSString(a.Finish.Sub(a.Start)) + sep) + w.WriteString("SDKPreSend:" + durToMSString(a.SendStart.Sub(a.Start)) + sep) + + writeStart := a.SendStart + fmt.Fprintf(&w, "ConnReused:%t"+sep, a.Reused) + if !a.Reused { + w.WriteString("DNS:" + durToMSString(a.DNSDone.Sub(a.DNSStart)) + sep) + w.WriteString("Connect:" + durToMSString(a.ConnectDone.Sub(a.ConnectStart)) + sep) + w.WriteString("TLS:" + durToMSString(a.TLSHandshakeDone.Sub(a.TLSHandshakeStart)) + sep) + writeStart = a.TLSHandshakeDone + } + + writeHeader := a.WroteHeaders.Sub(writeStart) + w.WriteString("WriteHeader:" + durToMSString(writeHeader) + sep) + if !a.Read100Continue.IsZero() { + // With 100-continue + w.WriteString("Read100Cont:" + durToMSString(a.Read100Continue.Sub(a.WroteHeaders)) + sep) + w.WriteString("WritePayload:" + durToMSString(a.FirstResponseByte.Sub(a.RequestWritten)) + sep) + + w.WriteString("RespRead:" + durToMSString(a.Finish.Sub(a.RequestWritten)) + sep) + } else { + // No 100-continue + w.WriteString("WritePayload:" + durToMSString(a.RequestWritten.Sub(a.WroteHeaders)) + sep) + + if !a.FirstResponseByte.IsZero() { + w.WriteString("RespFirstByte:" + durToMSString(a.FirstResponseByte.Sub(a.RequestWritten)) + sep) + w.WriteString("RespRead:" + durToMSString(a.Finish.Sub(a.FirstResponseByte)) + sep) + } + } + + return w.String() +} + +func durToMSString(v time.Duration) string { + ms := float64(v) / float64(time.Millisecond) + return fmt.Sprintf("%0.6f", ms) +} diff --git a/internal/sdkio/byte.go b/internal/sdkio/byte.go new file mode 100644 index 00000000000..6c443988bbc --- /dev/null +++ b/internal/sdkio/byte.go @@ -0,0 +1,12 @@ +package sdkio + +const ( + // Byte is 8 bits + Byte int64 = 1 + // KibiByte (KiB) is 1024 Bytes + KibiByte = Byte * 1024 + // MebiByte (MiB) is 1024 KiB + MebiByte = KibiByte * 1024 + // GibiByte (GiB) is 1024 MiB + GibiByte = MebiByte * 1024 +) diff --git a/service/s3/s3manager/buffered_read_seeker.go b/service/s3/s3manager/buffered_read_seeker.go new file mode 100644 index 00000000000..818a2a4e7a8 --- /dev/null +++ b/service/s3/s3manager/buffered_read_seeker.go @@ -0,0 +1,79 @@ +package s3manager + +import ( + "io" +) + +// BufferedReadSeeker is buffered io.ReadSeeker +type BufferedReadSeeker struct { + r io.ReadSeeker + buffer []byte + readIdx, writeIdx int +} + +// NewBufferedReadSeeker returns a new BufferedReadSeeker +// if len(b) == 0 then the buffer will be initialized to 64 KiB. +func NewBufferedReadSeeker(r io.ReadSeeker, b []byte) *BufferedReadSeeker { + if len(b) == 0 { + b = make([]byte, 64*1024) + } + return &BufferedReadSeeker{r: r, buffer: b} +} + +func (b *BufferedReadSeeker) reset(r io.ReadSeeker) { + b.r = r + b.readIdx, b.writeIdx = 0, 0 +} + +// Read will read up len(p) bytes into p and will return +// the number of bytes read and any error that occurred. +// If the len(p) > the buffer size then a single read request +// will be issued to the underlying io.ReadSeeker for len(p) bytes. +// A Read request will at most perform a single Read to the underlying +// io.ReadSeeker, and may return < len(p) if serviced from the buffer. +func (b *BufferedReadSeeker) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return n, err + } + + if b.readIdx == b.writeIdx { + if len(p) >= len(b.buffer) { + n, err = b.r.Read(p) + return n, err + } + b.readIdx, b.writeIdx = 0, 0 + + n, err = b.r.Read(b.buffer) + if n == 0 { + return n, err + } + + b.writeIdx += n + } + + n = copy(p, b.buffer[b.readIdx:b.writeIdx]) + b.readIdx += n + + return n, err +} + +// Seek will position then underlying io.ReadSeeker to the given offset +// and will clear the buffer. +func (b *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) { + n, err := b.r.Seek(offset, whence) + + b.reset(b.r) + + return n, err +} + +// ReadAt will read up to len(p) bytes at the given file offset. +// This will result in the buffer being cleared. +func (b *BufferedReadSeeker) ReadAt(p []byte, off int64) (int, error) { + _, err := b.Seek(off, io.SeekStart) + if err != nil { + return 0, err + } + + return b.Read(p) +} diff --git a/service/s3/s3manager/buffered_read_seeker_test.go b/service/s3/s3manager/buffered_read_seeker_test.go new file mode 100644 index 00000000000..857cce56809 --- /dev/null +++ b/service/s3/s3manager/buffered_read_seeker_test.go @@ -0,0 +1,79 @@ +package s3manager + +import ( + "bytes" + "io" + "testing" +) + +func TestBufferedReadSeekerRead(t *testing.T) { + expected := []byte("testData") + + readSeeker := NewBufferedReadSeeker(bytes.NewReader(expected), make([]byte, 4)) + + var ( + actual []byte + buffer = make([]byte, 2) + ) + + for { + n, err := readSeeker.Read(buffer) + actual = append(actual, buffer[:n]...) + if err != nil && err == io.EOF { + break + } else if err != nil { + t.Fatalf("failed to read from reader: %v", err) + } + } + + if !bytes.Equal(expected, actual) { + t.Errorf("expected %v, got %v", expected, actual) + } +} + +func TestBufferedReadSeekerSeek(t *testing.T) { + content := []byte("testData") + + readSeeker := NewBufferedReadSeeker(bytes.NewReader(content), make([]byte, 4)) + + _, err := readSeeker.Seek(4, io.SeekStart) + if err != nil { + t.Fatalf("failed to seek reader: %v", err) + } + + var ( + actual []byte + buffer = make([]byte, 4) + ) + + for { + n, err := readSeeker.Read(buffer) + actual = append(actual, buffer[:n]...) + if err != nil && err == io.EOF { + break + } else if err != nil { + t.Fatalf("failed to read from reader: %v", err) + } + } + + if e := []byte("Data"); !bytes.Equal(e, actual) { + t.Errorf("expected %v, got %v", e, actual) + } +} + +func TestBufferedReadSeekerReadAt(t *testing.T) { + content := []byte("testData") + + readSeeker := NewBufferedReadSeeker(bytes.NewReader(content), make([]byte, 2)) + + buffer := make([]byte, 4) + + _, err := readSeeker.ReadAt(buffer, 0) + if err != nil { + t.Fatalf("failed to seek reader: %v", err) + } + + if e := content[:4]; !bytes.Equal(e, buffer) { + t.Errorf("expected %v, got %v", e, buffer) + } +} diff --git a/service/s3/s3manager/default_read_seeker_write_to.go b/service/s3/s3manager/default_read_seeker_write_to.go new file mode 100644 index 00000000000..42276530a8b --- /dev/null +++ b/service/s3/s3manager/default_read_seeker_write_to.go @@ -0,0 +1,7 @@ +// +build !windows + +package s3manager + +func defaultUploadBufferProvider() ReadSeekerWriteToProvider { + return nil +} diff --git a/service/s3/s3manager/default_read_seeker_write_to_windows.go b/service/s3/s3manager/default_read_seeker_write_to_windows.go new file mode 100644 index 00000000000..687082c3066 --- /dev/null +++ b/service/s3/s3manager/default_read_seeker_write_to_windows.go @@ -0,0 +1,5 @@ +package s3manager + +func defaultUploadBufferProvider() ReadSeekerWriteToProvider { + return NewBufferedReadSeekerWriteToPool(1024 * 1024) +} diff --git a/service/s3/s3manager/default_writer_read_from.go b/service/s3/s3manager/default_writer_read_from.go new file mode 100644 index 00000000000..ada50c24355 --- /dev/null +++ b/service/s3/s3manager/default_writer_read_from.go @@ -0,0 +1,7 @@ +// +build !windows + +package s3manager + +func defaultDownloadBufferProvider() WriterReadFromProvider { + return nil +} diff --git a/service/s3/s3manager/default_writer_read_from_windows.go b/service/s3/s3manager/default_writer_read_from_windows.go new file mode 100644 index 00000000000..7e9d9579f64 --- /dev/null +++ b/service/s3/s3manager/default_writer_read_from_windows.go @@ -0,0 +1,5 @@ +package s3manager + +func defaultDownloadBufferProvider() WriterReadFromProvider { + return NewPooledBufferedWriterReadFromProvider(1024 * 1024) +} diff --git a/service/s3/s3manager/download.go b/service/s3/s3manager/download.go index 3c743f1baec..749fb086a82 100644 --- a/service/s3/s3manager/download.go +++ b/service/s3/s3manager/download.go @@ -25,13 +25,25 @@ const DefaultDownloadPartSize = 1024 * 1024 * 5 // when using Download(). const DefaultDownloadConcurrency = 5 +type errReadingBody struct { + err error +} + +func (e *errReadingBody) Error() string { + return fmt.Sprintf("failed to read part body: %v", e.err) +} + +func (e *errReadingBody) Unwrap() error { + return e.err +} + // The Downloader structure that calls Download(). It is safe to call Download() // on this structure for multiple objects and across concurrent goroutines. // Mutating the Downloader's properties is not safe to be done concurrently. type Downloader struct { - // The buffer size (in bytes) to use when buffering data into chunks and - // sending them as parts to S3. The minimum allowed part size is 5MB, and - // if this value is set to zero, the DefaultDownloadPartSize value will be used. + // The size (in bytes) to request from S3 for each part. + // The minimum allowed part size is 5MB, and if this value is set to zero, + // the DefaultDownloadPartSize value will be used. // // PartSize is ignored if the Range input parameter is provided. PartSize int64 @@ -54,6 +66,14 @@ type Downloader struct { // The retryer that the downloader will use to determine how many times // a part download should retried before it is considered to of failed. Retryer aws.Retryer + + // Defines the buffer strategy used when downloading a part. + // + // If a WriterReadFromProvider is given the Download manager + // will pass the io.WriterAt of the Download request to the provider + // and will use the returned WriterReadFrom from the provider as the + // destination writer when copying from http response body. + BufferProvider WriterReadFromProvider } // WithDownloaderRequestOptions appends to the Downloader's API request options. @@ -80,11 +100,22 @@ func WithDownloaderRequestOptions(opts ...request.Option) func(*Downloader) { // d.PartSize = 64 * 1024 * 1024 // 64MB per part // }) func NewDownloader(cfg aws.Config, options ...func(*Downloader)) *Downloader { + return newDownloader(s3.New(cfg), options...) +} + +func newDownloader(client s3iface.ClientAPI, options ...func(*Downloader)) *Downloader { + var retryer aws.Retryer + + if s3Svc, ok := client.(*s3.Client); ok { + retryer = s3Svc.Retryer + } + d := &Downloader{ - S3: s3.New(cfg), - PartSize: DefaultDownloadPartSize, - Concurrency: DefaultDownloadConcurrency, - Retryer: cfg.Retryer, + S3: client, + PartSize: DefaultDownloadPartSize, + Concurrency: DefaultDownloadConcurrency, + Retryer: retryer, + BufferProvider: defaultDownloadBufferProvider(), } for _, option := range options { @@ -114,24 +145,7 @@ func NewDownloader(cfg aws.Config, options ...func(*Downloader)) *Downloader { // d.PartSize = 64 * 1024 * 1024 // 64MB per part // }) func NewDownloaderWithClient(svc s3iface.ClientAPI, options ...func(*Downloader)) *Downloader { - var retryer aws.Retryer - - if s3Svc, ok := svc.(*s3.Client); ok { - retryer = s3Svc.Retryer - } - - d := &Downloader{ - S3: svc, - PartSize: DefaultDownloadPartSize, - Concurrency: DefaultDownloadConcurrency, - Retryer: retryer, - } - - for _, option := range options { - option(d) - } - - return d + return newDownloader(svc, options...) } // Download downloads an object in S3 and writes the payload into w using @@ -412,21 +426,20 @@ func (d *downloader) downloadChunk(chunk dlchunk) error { var n int64 var err error for retry := 0; retry <= d.partBodyMaxRetries; retry++ { - req := d.cfg.S3.GetObjectRequest(in) - req.ApplyOptions(d.cfg.RequestOptions...) - - var resp *s3.GetObjectResponse - resp, err = req.Send(d.ctx) - if err != nil { - return err - } - d.setTotalBytes(resp.GetObjectOutput) // Set total if not yet set. - - n, err = io.Copy(&chunk, resp.Body) - resp.Body.Close() + n, err = d.tryDownloadChunk(in, &chunk) if err == nil { break } + // Check if the returned error is an errReadingBody. + // If err is errReadingBody this indicates that an error + // occurred while copying the http response body. + // If this occurs we unwrap the error to set the underlying error + // and attempt any remaining retries. + if bodyErr, ok := err.(*errReadingBody); ok { + err = bodyErr.Unwrap() + } else { + return err + } chunk.cur = 0 logMessage(d.cfg.S3, aws.LogDebugWithRequestRetries, @@ -439,6 +452,30 @@ func (d *downloader) downloadChunk(chunk dlchunk) error { return err } +func (d *downloader) tryDownloadChunk(in *s3.GetObjectInput, w io.Writer) (int64, error) { + cleanup := func() {} + if d.cfg.BufferProvider != nil { + w, cleanup = d.cfg.BufferProvider.GetReadFrom(w) + } + defer cleanup() + + req := d.cfg.S3.GetObjectRequest(in) + req.ApplyOptions(d.cfg.RequestOptions...) + resp, err := req.Send(d.ctx) + if err != nil { + return 0, err + } + d.setTotalBytes(resp.GetObjectOutput) // Set total if not yet set. + + n, err := io.Copy(w, resp.Body) + resp.Body.Close() + if err != nil { + return n, &errReadingBody{err: err} + } + + return n, nil +} + func logMessage(svc s3iface.ClientAPI, level aws.LogLevel, msg string) { s, ok := svc.(*s3.Client) if !ok { diff --git a/service/s3/s3manager/download_test.go b/service/s3/s3manager/download_test.go index fe1bb422d23..e29be063e90 100644 --- a/service/s3/s3manager/download_test.go +++ b/service/s3/s3manager/download_test.go @@ -11,14 +11,15 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" - request "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/awserr" "github.com/aws/aws-sdk-go-v2/internal/awstesting" "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" + "github.com/aws/aws-sdk-go-v2/internal/sdkio" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" ) @@ -30,7 +31,7 @@ func dlLoggingSvc(data []byte) (*s3.Client, *[]string, *[]string) { svc := s3.New(unit.Config()) svc.Handlers.Send.Clear() - svc.Handlers.Send.PushBack(func(r *request.Request) { + svc.Handlers.Send.PushBack(func(r *aws.Request) { m.Lock() defer m.Unlock() @@ -67,7 +68,7 @@ func dlLoggingSvcNoChunk(data []byte) (*s3.Client, *[]string) { svc := s3.New(unit.Config()) svc.Handlers.Send.Clear() - svc.Handlers.Send.PushBack(func(r *request.Request) { + svc.Handlers.Send.PushBack(func(r *aws.Request) { m.Lock() defer m.Unlock() @@ -91,7 +92,7 @@ func dlLoggingSvcNoContentRangeLength(data []byte, states []int) (*s3.Client, *[ svc := s3.New(unit.Config()) svc.Handlers.Send.Clear() - svc.Handlers.Send.PushBack(func(r *request.Request) { + svc.Handlers.Send.PushBack(func(r *aws.Request) { m.Lock() defer m.Unlock() @@ -116,7 +117,7 @@ func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.Client, *[ svc := s3.New(unit.Config()) svc.Handlers.Send.Clear() - svc.Handlers.Send.PushBack(func(r *request.Request) { + svc.Handlers.Send.PushBack(func(r *aws.Request) { m.Lock() defer m.Unlock() @@ -173,7 +174,7 @@ func dlLoggingSvcWithErrReader(cases []testErrReader) (*s3.Client, *[]string) { svc := s3.New(cfg) svc.Handlers.Send.Clear() - svc.Handlers.Send.PushBack(func(r *request.Request) { + svc.Handlers.Send.PushBack(func(r *aws.Request) { m.Lock() defer m.Unlock() @@ -297,7 +298,7 @@ func TestDownloadError(t *testing.T) { s, names, _ := dlLoggingSvc([]byte{1, 2, 3}) num := 0 - s.Handlers.Send.PushBack(func(r *request.Request) { + s.Handlers.Send.PushBack(func(r *aws.Request) { num++ if num > 1 { r.HTTPResponse.StatusCode = 400 @@ -544,7 +545,7 @@ func TestDownloadWithContextCanceled(t *testing.T) { t.Fatalf("expected error, did not get one") } aerr := err.(awserr.Error) - if e, a := request.ErrCodeRequestCanceled, aerr.Code(); e != a { + if e, a := aws.ErrCodeRequestCanceled, aerr.Code(); e != a { t.Errorf("expected error code %q, got %q", e, a) } if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) { @@ -592,7 +593,7 @@ func TestDownload_WithFailure(t *testing.T) { svc.Handlers.Send.Clear() first := true - svc.Handlers.Send.PushBack(func(r *request.Request) { + svc.Handlers.Send.PushBack(func(r *aws.Request) { if first { first = false body := bytes.NewReader(make([]byte, s3manager.DefaultDownloadPartSize)) @@ -644,6 +645,72 @@ func TestDownload_WithFailure(t *testing.T) { } } +func TestDownloadBufferStrategy(t *testing.T) { + cases := map[string]struct { + partSize int64 + strategy *recordedWriterReadFromProvider + expectedSize int64 + }{ + "no strategy": { + partSize: s3manager.DefaultDownloadPartSize, + expectedSize: 10 * sdkio.MebiByte, + }, + "partSize modulo bufferSize == 0": { + partSize: 5 * sdkio.MebiByte, + strategy: &recordedWriterReadFromProvider{ + WriterReadFromProvider: s3manager.NewPooledBufferedWriterReadFromProvider(int(sdkio.MebiByte)), // 1 MiB + }, + expectedSize: 10 * sdkio.MebiByte, // 10 MiB + }, + "partSize modulo bufferSize > 0": { + partSize: 5 * 1024 * 1204, // 5 MiB + strategy: &recordedWriterReadFromProvider{ + WriterReadFromProvider: s3manager.NewPooledBufferedWriterReadFromProvider(2 * int(sdkio.MebiByte)), // 2 MiB + }, + expectedSize: 10 * sdkio.MebiByte, // 10 MiB + }, + } + + for name, tCase := range cases { + t.Logf("starting case: %v", name) + + expected := getTestBytes(int(tCase.expectedSize)) + + svc, _, _ := dlLoggingSvc(expected) + + d := s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) { + d.PartSize = tCase.partSize + if tCase.strategy != nil { + d.BufferProvider = tCase.strategy + } + }) + + buffer := aws.NewWriteAtBuffer(make([]byte, len(expected))) + + n, err := d.Download(buffer, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + if err != nil { + t.Errorf("failed to download: %v", err) + } + + if e, a := len(expected), int(n); e != a { + t.Errorf("expected %v, got %v downloaded bytes", e, a) + } + + if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) { + t.Errorf("downloaded bytes did not match expected") + } + + if tCase.strategy != nil { + if e, a := tCase.strategy.callbacksVended, tCase.strategy.callbacksExecuted; e != a { + t.Errorf("expected %v, got %v", e, a) + } + } + } +} + type testErrReader struct { Buf []byte Err error @@ -665,3 +732,98 @@ func (r *testErrReader) Read(p []byte) (int, error) { return n, nil } + +func TestDownloadBufferStrategy_Errors(t *testing.T) { + expected := getTestBytes(int(10 * sdkio.MebiByte)) + + svc, _, _ := dlLoggingSvc(expected) + strat := &recordedWriterReadFromProvider{ + WriterReadFromProvider: s3manager.NewPooledBufferedWriterReadFromProvider(int(2 * sdkio.MebiByte)), + } + + d := s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) { + d.PartSize = 5 * sdkio.MebiByte + d.BufferProvider = strat + d.Concurrency = 1 + }) + + seenOps := make(map[string]struct{}) + svc.Handlers.Send.PushFront(func(*aws.Request) {}) + svc.Handlers.Send.AfterEachFn = func(item aws.HandlerListRunItem) bool { + r := item.Request + + if r.Operation.Name != "GetObject" { + return true + } + + input := r.Params.(*s3.GetObjectInput) + + fingerPrint := fmt.Sprintf("%s/%s/%s/%s", r.Operation.Name, *input.Bucket, *input.Key, *input.Range) + if _, ok := seenOps[fingerPrint]; ok { + return true + } + seenOps[fingerPrint] = struct{}{} + + regex := regexp.MustCompile(`bytes=(\d+)-(\d+)`) + rng := regex.FindStringSubmatch(*input.Range) + start, _ := strconv.ParseInt(rng[1], 10, 64) + fin, _ := strconv.ParseInt(rng[2], 10, 64) + + _, _ = io.Copy(ioutil.Discard, r.Body) + r.HTTPResponse = &http.Response{ + StatusCode: 200, + Body: aws.ReadSeekCloser(&badReader{err: io.ErrUnexpectedEOF}), + ContentLength: fin - start, + } + + return false + } + + buffer := aws.NewWriteAtBuffer(make([]byte, len(expected))) + + n, err := d.Download(buffer, &s3.GetObjectInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + }) + if err != nil { + t.Errorf("failed to download: %v", err) + } + + if e, a := len(expected), int(n); e != a { + t.Errorf("expected %v, got %v downloaded bytes", e, a) + } + + if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) { + t.Errorf("downloaded bytes did not match expected") + } + + if e, a := strat.callbacksVended, strat.callbacksExecuted; e != a { + t.Errorf("expected %v, got %v", e, a) + } +} + +type recordedWriterReadFromProvider struct { + callbacksVended uint32 + callbacksExecuted uint32 + s3manager.WriterReadFromProvider +} + +func (r *recordedWriterReadFromProvider) GetReadFrom(writer io.Writer) (s3manager.WriterReadFrom, func()) { + w, cleanup := r.WriterReadFromProvider.GetReadFrom(writer) + + atomic.AddUint32(&r.callbacksVended, 1) + return w, func() { + atomic.AddUint32(&r.callbacksExecuted, 1) + cleanup() + } +} + +type badReader struct { + err error +} + +func (b *badReader) Read(p []byte) (int, error) { + tb := getTestBytes(len(p)) + copy(p, tb) + return len(p), b.err +} diff --git a/service/s3/s3manager/exmaples_1_13_test.go b/service/s3/s3manager/exmaples_1_13_test.go new file mode 100644 index 00000000000..2ce84df4e78 --- /dev/null +++ b/service/s3/s3manager/exmaples_1_13_test.go @@ -0,0 +1,45 @@ +// +build go1.13 + +package s3manager_test + +import ( + "bytes" + "fmt" + "net/http" + "time" + + "github.com/aws/aws-sdk-go-v2/aws/external" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" +) + +// ExampleNewUploader_overrideTransport gives an example +// on how to override the default HTTP transport. This can +// be used to tune timeouts such as response headers, or +// write / read buffer usage (go1.13) when writing or reading respectively +// from the net/http transport. +func ExampleNewUploader_overrideTransport() { + // Create Transport + tr := &http.Transport{ + ResponseHeaderTimeout: 1 * time.Second, + WriteBufferSize: 1024 * 1024, + ReadBufferSize: 1024 * 1024, + } + + cfg, err := external.LoadDefaultAWSConfig(aws.Config{HTTPClient: &http.Client{Transport: tr}}) + if err != nil { + panic(fmt.Sprintf("failed to load SDK config: %v", err)) + } + + uploader := s3manager.NewUploader(cfg) + + _, err = uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String("examplebucket"), + Key: aws.String("largeobject"), + Body: bytes.NewReader([]byte("large_multi_part_upload")), + }) + if err != nil { + fmt.Println(err.Error()) + } +} diff --git a/service/s3/s3manager/exmaples_test.go b/service/s3/s3manager/exmaples_test.go new file mode 100644 index 00000000000..edcaa0759f6 --- /dev/null +++ b/service/s3/s3manager/exmaples_test.go @@ -0,0 +1,35 @@ +package s3manager_test + +import ( + "bytes" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws/external" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" +) + +// ExampleNewUploader_overrideReadSeekerProvider gives an example +// on a custom ReadSeekerWriteToProvider can be provided to Uploader +// to define how parts will be buffered in memory. +func ExampleNewUploader_overrideReadSeekerProvider() { + cfg, err := external.LoadDefaultAWSConfig() + if err != nil { + panic(fmt.Sprintf("failed to load SDK config: %v", err)) + } + + uploader := s3manager.NewUploader(cfg, func(u *s3manager.Uploader) { + // Define a strategy that will buffer 25 MiB in memory + u.BufferProvider = s3manager.NewBufferedReadSeekerWriteToPool(25 * 1024 * 1024) + }) + + _, err = uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String("examplebucket"), + Key: aws.String("largeobject"), + Body: bytes.NewReader([]byte("large_multi_part_upload")), + }) + if err != nil { + fmt.Println(err.Error()) + } +} diff --git a/service/s3/s3manager/read_seeker_write_to.go b/service/s3/s3manager/read_seeker_write_to.go new file mode 100644 index 00000000000..f62e1a45eef --- /dev/null +++ b/service/s3/s3manager/read_seeker_write_to.go @@ -0,0 +1,65 @@ +package s3manager + +import ( + "io" + "sync" +) + +// ReadSeekerWriteTo defines an interface implementing io.WriteTo and io.ReadSeeker +type ReadSeekerWriteTo interface { + io.ReadSeeker + io.WriterTo +} + +// BufferedReadSeekerWriteTo wraps a BufferedReadSeeker with an io.WriteAt +// implementation. +type BufferedReadSeekerWriteTo struct { + *BufferedReadSeeker +} + +// WriteTo writes to the given io.Writer from BufferedReadSeeker until there's no more data to write or +// an error occurs. Returns the number of bytes written and any error encountered during the write. +func (b *BufferedReadSeekerWriteTo) WriteTo(writer io.Writer) (int64, error) { + return io.Copy(writer, b.BufferedReadSeeker) +} + +// ReadSeekerWriteToProvider provides an implementation of io.WriteTo for an io.ReadSeeker +type ReadSeekerWriteToProvider interface { + GetWriteTo(seeker io.ReadSeeker) (r ReadSeekerWriteTo, cleanup func()) +} + +// BufferedReadSeekerWriteToPool uses a sync.Pool to create and reuse +// []byte slices for buffering parts in memory +type BufferedReadSeekerWriteToPool struct { + pool sync.Pool +} + +// NewBufferedReadSeekerWriteToPool will return a new BufferedReadSeekerWriteToPool that will create +// a pool of reusable buffers . If size is less then < 64 KiB then the buffer +// will default to 64 KiB. Reason: io.Copy from writers or readers that don't support io.WriteTo or io.ReadFrom +// respectively will default to copying 32 KiB. +func NewBufferedReadSeekerWriteToPool(size int) *BufferedReadSeekerWriteToPool { + if size < 65536 { + size = 65536 + } + + return &BufferedReadSeekerWriteToPool{ + pool: sync.Pool{New: func() interface{} { + return make([]byte, size) + }}, + } +} + +// GetWriteTo will wrap the provided io.ReadSeeker with a BufferedReadSeekerWriteTo. +// The provided cleanup must be called after operations have been completed on the +// returned io.ReadSeekerWriteTo in order to signal the return of resources to the pool. +func (p *BufferedReadSeekerWriteToPool) GetWriteTo(seeker io.ReadSeeker) (r ReadSeekerWriteTo, cleanup func()) { + buffer := p.pool.Get().([]byte) + + r = &BufferedReadSeekerWriteTo{BufferedReadSeeker: NewBufferedReadSeeker(seeker, buffer)} + cleanup = func() { + p.pool.Put(buffer) + } + + return r, cleanup +} diff --git a/service/s3/s3manager/shared_test.go b/service/s3/s3manager/shared_test.go index b5b61314336..d8a95e7c421 100644 --- a/service/s3/s3manager/shared_test.go +++ b/service/s3/s3manager/shared_test.go @@ -1,4 +1,28 @@ package s3manager_test +import ( + "math/rand" + + "github.com/aws/aws-sdk-go-v2/internal/sdkio" +) + var buf12MB = make([]byte, 1024*1024*12) var buf2MB = make([]byte, 1024*1024*2) + +var randBytes = func() []byte { + b := make([]byte, 10*sdkio.MebiByte) + + // always returns len(b) and nil error + _, _ = rand.Read(b) + + return b +}() + +func getTestBytes(size int) []byte { + if len(randBytes) >= size { + return randBytes[:size] + } + + b := append(randBytes, getTestBytes(size-len(randBytes))...) + return b +} diff --git a/service/s3/s3manager/upload.go b/service/s3/s3manager/upload.go index 9181f4d32c1..cc2f2dad607 100644 --- a/service/s3/s3manager/upload.go +++ b/service/s3/s3manager/upload.go @@ -159,6 +159,12 @@ type Uploader struct { // List of request options that will be passed down to individual API // operation requests made by the uploader. RequestOptions []request.Option + + // Defines the buffer strategy used when uploading a part + BufferProvider ReadSeekerWriteToProvider + + // partPool allows for the re-usage of streaming payload part buffers between upload calls + partPool *partPool } // NewUploader creates a new Uploader instance to upload objects to S3. Pass In @@ -176,18 +182,25 @@ type Uploader struct { // u.PartSize = 64 * 1024 * 1024 // 64MB per part // }) func NewUploader(cfg aws.Config, options ...func(*Uploader)) *Uploader { + return newUploader(s3.New(cfg), options...) +} + +func newUploader(client s3iface.ClientAPI, options ...func(*Uploader)) *Uploader { u := &Uploader{ - S3: s3.New(cfg), + S3: client, PartSize: DefaultUploadPartSize, Concurrency: DefaultUploadConcurrency, LeavePartsOnError: false, MaxUploadParts: MaxUploadParts, + BufferProvider: defaultUploadBufferProvider(), } for _, option := range options { option(u) } + u.partPool = newPartPool(u.PartSize) + return u } @@ -210,19 +223,7 @@ func NewUploader(cfg aws.Config, options ...func(*Uploader)) *Uploader { // u.PartSize = 64 * 1024 * 1024 // 64MB per part // }) func NewUploaderWithClient(svc s3iface.ClientAPI, options ...func(*Uploader)) *Uploader { - u := &Uploader{ - S3: svc, - PartSize: DefaultUploadPartSize, - Concurrency: DefaultUploadConcurrency, - LeavePartsOnError: false, - MaxUploadParts: MaxUploadParts, - } - - for _, option := range options { - option(u) - } - - return u + return newUploader(svc, options...) } // Upload uploads an object to S3, intelligently buffering large files into @@ -282,6 +283,7 @@ func (u Uploader) UploadWithContext(ctx context.Context, input *UploadInput, opt for _, opt := range opts { opt(&i.cfg) } + i.cfg.RequestOptions = append(i.cfg.RequestOptions, request.WithAppendUserAgent("S3Manager")) return i.upload() @@ -366,15 +368,16 @@ func (u *uploader) upload() (*UploadOutput, error) { } // Do one read to determine if we have more than one part - reader, _, err := u.nextReader() + reader, _, cleanup, err := u.nextReader() if err == io.EOF { // single part - return u.singlePart(reader) + return u.singlePart(reader, cleanup) } else if err != nil { + cleanup() return nil, awserr.New("ReadRequestBody", "read upload data failed", err) } mu := multiuploader{uploader: u} - return mu.upload(reader) + return mu.upload(reader, cleanup) } // init will initialize all default options. @@ -385,6 +388,15 @@ func (u *uploader) init() error { if u.cfg.PartSize == 0 { u.cfg.PartSize = DefaultUploadPartSize } + if u.cfg.MaxUploadParts == 0 { + u.cfg.MaxUploadParts = MaxUploadParts + } + + // If PartSize was changed or partPool was never setup then we need to allocated a new pool + // so that we return []byte slices of the correct size + if u.cfg.partPool == nil || u.cfg.partPool.partSize != u.cfg.PartSize { + u.cfg.partPool = newPartPool(u.cfg.PartSize) + } // Try to get the total size for some optimizations return u.initSize() @@ -418,7 +430,7 @@ func (u *uploader) initSize() error { // This operation increases the shared u.readerPos counter, but note that it // does not need to be wrapped in a mutex because nextReader is only called // from the main thread. -func (u *uploader) nextReader() (io.ReadSeeker, int, error) { +func (u *uploader) nextReader() (io.ReadSeeker, int, func(), error) { type readerAtSeeker interface { io.ReaderAt io.ReadSeeker @@ -437,13 +449,24 @@ func (u *uploader) nextReader() (io.ReadSeeker, int, error) { } } - reader := io.NewSectionReader(r, u.readerPos, n) + var ( + reader io.ReadSeeker + cleanup func() + ) + + reader = io.NewSectionReader(r, u.readerPos, n) + if u.cfg.BufferProvider != nil { + reader, cleanup = u.cfg.BufferProvider.GetWriteTo(reader) + } else { + cleanup = func() {} + } + u.readerPos += n - return reader, int(n), err + return reader, int(n), cleanup, err default: - part := make([]byte, u.cfg.PartSize) + part := u.cfg.partPool.Get().([]byte) n, err := readFillBuf(r, part) if n < 0 { if err == nil { @@ -455,7 +478,11 @@ func (u *uploader) nextReader() (io.ReadSeeker, int, error) { } u.readerPos += int64(n) - return bytes.NewReader(part[0:n]), n, err + cleanup := func() { + u.cfg.partPool.Put(part) + } + + return bytes.NewReader(part[0:n]), n, cleanup, err } } @@ -472,7 +499,9 @@ func readFillBuf(r io.Reader, b []byte) (offset int, err error) { // singlePart contains upload logic for uploading a single chunk via // a regular PutObject request. Multipart requests require at least two // parts, or at least 5MB of data. -func (u *uploader) singlePart(buf io.ReadSeeker) (*UploadOutput, error) { +func (u *uploader) singlePart(buf io.ReadSeeker, cleanup func()) (*UploadOutput, error) { + defer cleanup() + params := &s3.PutObjectInput{} awsutil.Copy(params, u.in) params.Body = buf @@ -505,8 +534,9 @@ type multiuploader struct { // keeps track of a single chunk of data being sent to S3. type chunk struct { - buf io.ReadSeeker - num int64 + buf io.ReadSeeker + num int64 + cleanup func() } // completedParts is a wrapper to make parts sortable by their part number, @@ -519,7 +549,7 @@ func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].Pa // upload will perform a multipart upload using the firstBuf buffer containing // the first chunk of data. -func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) { +func (u *multiuploader) upload(firstBuf io.ReadSeeker, cleanup func()) (*UploadOutput, error) { params := &s3.CreateMultipartUploadInput{} awsutil.Copy(params, u.in) @@ -541,45 +571,30 @@ func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) { // Send part 1 to the workers var num int64 = 1 - ch <- chunk{buf: firstBuf, num: num} + ch <- chunk{buf: firstBuf, num: num, cleanup: cleanup} // Read and queue the rest of the parts for u.geterr() == nil && err == nil { - var reader io.ReadSeeker - var nextChunkLen int - reader, nextChunkLen, err = u.nextReader() - - if err != nil && err != io.EOF { - u.seterr(awserr.New( - "ReadRequestBody", - "read multipart upload data failed", - err)) - break - } + var ( + reader io.ReadSeeker + nextChunkLen int + ok bool + ) - if nextChunkLen == 0 { - // No need to upload empty part, if file was empty to start - // with empty single part would of been created and never - // started multipart upload. - break - } + reader, nextChunkLen, cleanup, err = u.nextReader() - num++ - // This upload exceeded maximum number of supported parts, error now. - if num > int64(u.cfg.MaxUploadParts) || num > int64(MaxUploadParts) { - var msg string - if num > int64(u.cfg.MaxUploadParts) { - msg = fmt.Sprintf("exceeded total allowed configured MaxUploadParts (%d). Adjust PartSize to fit in this limit", - u.cfg.MaxUploadParts) - } else { - msg = fmt.Sprintf("exceeded total allowed S3 limit MaxUploadParts (%d). Adjust PartSize to fit in this limit", - MaxUploadParts) + ok, err = u.shouldContinue(num, nextChunkLen, err) + if !ok { + cleanup() + if err != nil { + u.seterr(err) } - u.seterr(awserr.New("TotalPartsExceeded", msg, nil)) break } - ch <- chunk{buf: reader, num: num} + num++ + + ch <- chunk{buf: reader, num: num, cleanup: cleanup} } // Close the channel, wait for workers, and complete upload @@ -613,6 +628,35 @@ func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) { }, nil } +func (u *multiuploader) shouldContinue(part int64, nextChunkLen int, err error) (bool, error) { + if err != nil && err != io.EOF { + return false, awserr.New("ReadRequestBody", "read multipart upload data failed", err) + } + + if nextChunkLen == 0 { + // No need to upload empty part, if file was empty to start + // with empty single part would of been created and never + // started multipart upload. + return false, nil + } + + part++ + // This upload exceeded maximum number of supported parts, error now. + if part > int64(u.cfg.MaxUploadParts) || part > int64(MaxUploadParts) { + var msg string + if part > int64(u.cfg.MaxUploadParts) { + msg = fmt.Sprintf("exceeded total allowed configured MaxUploadParts (%d). Adjust PartSize to fit in this limit", + u.cfg.MaxUploadParts) + } else { + msg = fmt.Sprintf("exceeded total allowed S3 limit MaxUploadParts (%d). Adjust PartSize to fit in this limit", + MaxUploadParts) + } + return false, awserr.New("TotalPartsExceeded", msg, nil) + } + + return true, err +} + // readChunk runs in worker goroutines to pull chunks off of the ch channel // and send() them as UploadPart requests. func (u *multiuploader) readChunk(ch chan chunk) { @@ -647,6 +691,7 @@ func (u *multiuploader) send(c chunk) error { req := u.cfg.S3.UploadPartRequest(params) req.ApplyOptions(u.cfg.RequestOptions...) resp, err := req.Send(u.ctx) + c.cleanup() if err != nil { return err } @@ -722,3 +767,18 @@ func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput { return resp.CompleteMultipartUploadOutput } + +type partPool struct { + partSize int64 + sync.Pool +} + +func newPartPool(partSize int64) *partPool { + p := &partPool{partSize: partSize} + + p.New = func() interface{} { + return make([]byte, p.partSize) + } + + return p +} diff --git a/service/s3/s3manager/upload_test.go b/service/s3/s3manager/upload_test.go index a9676cad0ec..43a79276e08 100644 --- a/service/s3/s3manager/upload_test.go +++ b/service/s3/s3manager/upload_test.go @@ -12,10 +12,10 @@ import ( "sort" "strings" "sync" + "sync/atomic" "testing" "github.com/aws/aws-sdk-go-v2/aws" - request "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/awserr" "github.com/aws/aws-sdk-go-v2/internal/awstesting" "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" @@ -61,7 +61,7 @@ func loggingSvc(ignoreOps []string) (*s3.Client, *[]string, *[]interface{}) { svc.Handlers.UnmarshalMeta.Clear() svc.Handlers.UnmarshalError.Clear() svc.Handlers.Send.Clear() - svc.Handlers.Send.PushBack(func(r *request.Request) { + svc.Handlers.Send.PushBack(func(r *aws.Request) { m.Lock() defer m.Unlock() @@ -353,7 +353,7 @@ func TestUploadOrderSingle(t *testing.T) { func TestUploadOrderSingleFailure(t *testing.T) { s, ops, _ := loggingSvc(emptyList) - s.Handlers.Send.PushBack(func(r *request.Request) { + s.Handlers.Send.PushBack(func(r *aws.Request) { r.HTTPResponse.StatusCode = 400 }) mgr := s3manager.NewUploaderWithClient(s) @@ -408,7 +408,7 @@ func TestUploadOrderZero(t *testing.T) { func TestUploadOrderMultiFailure(t *testing.T) { s, ops, _ := loggingSvc(emptyList) - s.Handlers.Send.PushBack(func(r *request.Request) { + s.Handlers.Send.PushBack(func(r *aws.Request) { switch t := r.Data.(type) { case *s3.UploadPartOutput: if *t.ETag == "ETAG2" { @@ -437,7 +437,7 @@ func TestUploadOrderMultiFailure(t *testing.T) { func TestUploadOrderMultiFailureOnComplete(t *testing.T) { s, ops, _ := loggingSvc(emptyList) - s.Handlers.Send.PushBack(func(r *request.Request) { + s.Handlers.Send.PushBack(func(r *aws.Request) { switch r.Data.(type) { case *s3.CompleteMultipartUploadOutput: r.HTTPResponse.StatusCode = 400 @@ -465,7 +465,7 @@ func TestUploadOrderMultiFailureOnComplete(t *testing.T) { func TestUploadOrderMultiFailureOnCreate(t *testing.T) { s, ops, _ := loggingSvc(emptyList) - s.Handlers.Send.PushBack(func(r *request.Request) { + s.Handlers.Send.PushBack(func(r *aws.Request) { switch r.Data.(type) { case *s3.CreateMultipartUploadOutput: r.HTTPResponse.StatusCode = 400 @@ -490,7 +490,7 @@ func TestUploadOrderMultiFailureOnCreate(t *testing.T) { func TestUploadOrderMultiFailureLeaveParts(t *testing.T) { s, ops, _ := loggingSvc(emptyList) - s.Handlers.Send.PushBack(func(r *request.Request) { + s.Handlers.Send.PushBack(func(r *aws.Request) { switch data := r.Data.(type) { case *s3.UploadPartOutput: if *data.ETag == "ETAG2" { @@ -601,7 +601,7 @@ func (s *sizedReader) Read(p []byte) (n int, err error) { n -= s.cur - s.size } - return + return n, err } func TestUploadOrderMultiBufferedReader(t *testing.T) { @@ -926,7 +926,7 @@ func TestReaderAt(t *testing.T) { svc.Handlers.Send.Clear() contentLen := "" - svc.Handlers.Send.PushBack(func(r *request.Request) { + svc.Handlers.Send.PushBack(func(r *aws.Request) { contentLen = r.HTTPRequest.Header.Get("Content-Length") r.HTTPResponse = &http.Response{ StatusCode: 200, @@ -963,7 +963,7 @@ func TestSSE(t *testing.T) { partNum := 0 mutex := &sync.Mutex{} - svc.Handlers.Send.PushBack(func(r *request.Request) { + svc.Handlers.Send.PushBack(func(r *aws.Request) { mutex.Lock() defer mutex.Unlock() r.HTTPResponse = &http.Response{ @@ -1027,7 +1027,7 @@ func TestUploadWithContextCanceled(t *testing.T) { t.Fatalf("expected error, did not get one") } aerr := err.(awserr.Error) - if e, a := request.ErrCodeRequestCanceled, aerr.Code(); e != a { + if e, a := aws.ErrCodeRequestCanceled, aerr.Code(); e != a { t.Errorf("expected error code %q, got %q", e, a) } if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) { @@ -1070,3 +1070,107 @@ func TestUploadMaxPartsEOF(t *testing.T) { t.Errorf("expect %v ops, got %v", e, a) } } + +func TestUploadBufferStrategy(t *testing.T) { + cases := map[string]struct { + PartSize int64 + Size int64 + Strategy s3manager.ReadSeekerWriteToProvider + callbacks int + }{ + "NoBuffer": { + PartSize: s3manager.DefaultUploadPartSize, + Strategy: nil, + }, + "SinglePart": { + PartSize: s3manager.DefaultUploadPartSize, + Size: s3manager.DefaultUploadPartSize, + Strategy: &recordedBufferProvider{size: int(s3manager.DefaultUploadPartSize)}, + callbacks: 1, + }, + "MultiPart": { + PartSize: s3manager.DefaultUploadPartSize, + Size: s3manager.DefaultUploadPartSize * 2, + Strategy: &recordedBufferProvider{size: int(s3manager.DefaultUploadPartSize)}, + callbacks: 2, + }, + } + + for name, tCase := range cases { + t.Run(name, func(t *testing.T) { + _ = tCase + cfg := unit.Config() + svc := s3.New(cfg) + svc.Handlers.Unmarshal.Clear() + svc.Handlers.UnmarshalMeta.Clear() + svc.Handlers.UnmarshalError.Clear() + svc.Handlers.Send.Clear() + + var etag int64 + svc.Handlers.Send.PushBack(func(r *aws.Request) { + if r.Body != nil { + io.Copy(ioutil.Discard, r.Body) + } + + r.HTTPResponse = &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + } + + switch data := r.Data.(type) { + case *s3.CreateMultipartUploadOutput: + data.UploadId = aws.String("UPLOAD-ID") + case *s3.UploadPartOutput: + data.ETag = aws.String(fmt.Sprintf("ETAG%d", atomic.AddInt64(&etag, 1))) + case *s3.CompleteMultipartUploadOutput: + data.Location = aws.String("https://location") + data.VersionId = aws.String("VERSION-ID") + case *s3.PutObjectOutput: + data.VersionId = aws.String("VERSION-ID") + } + }) + + uploader := s3manager.NewUploaderWithClient(svc, func(u *s3manager.Uploader) { + u.PartSize = tCase.PartSize + u.BufferProvider = tCase.Strategy + u.Concurrency = 1 + }) + + expected := getTestBytes(int(tCase.Size)) + _, err := uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String("bucket"), + Key: aws.String("key"), + Body: bytes.NewReader(expected), + }) + if err != nil { + t.Fatalf("failed to upload file: %v", err) + } + + switch strat := tCase.Strategy.(type) { + case *recordedBufferProvider: + if !bytes.Equal(expected, strat.content) { + t.Errorf("content buffered did not match expected") + } + if tCase.callbacks != strat.callbackCount { + t.Errorf("expected %v, got %v callbacks", tCase.callbacks, strat.callbackCount) + } + } + }) + } +} + +type recordedBufferProvider struct { + content []byte + size int + callbackCount int +} + +func (r *recordedBufferProvider) GetWriteTo(seeker io.ReadSeeker) (s3manager.ReadSeekerWriteTo, func()) { + b := make([]byte, r.size) + w := &s3manager.BufferedReadSeekerWriteTo{BufferedReadSeeker: s3manager.NewBufferedReadSeeker(seeker, b)} + + return w, func() { + r.content = append(r.content, b...) + r.callbackCount++ + } +} diff --git a/service/s3/s3manager/writer_read_from.go b/service/s3/s3manager/writer_read_from.go new file mode 100644 index 00000000000..f3aa15ab92e --- /dev/null +++ b/service/s3/s3manager/writer_read_from.go @@ -0,0 +1,75 @@ +package s3manager + +import ( + "bufio" + "io" + "sync" + + "github.com/aws/aws-sdk-go-v2/internal/sdkio" +) + +// WriterReadFrom defines an interface implementing io.Writer and io.ReaderFrom +type WriterReadFrom interface { + io.Writer + io.ReaderFrom +} + +// WriterReadFromProvider provides an implementation of io.ReadFrom for the given io.Writer +type WriterReadFromProvider interface { + GetReadFrom(writer io.Writer) (w WriterReadFrom, cleanup func()) +} + +type bufferedWriter interface { + WriterReadFrom + Flush() error + Reset(io.Writer) +} + +type bufferedReadFrom struct { + bufferedWriter +} + +func (b *bufferedReadFrom) ReadFrom(r io.Reader) (int64, error) { + n, err := b.bufferedWriter.ReadFrom(r) + if flushErr := b.Flush(); flushErr != nil && err == nil { + err = flushErr + } + return n, err +} + +// PooledBufferedReadFromProvider is a WriterReadFromProvider that uses a sync.Pool +// to manage allocation and reuse of *bufio.Writer structures. +type PooledBufferedReadFromProvider struct { + pool sync.Pool +} + +// NewPooledBufferedWriterReadFromProvider returns a new PooledBufferedReadFromProvider +// Size is used to control the size of the underlying *bufio.Writer created for +// calls to GetReadFrom. +func NewPooledBufferedWriterReadFromProvider(size int) *PooledBufferedReadFromProvider { + if size < int(32*sdkio.KibiByte) { + size = int(64 * sdkio.KibiByte) + } + + return &PooledBufferedReadFromProvider{ + pool: sync.Pool{ + New: func() interface{} { + return &bufferedReadFrom{bufferedWriter: bufio.NewWriterSize(nil, size)} + }, + }, + } +} + +// GetReadFrom takes an io.Writer and wraps it with a type which satisfies the WriterReadFrom +// interface/ Additionally a cleanup function is provided which must be called after usage of the WriterReadFrom +// has been completed in order to allow the reuse of the *bufio.Writer +func (p *PooledBufferedReadFromProvider) GetReadFrom(writer io.Writer) (r WriterReadFrom, cleanup func()) { + buffer := p.pool.Get().(*bufferedReadFrom) + buffer.Reset(writer) + r = buffer + cleanup = func() { + buffer.Reset(nil) // Reset to nil writer to release reference + p.pool.Put(buffer) + } + return r, cleanup +} diff --git a/service/s3/s3manager/writer_read_from_test.go b/service/s3/s3manager/writer_read_from_test.go new file mode 100644 index 00000000000..6a5196e60e8 --- /dev/null +++ b/service/s3/s3manager/writer_read_from_test.go @@ -0,0 +1,73 @@ +package s3manager + +import ( + "fmt" + "io" + "reflect" + "testing" +) + +type testBufioWriter struct { + ReadFromN int64 + ReadFromErr error + FlushReturn error +} + +func (t testBufioWriter) Write(p []byte) (n int, err error) { + panic("unused") +} + +func (t testBufioWriter) ReadFrom(r io.Reader) (n int64, err error) { + return t.ReadFromN, t.ReadFromErr +} + +func (t testBufioWriter) Flush() error { + return t.FlushReturn +} + +func (t *testBufioWriter) Reset(io.Writer) { + panic("unused") +} + +func TestBufferedReadFromFlusher_ReadFrom(t *testing.T) { + cases := map[string]struct { + w testBufioWriter + expectedErr error + }{ + "no errors": {}, + "error returned from underlying ReadFrom": { + w: testBufioWriter{ + ReadFromN: 42, + ReadFromErr: fmt.Errorf("readfrom"), + }, + expectedErr: fmt.Errorf("readfrom"), + }, + "error returned from Flush": { + w: testBufioWriter{ + ReadFromN: 7, + FlushReturn: fmt.Errorf("flush"), + }, + expectedErr: fmt.Errorf("flush"), + }, + "error returned from ReadFrom and Flush": { + w: testBufioWriter{ + ReadFromN: 1337, + ReadFromErr: fmt.Errorf("readfrom"), + FlushReturn: fmt.Errorf("flush"), + }, + expectedErr: fmt.Errorf("readfrom"), + }, + } + + for name, tCase := range cases { + t.Log(name) + readFromFlusher := bufferedReadFrom{bufferedWriter: &tCase.w} + n, err := readFromFlusher.ReadFrom(nil) + if e, a := tCase.w.ReadFromN, n; e != a { + t.Errorf("expected %v bytes, got %v", e, a) + } + if e, a := tCase.expectedErr, err; !reflect.DeepEqual(e, a) { + t.Errorf("expected error %v. got %v", e, a) + } + } +}