Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remotes: AWS: Fix a bug where uploading tables files to S3 could have unbounded memory usage. #6883

Merged
merged 3 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 20 additions & 102 deletions go/store/nbs/aws_table_persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/s3/s3manager"

"github.com/dolthub/dolt/go/store/atomicerr"
"github.com/dolthub/dolt/go/store/chunks"
Expand All @@ -55,7 +57,7 @@ const (
)

type awsTablePersister struct {
s3 s3svc
s3 s3iface.S3API
bucket string
rl chan struct{}
ddb *ddbTableStore
Expand Down Expand Up @@ -120,21 +122,20 @@ func (s3p awsTablePersister) CopyTableFile(ctx context.Context, r io.ReadCloser,
}
}()

data, err := io.ReadAll(r)
if err != nil {
return err
}

name, err := parseAddr(fileId)
if err != nil {
return err
}

if s3p.limits.tableFitsInDynamo(name, len(data), chunkCount) {
if s3p.limits.tableFitsInDynamo(name, int(fileSz), chunkCount) {
data, err := io.ReadAll(r)
if err != nil {
return err
}
return s3p.ddb.Write(ctx, name, data)
}

return s3p.multipartUpload(ctx, data, fileId)
return s3p.multipartUpload(ctx, r, fileSz, fileId)
}

func (s3p awsTablePersister) Path() string {
Expand Down Expand Up @@ -174,7 +175,7 @@ func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver ch
return newReaderFromIndexData(ctx, s3p.q, data, name, &dynamoTableReaderAt{ddb: s3p.ddb, h: name}, s3BlockSize)
}

err = s3p.multipartUpload(ctx, data, name.String())
err = s3p.multipartUpload(ctx, bytes.NewReader(data), uint64(len(data)), name.String())

if err != nil {
return emptyChunkSource{}, err
Expand All @@ -184,20 +185,16 @@ func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver ch
return newReaderFromIndexData(ctx, s3p.q, data, name, tra, s3BlockSize)
}

func (s3p awsTablePersister) multipartUpload(ctx context.Context, data []byte, key string) error {
uploadID, err := s3p.startMultipartUpload(ctx, key)

if err != nil {
return err
}

multipartUpload, err := s3p.uploadParts(ctx, data, key, uploadID)
if err != nil {
_ = s3p.abortMultipartUpload(ctx, key, uploadID)
return err
}

return s3p.completeMultipartUpload(ctx, key, uploadID, multipartUpload)
func (s3p awsTablePersister) multipartUpload(ctx context.Context, r io.Reader, sz uint64, key string) error {
uploader := s3manager.NewUploaderWithClient(s3p.s3, func(u *s3manager.Uploader) {
u.PartSize = int64(s3p.limits.partTarget)
})
_, err := uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(s3p.bucket),
Key: aws.String(s3p.key(key)),
Body: r,
})
return err
}

