From daabca7baf1a54fc8817578fb7f334c4860bcfa6 Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Thu, 5 Oct 2023 10:14:09 +0000 Subject: [PATCH 1/5] feat(spanner): BatchWrite --- spanner/client.go | 190 ++++++++++++++ spanner/client_test.go | 248 ++++++++++++++++++ spanner/integration_test.go | 86 ++++++ .../internal/testutil/inmem_spanner_server.go | 34 +++ .../testutil/inmem_spanner_server_test.go | 101 +++++++ spanner/mutation.go | 20 ++ spanner/mutation_test.go | 107 ++++++++ 7 files changed, 786 insertions(+) diff --git a/spanner/client.go b/spanner/client.go index d28b95cfaf20..d9060281a5b7 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -19,6 +19,7 @@ package spanner import ( "context" "fmt" + "io" "log" "os" "regexp" @@ -26,6 +27,8 @@ import ( "cloud.google.com/go/internal/trace" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/gax-go/v2" + "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/api/option/internaloption" gtransport "google.golang.org/api/transport/grpc" @@ -669,6 +672,193 @@ func (c *Client) Apply(ctx context.Context, ms []*Mutation, opts ...ApplyOption) return t.applyAtLeastOnce(ctx, ms...) } +// An iterator over BatchWriteResponse structures returned from BatchWrite RPC. +type BatchWriteResponseIterator struct { + ctx context.Context + stream sppb.Spanner_BatchWriteClient + err error + dataReceived bool + replaceSession func(ctx context.Context) error + rpc func(ctx context.Context) (sppb.Spanner_BatchWriteClient, error) + release func(error) + cancel func() +} + +// Next returns the next result. Its second return value is iterator.Done if +// there are no more results. Once Next returns Done, all subsequent calls +// will return Done. +func (r *BatchWriteResponseIterator) Next() (*sppb.BatchWriteResponse, error) { + for { + // Stream finished or in error state. + if r.err != nil { + return nil, r.err + } + + // RPC not made yet. + if r.stream == nil { + r.stream, r.err = r.rpc(r.ctx) + continue + } + + // Read from the stream. + var response *sppb.BatchWriteResponse + response, r.err = r.stream.Recv() + + // Return an item. + if r.err == nil { + r.dataReceived = true + return response, nil + } + + // Stream finished. + if r.err == io.EOF { + r.err = iterator.Done + return nil, r.err + } + + // Retry request on session not found error only if no data has been received before. + if !r.dataReceived && r.replaceSession != nil && isSessionNotFoundError(r.err) { + r.err = r.replaceSession(r.ctx) + r.stream = nil + } + } +} + +// Stop terminates the iteration. It should be called after you finish using the +// iterator. +func (r *BatchWriteResponseIterator) Stop() { + if r.stream != nil { + err := r.err + if err == iterator.Done { + err = nil + } + defer trace.EndSpan(r.ctx, err) + } + if r.cancel != nil { + r.cancel() + r.cancel = nil + } + if r.release != nil { + r.release(r.err) + r.release = nil + } + if r.err == nil { + r.err = spannerErrorf(codes.FailedPrecondition, "Next called after Stop") + } +} + +// Do calls the provided function once in sequence for each item in the +// iteration. If the function returns a non-nil error, Do immediately returns +// that error. +// +// If there are no items in the iterator, Do will return nil without calling the +// provided function. +// +// Do always calls Stop on the iterator. +func (r *BatchWriteResponseIterator) Do(f func(r *sppb.BatchWriteResponse) error) error { + defer r.Stop() + for { + row, err := r.Next() + switch err { + case iterator.Done: + return nil + case nil: + if err = f(row); err != nil { + return err + } + default: + return err + } + } +} + +// Batch Write applies a list of mutation groups in a collection of efficient +// transactions. The mutation groups are applied non-atomically in an +// unspecified order and thus, they must be independent of each other. Partial +// failure is possible, i.e., some mutation groups may have been applied +// successfully, while some may have failed. The results of individual batches +// are streamed into the response as the batches are applied. +// +// BatchWrite requests are not replay protected, meaning that each mutation +// group may be applied more than once. Replays of non-idempotent mutations +// may have undesirable effects. For example, replays of an insert mutation +// may produce an already exists error or if you use generated or commit +// timestamp-based keys, it may result in additional rows being added to the +// mutation's table. We recommend structuring your mutation groups to be +// idempotent to avoid this issue. +func (c *Client) BatchWrite(ctx context.Context, mgs []*MutationGroup, opts ...ApplyOption) *BatchWriteResponseIterator { + ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchWrite") + + var err error + defer func() { + trace.EndSpan(ctx, err) + }() + + ao := &applyOption{} + for _, opt := range c.ao { + opt(ao) + } + for _, opt := range opts { + opt(ao) + } + + mgsPb, err := mutationGroupsProto(mgs) + if err != nil { + return &BatchWriteResponseIterator{err: err} + } + + var sh *sessionHandle + sh, err = c.idleSessions.take(ctx) + if err != nil { + return &BatchWriteResponseIterator{err: err} + } + + rpc := func(ct context.Context) (sppb.Spanner_BatchWriteClient, error) { + var md metadata.MD + stream, rpcErr := sh.getClient().BatchWrite(contextWithOutgoingMetadata(ct, sh.getMetadata(), c.disableRouteToLeader), &sppb.BatchWriteRequest{ + Session: sh.getID(), + MutationGroups: mgsPb, + RequestOptions: createRequestOptions(ao.priority, "", ao.transactionTag), + }, gax.WithGRPCOptions(grpc.Header(&md))) + + if getGFELatencyMetricsFlag() && md != nil && c.ct != nil { + if metricErr := createContextAndCaptureGFELatencyMetrics(ct, c.ct, md, "BatchWrite"); metricErr != nil { + trace.TracePrintf(ct, nil, "Error in recording GFE Latency. Try disabling and rerunning. Error: %v", err) + } + } + return stream, rpcErr + } + + replaceSession := func(ct context.Context) error { + if sh != nil { + sh.destroy() + } + var sessionErr error + sh, sessionErr = c.idleSessions.take(ct) + return sessionErr + } + + release := func(err error) { + if sh == nil { + return + } + if isSessionNotFoundError(err) { + sh.destroy() + } + sh.recycle() + } + + ctx, cancel := context.WithCancel(ctx) + ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchWriteResponseIterator") + return &BatchWriteResponseIterator{ + ctx: ctx, + rpc: rpc, + replaceSession: replaceSession, + release: release, + cancel: cancel, + } +} + // logf logs the given message to the given logger, or the standard logger if // the given logger is nil. func logf(logger *log.Logger, format string, v ...interface{}) { diff --git a/spanner/client_test.go b/spanner/client_test.go index aa17d7454c3c..5813542a4d5e 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -31,6 +31,7 @@ import ( itestutil "cloud.google.com/go/internal/testutil" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/googleapis/gax-go/v2" "google.golang.org/api/iterator" "google.golang.org/api/option" @@ -4304,3 +4305,250 @@ func TestClient_CustomRetryAndTimeoutSettings(t *testing.T) { t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestClient_BatchWrite(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWrite(context.Background(), mutationGroups) + responseCount := int(0) + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + if responseCount != len(mutationGroups) { + t.Fatalf("Response count mismatch.\nGot: %v\nWant:%v", responseCount, len(mutationGroups)) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchWriteRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_BatchWrite_SessionNotFound(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodBatchWrite, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWrite(context.Background(), mutationGroups) + responseCount := int(0) + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + if responseCount != len(mutationGroups) { + t.Fatalf("Response count mismatch.\nGot: %v\nWant:%v", responseCount, len(mutationGroups)) + } + + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchWriteRequest{}, + &sppb.BatchWriteRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_BatchWrite_Error(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + injectedErr := status.Error(codes.InvalidArgument, "Invalid argument") + server.TestSpanner.PutExecutionTime( + MethodBatchWrite, + SimulatedExecutionTime{Errors: []error{injectedErr}}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWrite(context.Background(), mutationGroups) + responseCount := int(0) + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); status.Code(err) != status.Code(injectedErr) { + t.Fatalf("Error mismatch.\nGot:%v\nExpected:%v\n", err, injectedErr) + } + if responseCount != 0 { + t.Fatalf("Do function unexpectedly called %v times", responseCount) + } +} + +func checkBatchWriteForExpectedRequestOptions(t *testing.T, server InMemSpannerServer, want *sppb.RequestOptions) { + reqs := drainRequestsFromServer(server) + var got *sppb.RequestOptions + + for _, req := range reqs { + if request, ok := req.(*sppb.BatchWriteRequest); ok { + got = request.RequestOptions + break + } + } + + if got == nil { + t.Fatalf("Missing BatchWrite RequestOptions") + } + + if diff := itestutil.Diff(got, want, cmpopts.IgnoreUnexported(sppb.RequestOptions{})); diff != "" { + t.Fatalf("RequestOptions mismatch. (+Got, -Want):%v", diff) + } +} + +func TestClient_BatchWrite_ApplyOptions(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + client []ApplyOption + write []ApplyOption + wantTransactionTag string + wantPriority sppb.RequestOptions_Priority + }{ + { + name: "Client level", + client: []ApplyOption{TransactionTag("testTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + wantTransactionTag: "testTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_LOW, + }, + { + name: "Write level", + write: []ApplyOption{TransactionTag("testTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + wantTransactionTag: "testTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_LOW, + }, + { + name: "Write level has precedence over client level", + client: []ApplyOption{TransactionTag("clientTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + write: []ApplyOption{TransactionTag("writeTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_MEDIUM)}, + wantTransactionTag: "writeTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_MEDIUM, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ApplyOptions: tt.client}) + defer teardown() + + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWrite(context.Background(), mutationGroups, tt.write...) + doFunc := func(r *sppb.BatchWriteResponse) error { + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + checkBatchWriteForExpectedRequestOptions(t, server.TestSpanner, &sppb.RequestOptions{Priority: tt.wantPriority, TransactionTag: tt.wantTransactionTag}) + }) + } +} + +func checkBatchWriteSpan(t *testing.T, errors []error, code codes.Code) { + // This test cannot be parallel, as the TestExporter does not support that. + te := itestutil.NewTestExporter() + defer te.Unregister() + minOpened := uint64(1) + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: minOpened, + WriteSessions: 0, + }, + }) + defer teardown() + + // Wait until all sessions have been created, so we know that those requests will not interfere with the test. + sp := client.idleSessions + waitFor(t, func() error { + sp.mu.Lock() + defer sp.mu.Unlock() + if uint64(sp.idleList.Len()) != minOpened { + return fmt.Errorf("num open sessions mismatch\nWant: %d\nGot: %d", sp.MinOpened, sp.numOpened) + } + return nil + }) + + server.TestSpanner.PutExecutionTime( + MethodBatchWrite, + SimulatedExecutionTime{Errors: errors}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWrite(context.Background(), mutationGroups) + iter.Do(func(r *sppb.BatchWriteResponse) error { + return nil + }) + select { + case <-te.Stats: + case <-time.After(1 * time.Second): + t.Fatal("No stats were exported before timeout") + } + // Preferably we would want to lock the TestExporter here, but the mutex TestExporter.mu is not exported, so we + // cannot do that. + if len(te.Spans) == 0 { + t.Fatal("No spans were exported") + } + s := te.Spans[len(te.Spans)-1].Status + if s.Code != int32(code) { + t.Errorf("Span status mismatch\nGot: %v\nWant: %v", s.Code, code) + } +} +func TestClient_BatchWrite_SpanExported(t *testing.T) { + testcases := []struct { + name string + code codes.Code + errors []error + }{ + { + name: "Success", + code: codes.OK, + errors: []error{}, + }, + { + name: "Error", + code: codes.InvalidArgument, + errors: []error{status.Error(codes.InvalidArgument, "Invalid argument")}, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + checkBatchWriteSpan(t, tt.errors, tt.code) + }) + } +} diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 5d016ede945d..8ad9218a87f0 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -897,6 +897,92 @@ func deleteTestSession(ctx context.Context, t *testing.T, sessionName string) { } } +func TestIntegration_BatchWrite(t *testing.T) { + skipEmulatorTest(t) + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + // Set up testing environment. + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) + defer cleanup() + + writes := []struct { + row []interface{} + ts time.Time + }{ + {row: []interface{}{1, "Marc", "Foo"}}, + {row: []interface{}{2, "Tars", "Bar"}}, + {row: []interface{}{3, "Alpha", "Beta"}}, + {row: []interface{}{4, "Last", "End"}}, + } + mgs := make([]*MutationGroup, len(writes)) + // Try to write four rows through the BatchWrite API. + for i, w := range writes { + m := InsertOrUpdate("Singers", + []string{"SingerId", "FirstName", "LastName"}, + w.row) + ms := make([]*Mutation, 1) + ms[0] = m + mgs[i] = &MutationGroup{mutations: ms} + } + // Records the mutation group indexes received in the response. + seen := make(map[int32]int32) + numMutationGroups := len(mgs) + validate := func(res *sppb.BatchWriteResponse) error { + if status := status.ErrorProto(res.GetStatus()); status != nil { + t.Fatalf("Invalid status: %v", status) + } + if ts := res.GetCommitTimestamp(); ts == nil { + t.Fatal("Invalid commit timestamp") + } + for _, idx := range res.GetIndexes() { + if idx >= 0 && idx < int32(numMutationGroups) { + seen[idx] += 1 + } else { + t.Fatalf("Index %v out of range. Expected range [%v,%v]", idx, 0, numMutationGroups-1) + } + } + return nil + } + iter := client.BatchWrite(ctx, mgs, TransactionTag("golang-test-batch-write")) + if err := iter.Do(validate); err != nil { + t.Fatal(err) + } + // Validate that each mutation group index is seen exactly once. + if numMutationGroups != len(seen) { + t.Fatalf("Expected %v indexes, got %v indexes", numMutationGroups, len(seen)) + } + for idx, ct := range seen { + if ct != 1 { + t.Fatalf("Index %v seen %v times instead of exactly once", idx, ct) + } + } + + // Verify the writes by reading the database. + singersQuery := "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@p1, @p2, @p3)" + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + singersQuery = "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId = $1 OR SingerId = $2 OR SingerId = $3" + } + qo := QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{ + OptimizerVersion: "1", + OptimizerStatisticsPackage: "latest", + }} + got, err := readAll(client.Single().QueryWithOptions(ctx, Statement{ + singersQuery, + map[string]interface{}{"p1": int64(1), "p2": int64(3), "p3": int64(4)}, + }, qo)) + + if err != nil { + t.Errorf("ReadOnlyTransaction.QueryWithOptions returns error %v, want nil", err) + } + + want := [][]interface{}{{int64(1), "Marc", "Foo"}, {int64(3), "Alpha", "Beta"}, {int64(4), "Last", "End"}} + if !testEqual(got, want) { + t.Errorf("got unexpected result from ReadOnlyTransaction.QueryWithOptions: %v, want %v", got, want) + } +} + func TestIntegration_SingleUse_ReadingWithLimit(t *testing.T) { t.Parallel() diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index b1adf02f2182..6fc06f59e776 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -87,6 +87,7 @@ const ( MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" MethodExecuteBatchDml string = "EXECUTE_BATCH_DML" MethodStreamingRead string = "EXECUTE_STREAMING_READ" + MethodBatchWrite string = "BATCH_WRITE" ) // StatementResult represents a mocked result on the test server. The result is @@ -1144,3 +1145,36 @@ func DecodeResumeToken(t []byte) (uint64, error) { } return s, nil } + +func (s *inMemSpannerServer) BatchWrite(req *spannerpb.BatchWriteRequest, stream spannerpb.Spanner_BatchWriteServer) error { + if err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil { + return err + } + return s.batchWrite(req, stream) +} + +func (s *inMemSpannerServer) batchWrite(req *spannerpb.BatchWriteRequest, stream spannerpb.Spanner_BatchWriteServer) error { + if req.Session == "" { + return gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return err + } + s.updateSessionLastUseTime(session.Name) + if len(req.GetMutationGroups()) == 0 { + return gstatus.Error(codes.InvalidArgument, "No mutations in Batch Write") + } + // For each MutationGroup, write a BatchWriteResponse to the response stream + for idx := range req.GetMutationGroups() { + res := &spannerpb.BatchWriteResponse{ + Indexes: []int32{int32(idx)}, + CommitTimestamp: getCurrentTimestamp(), + Status: &status.Status{}, + } + if err = stream.Send(res); err != nil { + return err + } + } + return nil +} diff --git a/spanner/internal/testutil/inmem_spanner_server_test.go b/spanner/internal/testutil/inmem_spanner_server_test.go index 18f667835881..fb864e3f236a 100644 --- a/spanner/internal/testutil/inmem_spanner_server_test.go +++ b/spanner/internal/testutil/inmem_spanner_server_test.go @@ -15,6 +15,7 @@ package testutil_test import ( + "io" "strconv" . "cloud.google.com/go/spanner/internal/testutil" @@ -603,3 +604,103 @@ func TestRollbackTransaction(t *testing.T) { t.Fatal(err) } } + +func TestSpannerBatchWrite(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + + batchWriteRequest := &spannerpb.BatchWriteRequest{ + Session: session.Name, + MutationGroups: []*spannerpb.BatchWriteRequest_MutationGroup{ + {Mutations: []*spannerpb.Mutation{ + { + Operation: &spannerpb.Mutation_Delete_{ + Delete: &spannerpb.Mutation_Delete{ + Table: "t_test", + KeySet: &spannerpb.KeySet{ + Keys: []*structpb.ListValue{ + { + Values: []*structpb.Value{ + {Kind: &structpb.Value_StringValue{StringValue: "k"}}, + }, + }, + }, + }, + }, + }, + }, + }}, + {Mutations: []*spannerpb.Mutation{ + { + Operation: &spannerpb.Mutation_Insert{ + Insert: &spannerpb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*structpb.ListValue{ + { + Values: []*structpb.Value{ + {Kind: &structpb.Value_StringValue{StringValue: "k"}}, + {Kind: &structpb.Value_StringValue{StringValue: "v"}}, + }, + }, + }, + }, + }, + }, + }}, + }, + } + stream, err := c.BatchWrite(context.Background(), batchWriteRequest) + if err != nil { + t.Fatal(err) + } + // Records the mutation group indexes received in the response. + seen := make(map[int32]int32) + numMutationGroups := len(batchWriteRequest.GetMutationGroups()) + validate := func(res *spannerpb.BatchWriteResponse) { + if status := res.GetStatus().GetCode(); status != int32(codes.OK) { + t.Fatalf("Invalid status: %v", status) + } + if ts := res.GetCommitTimestamp(); ts == nil { + t.Fatal("Invalid commit timestamp") + } + for _, idx := range res.GetIndexes() { + if idx >= 0 && idx < int32(numMutationGroups) { + seen[idx] += 1 + } else { + t.Fatalf("Index %v out of range. Expected range [%v,%v]", idx, 0, numMutationGroups-1) + } + } + } + for { + response, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + validate(response) + } + // Validate that each mutation group index is seen exactly once. + if numMutationGroups != len(seen) { + t.Fatalf("Expected %v indexes, got %v indexes", numMutationGroups, len(seen)) + } + for idx, ct := range seen { + if ct != 1 { + t.Fatalf("Index %v seen %v times instead of exactly once", idx, ct) + } + } +} diff --git a/spanner/mutation.go b/spanner/mutation.go index 19ced2be66fb..1d8718065de9 100644 --- a/spanner/mutation.go +++ b/spanner/mutation.go @@ -141,6 +141,12 @@ type Mutation struct { values []interface{} } +// A group of mutations to be committed atomically. +type MutationGroup struct { + // The mutations in this group + mutations []*Mutation +} + // mapToMutationParams converts Go map into mutation parameters. func mapToMutationParams(in map[string]interface{}) ([]string, []interface{}) { cols := []string{} @@ -432,3 +438,17 @@ func mutationsProto(ms []*Mutation) ([]*sppb.Mutation, error) { } return l, nil } + +// mutationGroupsProto turns a spanner.MutationGroup array into a +// sppb.BatchWriteRequest_MutationGroup array, in preparation to send RPCs. +func mutationGroupsProto(mgs []*MutationGroup) ([]*sppb.BatchWriteRequest_MutationGroup, error) { + gs := make([]*sppb.BatchWriteRequest_MutationGroup, 0, len(mgs)) + for _, mg := range mgs { + ms, err := mutationsProto(mg.mutations) + if err != nil { + return nil, err + } + gs = append(gs, &sppb.BatchWriteRequest_MutationGroup{Mutations: ms}) + } + return gs, nil +} diff --git a/spanner/mutation_test.go b/spanner/mutation_test.go index c5dacf65f985..8047a6b9f9c4 100644 --- a/spanner/mutation_test.go +++ b/spanner/mutation_test.go @@ -618,3 +618,110 @@ func TestEncodeMutationArray(t *testing.T) { } } } + +func TestEncodeMutationGroupArray(t *testing.T) { + for _, test := range []struct { + name string + mgs []*MutationGroup + want []*sppb.BatchWriteRequest_MutationGroup + wantErr error + }{ + { + "Multiple Mutations", + []*MutationGroup{ + {[]*Mutation{ + {opDelete, "t_test", Key{"bar"}, nil, nil}, + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + {[]*Mutation{ + {opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{"foo2", 1}}, + {opUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo3", 1}}, + }}, + {[]*Mutation{ + {opReplace, "t_test", nil, []string{"key", "val"}, []interface{}{"foo4", 1}}, + }}, + }, + []*sppb.BatchWriteRequest_MutationGroup{ + {Mutations: []*sppb.Mutation{ + { + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: "t_test", + KeySet: &sppb.KeySet{ + Keys: []*proto3.ListValue{listValueProto(stringProto("bar"))}, + }, + }, + }, + }, + { + Operation: &sppb.Mutation_InsertOrUpdate{ + InsertOrUpdate: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo1"), intProto(1))}, + }, + }, + }, + }}, + {Mutations: []*sppb.Mutation{ + { + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo2"), intProto(1))}, + }, + }, + }, + { + Operation: &sppb.Mutation_Update{ + Update: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo3"), intProto(1))}, + }, + }, + }, + }}, + {Mutations: []*sppb.Mutation{ + { + Operation: &sppb.Mutation_Replace{ + Replace: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo4"), intProto(1))}, + }, + }, + }, + }}, + }, + nil, + }, + { + "Multiple Mutations - Bad Mutation", + []*MutationGroup{ + {[]*Mutation{ + {opDelete, "t_test", Key{"bar"}, nil, nil}, + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", struct{}{}}}, + }}, + {[]*Mutation{ + {opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{"foo2", 1}}, + {opUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo3", 1}}, + }}, + }, + []*sppb.BatchWriteRequest_MutationGroup{}, + errEncoderUnsupportedType(struct{}{}), + }, + } { + gotProto, gotErr := mutationGroupsProto(test.mgs) + if gotErr != nil { + if !testEqual(gotErr, test.wantErr) { + t.Errorf("%v: mutationGroupsProto(%v) returns error %v, want %v", test.name, test.mgs, gotErr, test.wantErr) + } + continue + } + if !testEqual(gotProto, test.want) { + t.Errorf("%v: mutationGroupsProto(%v) = (%v, nil), want (%v, nil)", test.name, test.mgs, gotProto, test.want) + } + } +} From 96f048597711045e50b65db8cbe8c5f780785d7d Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Thu, 12 Oct 2023 10:09:56 +0000 Subject: [PATCH 2/5] vet: apply recommendations --- spanner/client.go | 4 ++-- spanner/integration_test.go | 2 +- spanner/internal/testutil/inmem_spanner_server_test.go | 2 +- spanner/mutation.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/spanner/client.go b/spanner/client.go index d9060281a5b7..0dab33085f9f 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -672,7 +672,7 @@ func (c *Client) Apply(ctx context.Context, ms []*Mutation, opts ...ApplyOption) return t.applyAtLeastOnce(ctx, ms...) } -// An iterator over BatchWriteResponse structures returned from BatchWrite RPC. +// BatchWriteResponseIterator is an iterator over BatchWriteResponse structures returned from BatchWrite RPC. type BatchWriteResponseIterator struct { ctx context.Context stream sppb.Spanner_BatchWriteClient @@ -772,7 +772,7 @@ func (r *BatchWriteResponseIterator) Do(f func(r *sppb.BatchWriteResponse) error } } -// Batch Write applies a list of mutation groups in a collection of efficient +// BatchWrite applies a list of mutation groups in a collection of efficient // transactions. The mutation groups are applied non-atomically in an // unspecified order and thus, they must be independent of each other. Partial // failure is possible, i.e., some mutation groups may have been applied diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 8ad9218a87f0..9ac1fdec86d9 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -938,7 +938,7 @@ func TestIntegration_BatchWrite(t *testing.T) { } for _, idx := range res.GetIndexes() { if idx >= 0 && idx < int32(numMutationGroups) { - seen[idx] += 1 + seen[idx]++ } else { t.Fatalf("Index %v out of range. Expected range [%v,%v]", idx, 0, numMutationGroups-1) } diff --git a/spanner/internal/testutil/inmem_spanner_server_test.go b/spanner/internal/testutil/inmem_spanner_server_test.go index fb864e3f236a..7ee83ac9c4ac 100644 --- a/spanner/internal/testutil/inmem_spanner_server_test.go +++ b/spanner/internal/testutil/inmem_spanner_server_test.go @@ -678,7 +678,7 @@ func TestSpannerBatchWrite(t *testing.T) { } for _, idx := range res.GetIndexes() { if idx >= 0 && idx < int32(numMutationGroups) { - seen[idx] += 1 + seen[idx]++ } else { t.Fatalf("Index %v out of range. Expected range [%v,%v]", idx, 0, numMutationGroups-1) } diff --git a/spanner/mutation.go b/spanner/mutation.go index 1d8718065de9..03718c3c3ffd 100644 --- a/spanner/mutation.go +++ b/spanner/mutation.go @@ -141,7 +141,7 @@ type Mutation struct { values []interface{} } -// A group of mutations to be committed atomically. +// A MutationGroup is a list of mutations to be committed atomically. type MutationGroup struct { // The mutations in this group mutations []*Mutation From 94a921ae7705985d40068f93550b7606a75633d7 Mon Sep 17 00:00:00 2001 From: Sunny Singh <126051413+sunnsing-google@users.noreply.github.com> Date: Thu, 19 Oct 2023 20:34:59 +0530 Subject: [PATCH 3/5] Update comment Co-authored-by: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com> --- spanner/mutation.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spanner/mutation.go b/spanner/mutation.go index 03718c3c3ffd..80a992e6c2cf 100644 --- a/spanner/mutation.go +++ b/spanner/mutation.go @@ -141,7 +141,7 @@ type Mutation struct { values []interface{} } -// A MutationGroup is a list of mutations to be committed atomically. +// A MutationGroup is a list of Mutation to be committed atomically. type MutationGroup struct { // The mutations in this group mutations []*Mutation From 5b9ce8a11ca9efc3976bb5ac579c132daf749c7d Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Mon, 23 Oct 2023 14:34:57 +0000 Subject: [PATCH 4/5] Add BatchWriteOptions --- spanner/client.go | 47 ++++++++++++++++++++++++++++++------- spanner/client_test.go | 18 +++++++------- spanner/integration_test.go | 2 +- 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/spanner/client.go b/spanner/client.go index 0dab33085f9f..c2e8447c155e 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -100,6 +100,7 @@ type Client struct { ro ReadOptions ao []ApplyOption txo TransactionOptions + bwo BatchWriteOptions ct *commonTags disableRouteToLeader bool } @@ -141,6 +142,9 @@ type ClientConfig struct { // TransactionOptions is the configuration for a transaction. TransactionOptions TransactionOptions + // BatchWriteOptions is the configuration for a BatchWrite request. + BatchWriteOptions BatchWriteOptions + // CallOptions is the configuration for providing custom retry settings that // override the default values. CallOptions *vkit.CallOptions @@ -284,6 +288,7 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf ro: config.ReadOptions, ao: config.ApplyOptions, txo: config.TransactionOptions, + bwo: config.BatchWriteOptions, ct: getCommonTags(sc), disableRouteToLeader: config.DisableRouteToLeader, } @@ -672,6 +677,31 @@ func (c *Client) Apply(ctx context.Context, ms []*Mutation, opts ...ApplyOption) return t.applyAtLeastOnce(ctx, ms...) } +// BatchWriteOptions provides options for a BatchWriteRequest. +type BatchWriteOptions struct { + // Priority is the RPC priority to use for this request. + Priority sppb.RequestOptions_Priority + + // The transaction tag to use for this request. + TransactionTag string +} + +// merge combines two BatchWriteOptions such that the input parameter will have higher +// order of precedence. +func (bwo BatchWriteOptions) merge(opts BatchWriteOptions) BatchWriteOptions { + merged := BatchWriteOptions{ + TransactionTag: bwo.TransactionTag, + Priority: bwo.Priority, + } + if opts.TransactionTag != "" { + merged.TransactionTag = opts.TransactionTag + } + if opts.Priority != sppb.RequestOptions_PRIORITY_UNSPECIFIED { + merged.Priority = opts.Priority + } + return merged +} + // BatchWriteResponseIterator is an iterator over BatchWriteResponse structures returned from BatchWrite RPC. type BatchWriteResponseIterator struct { ctx context.Context @@ -786,7 +816,12 @@ func (r *BatchWriteResponseIterator) Do(f func(r *sppb.BatchWriteResponse) error // timestamp-based keys, it may result in additional rows being added to the // mutation's table. We recommend structuring your mutation groups to be // idempotent to avoid this issue. -func (c *Client) BatchWrite(ctx context.Context, mgs []*MutationGroup, opts ...ApplyOption) *BatchWriteResponseIterator { +func (c *Client) BatchWrite(ctx context.Context, mgs []*MutationGroup) *BatchWriteResponseIterator { + return c.BatchWriteWithOptions(ctx, mgs, BatchWriteOptions{}) +} + +// BatchWriteWithOptions is same as BatchWrite. It accepts additional options to customize the request. +func (c *Client) BatchWriteWithOptions(ctx context.Context, mgs []*MutationGroup, opts BatchWriteOptions) *BatchWriteResponseIterator { ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchWrite") var err error @@ -794,13 +829,7 @@ func (c *Client) BatchWrite(ctx context.Context, mgs []*MutationGroup, opts ...A trace.EndSpan(ctx, err) }() - ao := &applyOption{} - for _, opt := range c.ao { - opt(ao) - } - for _, opt := range opts { - opt(ao) - } + opts = c.bwo.merge(opts) mgsPb, err := mutationGroupsProto(mgs) if err != nil { @@ -818,7 +847,7 @@ func (c *Client) BatchWrite(ctx context.Context, mgs []*MutationGroup, opts ...A stream, rpcErr := sh.getClient().BatchWrite(contextWithOutgoingMetadata(ct, sh.getMetadata(), c.disableRouteToLeader), &sppb.BatchWriteRequest{ Session: sh.getID(), MutationGroups: mgsPb, - RequestOptions: createRequestOptions(ao.priority, "", ao.transactionTag), + RequestOptions: createRequestOptions(opts.Priority, "", opts.TransactionTag), }, gax.WithGRPCOptions(grpc.Header(&md))) if getGFELatencyMetricsFlag() && md != nil && c.ct != nil { diff --git a/spanner/client_test.go b/spanner/client_test.go index 5813542a4d5e..6abd4435b339 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -4423,32 +4423,32 @@ func checkBatchWriteForExpectedRequestOptions(t *testing.T, server InMemSpannerS } } -func TestClient_BatchWrite_ApplyOptions(t *testing.T) { +func TestClient_BatchWrite_Options(t *testing.T) { t.Parallel() testcases := []struct { name string - client []ApplyOption - write []ApplyOption + client BatchWriteOptions + write BatchWriteOptions wantTransactionTag string wantPriority sppb.RequestOptions_Priority }{ { name: "Client level", - client: []ApplyOption{TransactionTag("testTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + client: BatchWriteOptions{TransactionTag: "testTransactionTag", Priority: sppb.RequestOptions_PRIORITY_LOW}, wantTransactionTag: "testTransactionTag", wantPriority: sppb.RequestOptions_PRIORITY_LOW, }, { name: "Write level", - write: []ApplyOption{TransactionTag("testTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, + write: BatchWriteOptions{TransactionTag: "testTransactionTag", Priority: sppb.RequestOptions_PRIORITY_LOW}, wantTransactionTag: "testTransactionTag", wantPriority: sppb.RequestOptions_PRIORITY_LOW, }, { name: "Write level has precedence over client level", - client: []ApplyOption{TransactionTag("clientTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_LOW)}, - write: []ApplyOption{TransactionTag("writeTransactionTag"), Priority(sppb.RequestOptions_PRIORITY_MEDIUM)}, + client: BatchWriteOptions{TransactionTag: "clientTransactionTag", Priority: sppb.RequestOptions_PRIORITY_LOW}, + write: BatchWriteOptions{TransactionTag: "writeTransactionTag", Priority: sppb.RequestOptions_PRIORITY_MEDIUM}, wantTransactionTag: "writeTransactionTag", wantPriority: sppb.RequestOptions_PRIORITY_MEDIUM, }, @@ -4456,7 +4456,7 @@ func TestClient_BatchWrite_ApplyOptions(t *testing.T) { for _, tt := range testcases { t.Run(tt.name, func(t *testing.T) { - server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ApplyOptions: tt.client}) + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{BatchWriteOptions: tt.client}) defer teardown() mutationGroups := []*MutationGroup{ @@ -4464,7 +4464,7 @@ func TestClient_BatchWrite_ApplyOptions(t *testing.T) { {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, }}, } - iter := client.BatchWrite(context.Background(), mutationGroups, tt.write...) + iter := client.BatchWriteWithOptions(context.Background(), mutationGroups, tt.write) doFunc := func(r *sppb.BatchWriteResponse) error { return nil } diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 9ac1fdec86d9..7ad2b82268f0 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -945,7 +945,7 @@ func TestIntegration_BatchWrite(t *testing.T) { } return nil } - iter := client.BatchWrite(ctx, mgs, TransactionTag("golang-test-batch-write")) + iter := client.BatchWrite(ctx, mgs) if err := iter.Do(validate); err != nil { t.Fatal(err) } From e9feab5ab43a58e65f0ce568dd83b0d9d8aa2db9 Mon Sep 17 00:00:00 2001 From: Sunny Singh Date: Thu, 26 Oct 2023 18:28:59 +0000 Subject: [PATCH 5/5] remove redundant int conversion --- spanner/client_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spanner/client_test.go b/spanner/client_test.go index 6abd4435b339..da3309adccb5 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -4317,7 +4317,7 @@ func TestClient_BatchWrite(t *testing.T) { }}, } iter := client.BatchWrite(context.Background(), mutationGroups) - responseCount := int(0) + responseCount := 0 doFunc := func(r *sppb.BatchWriteResponse) error { responseCount++ return nil @@ -4352,7 +4352,7 @@ func TestClient_BatchWrite_SessionNotFound(t *testing.T) { }}, } iter := client.BatchWrite(context.Background(), mutationGroups) - responseCount := int(0) + responseCount := 0 doFunc := func(r *sppb.BatchWriteResponse) error { responseCount++ return nil @@ -4390,7 +4390,7 @@ func TestClient_BatchWrite_Error(t *testing.T) { }}, } iter := client.BatchWrite(context.Background(), mutationGroups) - responseCount := int(0) + responseCount := 0 doFunc := func(r *sppb.BatchWriteResponse) error { responseCount++ return nil