diff --git a/aborted_transactions_test.go b/aborted_transactions_test.go index c5dbcc5d..01d62409 100644 --- a/aborted_transactions_test.go +++ b/aborted_transactions_test.go @@ -19,6 +19,7 @@ import ( "database/sql" "reflect" "testing" + "time" "cloud.google.com/go/spanner" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" @@ -31,10 +32,11 @@ import ( func TestCommitAborted(t *testing.T) { t.Parallel() - db, server, teardown := setupTestDBConnection(t) + db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=1;maxSessions=1") defer teardown() - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() tx, err := db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { t.Fatalf("begin failed: %v", err) @@ -56,10 +58,11 @@ func TestCommitAborted(t *testing.T) { func TestCommitAbortedWithInternalRetriesDisabled(t *testing.T) { t.Parallel() - db, server, teardown := setupTestDBConnectionWithParams(t, "retryAbortsInternally=false") + db, server, teardown := setupTestDBConnectionWithParams(t, "retryAbortsInternally=false;minSessions=1;maxSessions=1") defer teardown() - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() tx, err := db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { t.Fatalf("begin failed: %v", err) @@ -82,10 +85,11 @@ func TestCommitAbortedWithInternalRetriesDisabled(t *testing.T) { func TestUpdateAborted(t *testing.T) { t.Parallel() - db, server, teardown := setupTestDBConnection(t) + db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=1;maxSessions=1") defer teardown() - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() tx, err := db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { t.Fatalf("begin failed: %v", err) @@ -122,10 +126,11 @@ func TestUpdateAborted(t *testing.T) { func TestBatchUpdateAborted(t *testing.T) { t.Parallel() - db, server, teardown := setupTestDBConnection(t) + db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=1;maxSessions=1") defer teardown() - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() tx, err := db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { t.Fatalf("begin failed: %v", err) @@ -359,13 +364,14 @@ func testRetryReadWriteTransactionWithQuery(t *testing.T, setupServer func(serve t.Parallel() - db, server, teardown := setupTestDBConnection(t) + db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=1;maxSessions=1") defer teardown() if setupServer != nil { setupServer(server.TestSpanner) } - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() tx, err := db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { t.Fatalf("begin failed: %v", err) diff --git a/transaction.go b/transaction.go index f8cf1778..6e92b6a5 100644 --- a/transaction.go +++ b/transaction.go @@ -246,6 +246,10 @@ func (tx *readWriteTransaction) runWithRetry(ctx context.Context, f func(ctx con // retry retries the entire read/write transaction on a new Spanner transaction. // It will return ErrAbortedDueToConcurrentModification if the retry fails. func (tx *readWriteTransaction) retry(ctx context.Context) (err error) { + // TODO: This should use t.ResetForRetry(ctx) instead when that function is available. + if tx.rwTx != nil { + tx.rwTx.Rollback(tx.ctx) + } tx.rwTx, err = spanner.NewReadWriteStmtBasedTransaction(ctx, tx.client) if err != nil { return err