diff --git a/spanner/client.go b/spanner/client.go index 3560ed3baacf..a95b02431ad4 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -579,6 +579,7 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea t.txReadOnly.qo = c.qo t.txReadOnly.ro = c.ro t.txReadOnly.disableRouteToLeader = c.disableRouteToLeader + t.wb = []*Mutation{} t.txOpts = c.txo.merge(options) t.ct = c.ct diff --git a/spanner/client_test.go b/spanner/client_test.go index 370c41560ea3..6f963102805e 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -811,6 +811,308 @@ func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { } } +func TestClient_ReadWriteTransaction_BufferedWriteBeforeAbortedFirstSqlStatement(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}) + + var attempts int + expectedAttempts := 2 + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + // Buffer mutations before executing a SQL statement. + if err := tx.BufferWrite([]*Mutation{ + Insert("foo", []string{"col1"}, []interface{}{"key1"}), + }); err != nil { + return err + } + // Then execute a SQL statement that will return Aborted from the backend. + // This will force a retry of the transaction with an explicit BeginTransaction RPC. + c, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo)) + if err != nil { + return err + } + if g, w := c, int64(UpdateBarSetFooRowCount); g != w { + return fmt.Errorf("update count mismatch\nGot: %v\nWant: %v", g, w) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + commit := requests[len(requests)-1].(*sppb.CommitRequest) + g, w := len(commit.Mutations), 1 + if g != w { + t.Fatalf("mutations count mismatch\nGot: %v\nWant: %v", g, w) + } +} + +func TestClient_ReadWriteTransaction_BufferedWriteBeforeAbortedFirstSqlStatementTwice(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{Errors: []error{ + status.Error(codes.Aborted, "Transaction aborted"), + status.Error(codes.Aborted, "Transaction aborted"), + }}) + + var attempts int + expectedAttempts := 3 + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + // Buffer mutations before executing a SQL statement. + if err := tx.BufferWrite([]*Mutation{ + Insert("foo", []string{"col1"}, []interface{}{"key1"}), + }); err != nil { + return err + } + // Then execute a SQL statement that will return Aborted from the backend. + // This will force a retry of the transaction with an explicit BeginTransaction RPC. + c, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo)) + if err != nil { + return err + } + if g, w := c, int64(UpdateBarSetFooRowCount); g != w { + return fmt.Errorf("update count mismatch\nGot: %v\nWant: %v", g, w) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + commit := requests[len(requests)-1].(*sppb.CommitRequest) + g, w := len(commit.Mutations), 1 + if g != w { + t.Fatalf("mutations count mismatch\nGot: %v\nWant: %v", g, w) + } +} + +func TestClient_ReadWriteTransaction_BufferedWriteBeforeSqlStatementWithError(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{Errors: []error{ + status.Error(codes.InvalidArgument, "Invalid"), + status.Error(codes.InvalidArgument, "Invalid"), + }}) + + var attempts int + expectedAttempts := 2 + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + // Buffer mutations before executing a SQL statement. + if err := tx.BufferWrite([]*Mutation{ + Insert("foo", []string{"col1"}, []interface{}{"key1"}), + }); err != nil { + return err + } + // Then execute a SQL statement that will return InvalidArgument from the backend. + // This will initially force a retry of the transaction with an explicit BeginTransaction RPC. + // We ignore the error and proceed to commit the transaction. + _, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo)) + if err == nil { + return fmt.Errorf("missing expected InvalidArgument error") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + commit := requests[len(requests)-1].(*sppb.CommitRequest) + g, w := len(commit.Mutations), 1 + if g != w { + t.Fatalf("mutations count mismatch\nGot: %v\nWant: %v", g, w) + } +} + +func TestClient_ReadWriteTransaction_BufferedWriteBeforeSqlStatementWithErrorThatGoesAway(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{Errors: []error{ + status.Error(codes.AlreadyExists, "Row already exists"), + }}) + + var attempts int + expectedAttempts := 2 + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + // Buffer mutations before executing a SQL statement. + if err := tx.BufferWrite([]*Mutation{ + Insert("foo", []string{"col1"}, []interface{}{"key1"}), + }); err != nil { + return err + } + // Then execute a SQL statement that will return InvalidArgument from the backend. + // This will initially force a retry of the transaction with an explicit BeginTransaction RPC. + // The error does not occur during the retry and the transaction is allowed to continue. + _, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo)) + if err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + commit := requests[len(requests)-1].(*sppb.CommitRequest) + g, w := len(commit.Mutations), 1 + if g != w { + t.Fatalf("mutations count mismatch\nGot: %v\nWant: %v", g, w) + } +} + +func TestClient_ReadWriteTransaction_OnlyBufferWritesDuringInitialAttempt(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{Errors: []error{ + status.Error(codes.AlreadyExists, "Row already exists"), + }}) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + if attempts == 1 { + // Only do a blind write if it is not a retry of the transaction. + if err := tx.BufferWrite([]*Mutation{ + Delete("foo", AllKeys()), + }); err != nil { + return err + } + } + // Then execute a SQL statement that will return InvalidArgument from the backend. + // This will initially force a retry of the transaction with an explicit BeginTransaction RPC. + // The error does not occur during the retry and the transaction is allowed to continue. + _, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo)) + if err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + commit := requests[len(requests)-1].(*sppb.CommitRequest) + g, w := len(commit.Mutations), 0 + if g != w { + t.Fatalf("mutations count mismatch\nGot: %v\nWant: %v", g, w) + } +} + +func TestClient_ReadWriteTransaction_BlindWriteWithAbortedCommit(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}) + + var attempts int + expectedAttempts := 2 + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + // Do a blind write and then commit. The CommitRequest will be aborted and cause the transaction to retry. + if err := tx.BufferWrite([]*Mutation{Insert("foo", []string{"col1"}, []interface{}{"key1"})}); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.CommitRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + commit := requests[len(requests)-1].(*sppb.CommitRequest) + // TODO: Update to 1 when the bug is fixed + g, w := len(commit.Mutations), 1 + if g != w { + t.Fatalf("mutations count mismatch\nGot: %v\nWant: %v", g, w) + } +} + func TestClient_ReadWriteTransaction_SessionNotFoundOnCommit(t *testing.T) { t.Parallel() if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{