func (s3p awsTablePersister) startMultipartUpload(ctx context.Context, key string) (string, error) {
Expand Down Expand Up @@ -234,85 +231,6 @@ func (s3p awsTablePersister) completeMultipartUpload(ctx context.Context, key, u
return err
}

func (s3p awsTablePersister) uploadParts(ctx context.Context, data []byte, key, uploadID string) (*s3.CompletedMultipartUpload, error) {
sent, failed, done := make(chan s3UploadedPart), make(chan error), make(chan struct{})

numParts := getNumParts(uint64(len(data)), s3p.limits.partTarget)

if numParts > maxS3Parts {
return nil, errors.New("exceeded maximum parts")
}

var wg sync.WaitGroup
sendPart := func(partNum, start, end uint64) {
if s3p.rl != nil {
s3p.rl <- struct{}{}
defer func() { <-s3p.rl }()
}
defer wg.Done()

// Check if upload has been terminated
select {
case <-done:
return
default:
}
// Upload the desired part
if partNum == numParts { // If this is the last part, make sure it includes any overflow
end = uint64(len(data))
}
etag, err := s3p.uploadPart(ctx, data[start:end], key, uploadID, int64(partNum))
if err != nil {
failed <- err
return
}
// Try to send along part info. In the case that the upload was aborted, reading from done allows this worker to exit correctly.
select {
case sent <- s3UploadedPart{int64(partNum), etag}:
case <-done:
return
}
}
for i := uint64(0); i < numParts; i++ {
wg.Add(1)
partNum := i + 1 // Parts are 1-indexed
start, end := i*s3p.limits.partTarget, (i+1)*s3p.limits.partTarget
go sendPart(partNum, start, end)
}
go func() {
wg.Wait()
close(sent)
close(failed)
}()

multipartUpload := &s3.CompletedMultipartUpload{}
var firstFailure error
for cont := true; cont; {
select {
case sentPart, open := <-sent:
if open {
multipartUpload.Parts = append(multipartUpload.Parts, &s3.CompletedPart{
ETag: aws.String(sentPart.etag),
PartNumber: aws.Int64(sentPart.idx),
})
}
cont = open

case err := <-failed:
if err != nil && firstFailure == nil { // nil err may happen when failed gets closed
firstFailure = err
close(done)
}
}
}

if firstFailure == nil {
close(done)
}
sort.Sort(partsByPartNum(multipartUpload.Parts))
return multipartUpload, firstFailure
}

func getNumParts(dataLen, minPartSize uint64) uint64 {
numParts := dataLen / minPartSize
if numParts == 0 {
Expand Down
61 changes: 48 additions & 13 deletions go/store/nbs/aws_table_persister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package nbs

import (
"context"
crand "crypto/rand"
"io"
"math/rand"
"sync"
Expand All @@ -31,30 +32,66 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/dolthub/dolt/go/store/util/sizecache"
)

func randomChunks(t *testing.T, r *rand.Rand, sz int) [][]byte {
buf := make([]byte, sz)
_, err := io.ReadFull(crand.Reader, buf)
require.NoError(t, err)

var ret [][]byte
var i int
for i < len(buf) {
j := int(r.NormFloat64()*1024 + 4096)
if i+j >= len(buf) {
ret = append(ret, buf[i:])
} else {
ret = append(ret, buf[i:i+j])
}
i += j
}

return ret
}

func TestRandomChunks(t *testing.T) {
r := rand.New(rand.NewSource(1024))
res := randomChunks(t, r, 10)
assert.Len(t, res, 1)
res = randomChunks(t, r, 4096+2048)
assert.Len(t, res, 2)
res = randomChunks(t, r, 4096+4096)
assert.Len(t, res, 3)
}

func TestAWSTablePersisterPersist(t *testing.T) {
ctx := context.Background()
calcPartSize := func(rdr chunkReader, maxPartNum uint64) uint64 {
return maxTableSize(uint64(mustUint32(rdr.count())), mustUint64(rdr.uncompressedLen())) / maxPartNum
}

mt := newMemTable(testMemTableSize)
r := rand.New(rand.NewSource(1024))
const sz15mb = 1 << 20 * 15
mt := newMemTable(sz15mb)
testChunks := randomChunks(t, r, 1<<20*12)
for _, c := range testChunks {
assert.Equal(t, mt.addChunk(computeAddr(c), c), chunkAdded)
}

var limits5mb = awsLimits{partTarget: 1 << 20 * 5}
var limits64mb = awsLimits{partTarget: 1 << 20 * 64}

t.Run("PersistToS3", func(t *testing.T) {
testIt := func(t *testing.T, ns string) {
t.Run("InMultipleParts", func(t *testing.T) {
assert := assert.New(t)
s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil)
limits := awsLimits{partTarget: calcPartSize(mt, 3)}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}}

src, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
require.NoError(t, err)
Expand All @@ -72,8 +109,7 @@ func TestAWSTablePersisterPersist(t *testing.T) {
assert := assert.New(t)

s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil)
limits := awsLimits{partTarget: calcPartSize(mt, 1)}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits64mb, ns: ns, q: &UnlimitedQuotaProvider{}}

src, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
require.NoError(t, err)
Expand All @@ -89,17 +125,16 @@ func TestAWSTablePersisterPersist(t *testing.T) {
t.Run("NoNewChunks", func(t *testing.T) {
assert := assert.New(t)

mt := newMemTable(testMemTableSize)
existingTable := newMemTable(testMemTableSize)
mt := newMemTable(sz15mb)
existingTable := newMemTable(sz15mb)

for _, c := range testChunks {
assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded)
assert.Equal(existingTable.addChunk(computeAddr(c), c), chunkAdded)
}

s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil)
limits := awsLimits{partTarget: 1 << 10}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}}

src, err := s3p.Persist(context.Background(), mt, existingTable, &Stats{})
require.NoError(t, err)
Expand All @@ -115,8 +150,7 @@ func TestAWSTablePersisterPersist(t *testing.T) {

s3svc := &failingFakeS3{makeFakeS3(t), sync.Mutex{}, 1}
ddb := makeFakeDTS(makeFakeDDB(t), nil)
limits := awsLimits{partTarget: calcPartSize(mt, 4)}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}}

_, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
assert.Error(err)
Expand Down Expand Up @@ -330,14 +364,15 @@ func TestAWSTablePersisterCalcPartSizes(t *testing.T) {

func TestAWSTablePersisterConjoinAll(t *testing.T) {
ctx := context.Background()
targetPartSize := uint64(1024)
const sz5mb = 1 << 20 * 5
targetPartSize := uint64(sz5mb)
minPartSize, maxPartSize := targetPartSize, 5*targetPartSize
maxItemSize, maxChunkCount := int(targetPartSize/2), uint32(4)

rl := make(chan struct{}, 8)
defer close(rl)

newPersister := func(s3svc s3svc, ddb *ddbTableStore) awsTablePersister {
newPersister := func(s3svc s3iface.S3API, ddb *ddbTableStore) awsTablePersister {
return awsTablePersister{
s3svc,
"bucket",
Expand Down
38 changes: 38 additions & 0 deletions go/store/nbs/s3_fake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/stretchr/testify/assert"

"github.com/dolthub/dolt/go/store/d"
Expand All @@ -58,6 +60,8 @@ func makeFakeS3(t *testing.T) *fakeS3 {
}

type fakeS3 struct {
s3iface.S3API

assert *assert.Assertions

mu sync.Mutex
Expand Down Expand Up @@ -279,3 +283,37 @@ func (m *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput,

return &s3.PutObjectOutput{}, nil
}

func (m *fakeS3) GetObjectRequest(input *s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput) {
out := &s3.GetObjectOutput{}
var handlers request.Handlers
handlers.Send.PushBack(func(r *request.Request) {
res, err := m.GetObjectWithContext(r.Context(), input)
r.Error = err
if res != nil {
*(r.Data.(*s3.GetObjectOutput)) = *res
}
})
return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{
Name: "GetObject",
HTTPMethod: "GET",
HTTPPath: "/{Bucket}/{Key+}",
}, input, out), out
}

func (m *fakeS3) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) {
out := &s3.PutObjectOutput{}
var handlers request.Handlers
handlers.Send.PushBack(func(r *request.Request) {
res, err := m.PutObjectWithContext(r.Context(), input)
r.Error = err
if res != nil {
*(r.Data.(*s3.PutObjectOutput)) = *res
}
})
return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{
Name: "PutObject",
HTTPMethod: "PUT",
HTTPPath: "/{Bucket}/{Key+}",
}, input, out), out
}
Loading