Skip to content

Commit

Permalink
perf(storage): remove protobuf's copy of data on unmarshalling (#9526)
Browse files Browse the repository at this point in the history
  • Loading branch information
BrennaEpp authored Mar 19, 2024
1 parent a3bb7c0 commit 81281c0
Show file tree
Hide file tree
Showing 4 changed files with 405 additions and 12 deletions.
2 changes: 1 addition & 1 deletion storage/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
cloud.google.com/go v0.112.1
cloud.google.com/go/compute/metadata v0.2.3
cloud.google.com/go/iam v1.1.6
github.com/golang/protobuf v1.5.3
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/googleapis/gax-go/v2 v2.12.2
Expand All @@ -26,7 +27,6 @@ require (
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/martian/v3 v3.3.2 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
Expand Down
262 changes: 252 additions & 10 deletions storage/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ import (
"cloud.google.com/go/internal/trace"
gapic "cloud.google.com/go/storage/internal/apiv2"
"cloud.google.com/go/storage/internal/apiv2/storagepb"
"github.com/golang/protobuf/proto"
"github.com/googleapis/gax-go/v2"
"google.golang.org/api/googleapi"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protowire"
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
)

Expand Down Expand Up @@ -902,12 +905,50 @@ func (c *grpcStorageClient) RewriteObject(ctx context.Context, req *rewriteObjec
return r, nil
}

// bytesCodec is a grpc codec which permits receiving messages as either
// protobuf messages, or as raw []bytes.
type bytesCodec struct {
encoding.Codec
}

func (bytesCodec) Marshal(v any) ([]byte, error) {
vv, ok := v.(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
}
return proto.Marshal(vv)
}

func (bytesCodec) Unmarshal(data []byte, v any) error {
switch v := v.(type) {
case *[]byte:
// If gRPC could recycle the data []byte after unmarshaling (through
// buffer pools), we would need to make a copy here.
*v = data
return nil
case proto.Message:
return proto.Unmarshal(data, v)
default:
return fmt.Errorf("can not unmarshal type %T", v)
}
}

func (bytesCodec) Name() string {
// If this isn't "", then gRPC sets the content-subtype of the call to this
// value and we get errors.
return ""
}

func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRangeReaderParams, opts ...storageOption) (r *Reader, err error) {
ctx = trace.StartSpan(ctx, "cloud.google.com/go/storage.grpcStorageClient.NewRangeReader")
defer func() { trace.EndSpan(ctx, err) }()

s := callSettings(c.settings, opts...)

s.gax = append(s.gax, gax.WithGRPCOptions(
grpc.ForceCodec(bytesCodec{}),
))

if s.userProject != "" {
ctx = setUserProjectMetadata(ctx, s.userProject)
}
Expand All @@ -923,6 +964,8 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
req.Generation = params.gen
}

var databuf []byte

// Define a function that initiates a Read with offset and length, assuming
// we have already read seen bytes.
reopen := func(seen int64) (*readStreamResponse, context.CancelFunc, error) {
Expand Down Expand Up @@ -957,12 +1000,23 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
return err
}

msg, err = stream.Recv()
// Receive the message into databuf as a wire-encoded message so we can
// use a custom decoder to avoid an extra copy at the protobuf layer.
err := stream.RecvMsg(&databuf)
// These types of errors show up on the Recv call, rather than the
// initialization of the stream via ReadObject above.
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
return ErrObjectNotExist
}
if err != nil {
return err
}
// Use a custom decoder that uses protobuf unmarshalling for all
// fields except the checksummed data.
// Subsequent receives in Read calls will skip all protobuf
// unmarshalling and directly read the content from the gRPC []byte
// response, since only the first call will contain other fields.
msg, err = readFullObjectResponse(databuf)

return err
}, s.retry, s.idempotent)
Expand Down Expand Up @@ -1008,6 +1062,7 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
leftovers: msg.GetChecksummedData().GetContent(),
settings: s,
zeroRange: params.length == 0,
databuf: databuf,
},
}

