diff --git a/go/store/nbs/aws_table_persister.go b/go/store/nbs/aws_table_persister.go index 920a04d8ff2..400cb3fdf97 100644 --- a/go/store/nbs/aws_table_persister.go +++ b/go/store/nbs/aws_table_persister.go @@ -34,6 +34,8 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/dolthub/dolt/go/store/atomicerr" "github.com/dolthub/dolt/go/store/chunks" @@ -55,7 +57,7 @@ const ( ) type awsTablePersister struct { - s3 s3svc + s3 s3iface.S3API bucket string rl chan struct{} ddb *ddbTableStore @@ -120,21 +122,20 @@ func (s3p awsTablePersister) CopyTableFile(ctx context.Context, r io.ReadCloser, } }() - data, err := io.ReadAll(r) - if err != nil { - return err - } - name, err := parseAddr(fileId) if err != nil { return err } - if s3p.limits.tableFitsInDynamo(name, len(data), chunkCount) { + if s3p.limits.tableFitsInDynamo(name, int(fileSz), chunkCount) { + data, err := io.ReadAll(r) + if err != nil { + return err + } return s3p.ddb.Write(ctx, name, data) } - return s3p.multipartUpload(ctx, data, fileId) + return s3p.multipartUpload(ctx, r, fileSz, fileId) } func (s3p awsTablePersister) Path() string { @@ -174,7 +175,7 @@ func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver ch return newReaderFromIndexData(ctx, s3p.q, data, name, &dynamoTableReaderAt{ddb: s3p.ddb, h: name}, s3BlockSize) } - err = s3p.multipartUpload(ctx, data, name.String()) + err = s3p.multipartUpload(ctx, bytes.NewReader(data), uint64(len(data)), name.String()) if err != nil { return emptyChunkSource{}, err @@ -184,20 +185,16 @@ func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver ch return newReaderFromIndexData(ctx, s3p.q, data, name, tra, s3BlockSize) } -func (s3p awsTablePersister) multipartUpload(ctx context.Context, data []byte, key string) error { - uploadID, err := s3p.startMultipartUpload(ctx, key) - - if err != nil { - return err - } - - multipartUpload, err := s3p.uploadParts(ctx, data, key, uploadID) - if err != nil { - _ = s3p.abortMultipartUpload(ctx, key, uploadID) - return err - } - - return s3p.completeMultipartUpload(ctx, key, uploadID, multipartUpload) +func (s3p awsTablePersister) multipartUpload(ctx context.Context, r io.Reader, sz uint64, key string) error { + uploader := s3manager.NewUploaderWithClient(s3p.s3, func(u *s3manager.Uploader) { + u.PartSize = int64(s3p.limits.partTarget) + }) + _, err := uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String(s3p.bucket), + Key: aws.String(s3p.key(key)), + Body: r, + }) + return err } func (s3p awsTablePersister) startMultipartUpload(ctx context.Context, key string) (string, error) { @@ -234,85 +231,6 @@ func (s3p awsTablePersister) completeMultipartUpload(ctx context.Context, key, u return err } -func (s3p awsTablePersister) uploadParts(ctx context.Context, data []byte, key, uploadID string) (*s3.CompletedMultipartUpload, error) { - sent, failed, done := make(chan s3UploadedPart), make(chan error), make(chan struct{}) - - numParts := getNumParts(uint64(len(data)), s3p.limits.partTarget) - - if numParts > maxS3Parts { - return nil, errors.New("exceeded maximum parts") - } - - var wg sync.WaitGroup - sendPart := func(partNum, start, end uint64) { - if s3p.rl != nil { - s3p.rl <- struct{}{} - defer func() { <-s3p.rl }() - } - defer wg.Done() - - // Check if upload has been terminated - select { - case <-done: - return - default: - } - // Upload the desired part - if partNum == numParts { // If this is the last part, make sure it includes any overflow - end = uint64(len(data)) - } - etag, err := s3p.uploadPart(ctx, data[start:end], key, uploadID, int64(partNum)) - if err != nil { - failed <- err - return - } - // Try to send along part info. In the case that the upload was aborted, reading from done allows this worker to exit correctly. - select { - case sent <- s3UploadedPart{int64(partNum), etag}: - case <-done: - return - } - } - for i := uint64(0); i < numParts; i++ { - wg.Add(1) - partNum := i + 1 // Parts are 1-indexed - start, end := i*s3p.limits.partTarget, (i+1)*s3p.limits.partTarget - go sendPart(partNum, start, end) - } - go func() { - wg.Wait() - close(sent) - close(failed) - }() - - multipartUpload := &s3.CompletedMultipartUpload{} - var firstFailure error - for cont := true; cont; { - select { - case sentPart, open := <-sent: - if open { - multipartUpload.Parts = append(multipartUpload.Parts, &s3.CompletedPart{ - ETag: aws.String(sentPart.etag), - PartNumber: aws.Int64(sentPart.idx), - }) - } - cont = open - - case err := <-failed: - if err != nil && firstFailure == nil { // nil err may happen when failed gets closed - firstFailure = err - close(done) - } - } - } - - if firstFailure == nil { - close(done) - } - sort.Sort(partsByPartNum(multipartUpload.Parts)) - return multipartUpload, firstFailure -} - func getNumParts(dataLen, minPartSize uint64) uint64 { numParts := dataLen / minPartSize if numParts == 0 { diff --git a/go/store/nbs/aws_table_persister_test.go b/go/store/nbs/aws_table_persister_test.go index e2cdc930f61..f0b8aa2a29b 100644 --- a/go/store/nbs/aws_table_persister_test.go +++ b/go/store/nbs/aws_table_persister_test.go @@ -23,6 +23,7 @@ package nbs import ( "context" + crand "crypto/rand" "io" "math/rand" "sync" @@ -31,30 +32,66 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dolthub/dolt/go/store/util/sizecache" ) +func randomChunks(t *testing.T, r *rand.Rand, sz int) [][]byte { + buf := make([]byte, sz) + _, err := io.ReadFull(crand.Reader, buf) + require.NoError(t, err) + + var ret [][]byte + var i int + for i < len(buf) { + j := int(r.NormFloat64()*1024 + 4096) + if i+j >= len(buf) { + ret = append(ret, buf[i:]) + } else { + ret = append(ret, buf[i:i+j]) + } + i += j + } + + return ret +} + +func TestRandomChunks(t *testing.T) { + r := rand.New(rand.NewSource(1024)) + res := randomChunks(t, r, 10) + assert.Len(t, res, 1) + res = randomChunks(t, r, 4096+2048) + assert.Len(t, res, 2) + res = randomChunks(t, r, 4096+4096) + assert.Len(t, res, 3) +} + func TestAWSTablePersisterPersist(t *testing.T) { ctx := context.Background() calcPartSize := func(rdr chunkReader, maxPartNum uint64) uint64 { return maxTableSize(uint64(mustUint32(rdr.count())), mustUint64(rdr.uncompressedLen())) / maxPartNum } - mt := newMemTable(testMemTableSize) + r := rand.New(rand.NewSource(1024)) + const sz15mb = 1 << 20 * 15 + mt := newMemTable(sz15mb) + testChunks := randomChunks(t, r, 1<<20*12) for _, c := range testChunks { assert.Equal(t, mt.addChunk(computeAddr(c), c), chunkAdded) } + var limits5mb = awsLimits{partTarget: 1 << 20 * 5} + var limits64mb = awsLimits{partTarget: 1 << 20 * 64} + t.Run("PersistToS3", func(t *testing.T) { testIt := func(t *testing.T, ns string) { t.Run("InMultipleParts", func(t *testing.T) { assert := assert.New(t) s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil) - limits := awsLimits{partTarget: calcPartSize(mt, 3)} - s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}} + s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} src, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) require.NoError(t, err) @@ -72,8 +109,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { assert := assert.New(t) s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil) - limits := awsLimits{partTarget: calcPartSize(mt, 1)} - s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}} + s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits64mb, ns: ns, q: &UnlimitedQuotaProvider{}} src, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) require.NoError(t, err) @@ -89,8 +125,8 @@ func TestAWSTablePersisterPersist(t *testing.T) { t.Run("NoNewChunks", func(t *testing.T) { assert := assert.New(t) - mt := newMemTable(testMemTableSize) - existingTable := newMemTable(testMemTableSize) + mt := newMemTable(sz15mb) + existingTable := newMemTable(sz15mb) for _, c := range testChunks { assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded) @@ -98,8 +134,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { } s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil) - limits := awsLimits{partTarget: 1 << 10} - s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}} + s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} src, err := s3p.Persist(context.Background(), mt, existingTable, &Stats{}) require.NoError(t, err) @@ -115,8 +150,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { s3svc := &failingFakeS3{makeFakeS3(t), sync.Mutex{}, 1} ddb := makeFakeDTS(makeFakeDDB(t), nil) - limits := awsLimits{partTarget: calcPartSize(mt, 4)} - s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}} + s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} _, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) assert.Error(err) @@ -330,14 +364,15 @@ func TestAWSTablePersisterCalcPartSizes(t *testing.T) { func TestAWSTablePersisterConjoinAll(t *testing.T) { ctx := context.Background() - targetPartSize := uint64(1024) + const sz5mb = 1 << 20 * 5 + targetPartSize := uint64(sz5mb) minPartSize, maxPartSize := targetPartSize, 5*targetPartSize maxItemSize, maxChunkCount := int(targetPartSize/2), uint32(4) rl := make(chan struct{}, 8) defer close(rl) - newPersister := func(s3svc s3svc, ddb *ddbTableStore) awsTablePersister { + newPersister := func(s3svc s3iface.S3API, ddb *ddbTableStore) awsTablePersister { return awsTablePersister{ s3svc, "bucket", diff --git a/go/store/nbs/s3_fake_test.go b/go/store/nbs/s3_fake_test.go index fe48bcea9cf..d05be408dbb 100644 --- a/go/store/nbs/s3_fake_test.go +++ b/go/store/nbs/s3_fake_test.go @@ -33,8 +33,10 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client/metadata" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/stretchr/testify/assert" "github.com/dolthub/dolt/go/store/d" @@ -58,6 +60,8 @@ func makeFakeS3(t *testing.T) *fakeS3 { } type fakeS3 struct { + s3iface.S3API + assert *assert.Assertions mu sync.Mutex @@ -279,3 +283,37 @@ func (m *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, return &s3.PutObjectOutput{}, nil } + +func (m *fakeS3) GetObjectRequest(input *s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput) { + out := &s3.GetObjectOutput{} + var handlers request.Handlers + handlers.Send.PushBack(func(r *request.Request) { + res, err := m.GetObjectWithContext(r.Context(), input) + r.Error = err + if res != nil { + *(r.Data.(*s3.GetObjectOutput)) = *res + } + }) + return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{ + Name: "GetObject", + HTTPMethod: "GET", + HTTPPath: "/{Bucket}/{Key+}", + }, input, out), out +} + +func (m *fakeS3) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) { + out := &s3.PutObjectOutput{} + var handlers request.Handlers + handlers.Send.PushBack(func(r *request.Request) { + res, err := m.PutObjectWithContext(r.Context(), input) + r.Error = err + if res != nil { + *(r.Data.(*s3.PutObjectOutput)) = *res + } + }) + return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{ + Name: "PutObject", + HTTPMethod: "PUT", + HTTPPath: "/{Bucket}/{Key+}", + }, input, out), out +} diff --git a/go/store/nbs/s3_table_reader.go b/go/store/nbs/s3_table_reader.go index 179d5b9a38b..05db80deb0e 100644 --- a/go/store/nbs/s3_table_reader.go +++ b/go/store/nbs/s3_table_reader.go @@ -34,8 +34,8 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/jpillora/backoff" "golang.org/x/sync/errgroup" ) @@ -50,16 +50,6 @@ type s3TableReaderAt struct { h addr } -type s3svc interface { - AbortMultipartUploadWithContext(ctx aws.Context, input *s3.AbortMultipartUploadInput, opts ...request.Option) (*s3.AbortMultipartUploadOutput, error) - CreateMultipartUploadWithContext(ctx aws.Context, input *s3.CreateMultipartUploadInput, opts ...request.Option) (*s3.CreateMultipartUploadOutput, error) - UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) - UploadPartCopyWithContext(ctx aws.Context, input *s3.UploadPartCopyInput, opts ...request.Option) (*s3.UploadPartCopyOutput, error) - CompleteMultipartUploadWithContext(ctx aws.Context, input *s3.CompleteMultipartUploadInput, opts ...request.Option) (*s3.CompleteMultipartUploadOutput, error) - GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) - PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) -} - func (s3tra *s3TableReaderAt) Close() error { return nil } @@ -78,7 +68,7 @@ func (s3tra *s3TableReaderAt) ReadAtWithStats(ctx context.Context, p []byte, off // TODO: Bring all the multipart upload and remote-conjoin stuff over here and make this a better analogue to ddbTableStore type s3ObjectReader struct { - s3 s3svc + s3 s3iface.S3API bucket string readRl chan struct{} ns string diff --git a/go/store/nbs/s3_table_reader_test.go b/go/store/nbs/s3_table_reader_test.go index 264222b16b9..787353df638 100644 --- a/go/store/nbs/s3_table_reader_test.go +++ b/go/store/nbs/s3_table_reader_test.go @@ -32,6 +32,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -88,16 +89,16 @@ func TestS3TableReaderAtNamespace(t *testing.T) { } type flakyS3 struct { - s3svc + s3iface.S3API alreadyFailed map[string]struct{} } -func makeFlakyS3(svc s3svc) *flakyS3 { +func makeFlakyS3(svc s3iface.S3API) *flakyS3 { return &flakyS3{svc, map[string]struct{}{}} } func (fs3 *flakyS3) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) { - output, err := fs3.s3svc.GetObjectWithContext(ctx, input) + output, err := fs3.S3API.GetObjectWithContext(ctx, input) if err != nil { return nil, err diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index 5f1ee8172f8..a63bc59e776 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -33,6 +33,7 @@ import ( "time" "cloud.google.com/go/storage" + "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/dustin/go-humanize" lru "github.com/hashicorp/golang-lru/v2" "github.com/pkg/errors" @@ -483,7 +484,7 @@ func OverwriteStoreManifest(ctx context.Context, store *NomsBlockStore, root has return nil } -func NewAWSStoreWithMMapIndex(ctx context.Context, nbfVerStr string, table, ns, bucket string, s3 s3svc, ddb ddbsvc, memTableSize uint64, q MemoryQuotaProvider) (*NomsBlockStore, error) { +func NewAWSStoreWithMMapIndex(ctx context.Context, nbfVerStr string, table, ns, bucket string, s3 s3iface.S3API, ddb ddbsvc, memTableSize uint64, q MemoryQuotaProvider) (*NomsBlockStore, error) { cacheOnce.Do(makeGlobalCaches) readRateLimiter := make(chan struct{}, 32) p := &awsTablePersister{ @@ -499,7 +500,7 @@ func NewAWSStoreWithMMapIndex(ctx context.Context, nbfVerStr string, table, ns, return newNomsBlockStore(ctx, nbfVerStr, mm, p, q, inlineConjoiner{defaultMaxTables}, memTableSize) } -func NewAWSStore(ctx context.Context, nbfVerStr string, table, ns, bucket string, s3 s3svc, ddb ddbsvc, memTableSize uint64, q MemoryQuotaProvider) (*NomsBlockStore, error) { +func NewAWSStore(ctx context.Context, nbfVerStr string, table, ns, bucket string, s3 s3iface.S3API, ddb ddbsvc, memTableSize uint64, q MemoryQuotaProvider) (*NomsBlockStore, error) { cacheOnce.Do(makeGlobalCaches) readRateLimiter := make(chan struct{}, 32) p := &awsTablePersister{