Expand Down Expand Up @@ -1406,6 +1461,7 @@ type gRPCReader struct {
stream storagepb.Storage_ReadObjectClient
reopen func(seen int64) (*readStreamResponse, context.CancelFunc, error)
leftovers []byte
databuf []byte
cancel context.CancelFunc
settings *settings
}
Expand Down Expand Up @@ -1436,7 +1492,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
}

// Attempt to Recv the next message on the stream.
msg, err := r.recv()
content, err := r.recv()
if err != nil {
return 0, err
}
Expand All @@ -1448,7 +1504,6 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
// present in the response here.
// TODO: Figure out if we need to support decompressive transcoding
// https://cloud.google.com/storage/docs/transcoding.
content := msg.GetChecksummedData().GetContent()
n = copy(p[n:], content)
leftover := len(content) - n
if leftover > 0 {
Expand All @@ -1471,18 +1526,20 @@ func (r *gRPCReader) Close() error {
return nil
}

// recv attempts to Recv the next message on the stream. In the event
// that a retryable error is encountered, the stream will be closed, reopened,
// and Recv again. This will attempt to Recv until one of the following is true:
// recv attempts to Recv the next message on the stream and extract the object
// data that it contains. In the event that a retryable error is encountered,
// the stream will be closed, reopened, and RecvMsg again.
// This will attempt to Recv until one of the following is true:
//
// * Recv is successful
// * A non-retryable error is encountered
// * The Reader's context is canceled
//
// The last error received is the one that is returned, which could be from
// an attempt to reopen the stream.
func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) {
msg, err := r.stream.Recv()
func (r *gRPCReader) recv() ([]byte, error) {
err := r.stream.RecvMsg(&r.databuf)

var shouldRetry = ShouldRetry
if r.settings.retry != nil && r.settings.retry.shouldRetry != nil {
shouldRetry = r.settings.retry.shouldRetry
Expand All @@ -1492,10 +1549,195 @@ func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) {
// reopen the stream, but will backoff if further attempts are necessary.
// Reopening the stream Recvs the first message, so if retrying is
// successful, the next logical chunk will be returned.
msg, err = r.reopenStream()
msg, err := r.reopenStream()
return msg.GetChecksummedData().GetContent(), err
}

if err != nil {
return nil, err
}

return readObjectResponseContent(r.databuf)
}

// ReadObjectResponse field and subfield numbers.
const (
checksummedDataField = protowire.Number(1)
checksummedDataContentField = protowire.Number(1)
checksummedDataCRC32CField = protowire.Number(2)
objectChecksumsField = protowire.Number(2)
contentRangeField = protowire.Number(3)
metadataField = protowire.Number(4)
)

// readObjectResponseContent returns the checksummed_data.content field of a
// ReadObjectResponse message, or an error if the message is invalid.
// This can be used on recvs of objects after the first recv, since only the
// first message will contain non-data fields.
func readObjectResponseContent(b []byte) ([]byte, error) {
checksummedData, err := readProtoBytes(b, checksummedDataField)
if err != nil {
return b, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", err)
}
content, err := readProtoBytes(checksummedData, checksummedDataContentField)
if err != nil {
return content, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", err)
}

return msg, err
return content, nil
}

// readFullObjectResponse returns the ReadObjectResponse that is encoded in the
// wire-encoded message buffer b, or an error if the message is invalid.
// This must be used on the first recv of an object as it may contain all fields
// of ReadObjectResponse, and we use or pass on those fields to the user.
// This function is essentially identical to proto.Unmarshal, except it aliases
// the data in the input []byte. If the proto library adds a feature to
// Unmarshal that does that, this function can be dropped.
func readFullObjectResponse(b []byte) (*storagepb.ReadObjectResponse, error) {
msg := &storagepb.ReadObjectResponse{}

// Loop over the entire message, extracting fields as we go. This does not
// handle field concatenation, in which the contents of a single field
// are split across multiple protobuf tags.
off := 0
for off < len(b) {
// Consume the next tag. This will tell us which field is next in the
// buffer, its type, and how much space it takes up.
fieldNum, fieldType, fieldLength := protowire.ConsumeTag(b[off:])
if fieldLength < 0 {
return nil, protowire.ParseError(fieldLength)
}
off += fieldLength

// Unmarshal the field according to its type. Only fields that are not
// nil will be present.
switch {
case fieldNum == checksummedDataField && fieldType == protowire.BytesType:
// The ChecksummedData field was found. Initialize the struct.
msg.ChecksummedData = &storagepb.ChecksummedData{}

// Get the bytes corresponding to the checksummed data.
fieldContent, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", protowire.ParseError(n))
}
off += n

// Get the nested fields. We need to do this manually as it contains
// the object content bytes.
contentOff := 0
for contentOff < len(fieldContent) {
gotNum, gotTyp, n := protowire.ConsumeTag(fieldContent[contentOff:])
if n < 0 {
return nil, protowire.ParseError(n)
}
contentOff += n

switch {
case gotNum == checksummedDataContentField && gotTyp == protowire.BytesType:
// Get the content bytes.
bytes, n := protowire.ConsumeBytes(fieldContent[contentOff:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", protowire.ParseError(n))
}
msg.ChecksummedData.Content = bytes
contentOff += n
case gotNum == checksummedDataCRC32CField && gotTyp == protowire.Fixed32Type:
v, n := protowire.ConsumeFixed32(fieldContent[contentOff:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Crc32C: %v", protowire.ParseError(n))
}
msg.ChecksummedData.Crc32C = &v
contentOff += n
default:
n = protowire.ConsumeFieldValue(gotNum, gotTyp, fieldContent[contentOff:])
if n < 0 {
return nil, protowire.ParseError(n)
}
contentOff += n
}
}
case fieldNum == objectChecksumsField && fieldType == protowire.BytesType:
// The field was found. Initialize the struct.
msg.ObjectChecksums = &storagepb.ObjectChecksums{}

// Get the bytes corresponding to the checksums.
bytes, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ObjectChecksums: %v", protowire.ParseError(n))
}
off += n

// Unmarshal.
if err := proto.Unmarshal(bytes, msg.ObjectChecksums); err != nil {
return nil, err
}
case fieldNum == contentRangeField && fieldType == protowire.BytesType:
msg.ContentRange = &storagepb.ContentRange{}

bytes, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ContentRange: %v", protowire.ParseError(n))
}
off += n

if err := proto.Unmarshal(bytes, msg.ContentRange); err != nil {
return nil, err
}
case fieldNum == metadataField && fieldType == protowire.BytesType:
msg.Metadata = &storagepb.Object{}

bytes, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.Metadata: %v", protowire.ParseError(n))
}
off += n

if err := proto.Unmarshal(bytes, msg.Metadata); err != nil {
return nil, err
}
default:
fieldLength = protowire.ConsumeFieldValue(fieldNum, fieldType, b[off:])
if fieldLength < 0 {
return nil, fmt.Errorf("default: %v", protowire.ParseError(fieldLength))
}
off += fieldLength
}
}

return msg, nil
}

// readProtoBytes returns the contents of the protobuf field with number num
// and type bytes from a wire-encoded message. If the field cannot be found,
// the returned slice will be nil and no error will be returned.
//
// It does not handle field concatenation, in which the contents of a single field
// are split across multiple protobuf tags. Encoded data containing split fields
// of this form is technically permissable, but uncommon.
func readProtoBytes(b []byte, num protowire.Number) ([]byte, error) {
off := 0
for off < len(b) {
gotNum, gotTyp, n := protowire.ConsumeTag(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
if gotNum == num && gotTyp == protowire.BytesType {
b, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
return b, nil
}
n = protowire.ConsumeFieldValue(gotNum, gotTyp, b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
}
return nil, nil
}

// reopenStream "closes" the existing stream and attempts to reopen a stream and
Expand Down
Loading

0 comments on commit 81281c0

Please sign in to comment.