Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner): add ResetForRetry method for stmt-based transactions #10956

Merged
merged 6 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion spanner/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ func stream(
rpc,
nil,
nil,
func(err error) error {
return err
},
setTimestamp,
release,
)
Expand All @@ -79,6 +82,7 @@ func streamWithReplaceSessionFunc(
rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error),
replaceSession func(ctx context.Context) error,
setTransactionID func(transactionID),
updateTxState func(err error) error,
setTimestamp func(time.Time),
release func(error),
) *RowIterator {
Expand All @@ -89,6 +93,7 @@ func streamWithReplaceSessionFunc(
streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession),
rowd: &partialResultSetDecoder{},
setTransactionID: setTransactionID,
updateTxState: updateTxState,
setTimestamp: setTimestamp,
release: release,
cancel: cancel,
Expand Down Expand Up @@ -127,6 +132,7 @@ type RowIterator struct {
streamd *resumableStreamDecoder
rowd *partialResultSetDecoder
setTransactionID func(transactionID)
updateTxState func(err error) error
setTimestamp func(time.Time)
release func(error)
cancel func()
Expand Down Expand Up @@ -214,7 +220,7 @@ func (r *RowIterator) Next() (*Row, error) {
return row, nil
}
if err := r.streamd.lastErr(); err != nil {
r.err = ToSpannerError(err)
r.err = r.updateTxState(ToSpannerError(err))
} else if !r.rowd.done() {
r.err = errEarlyReadEnd()
} else {
Expand Down
82 changes: 71 additions & 11 deletions spanner/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package spanner

import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -63,6 +64,12 @@ type txReadOnly struct {
// operations.
txReadEnv

// updateTxStateFunc is a function that updates the state of the current
// transaction based on the given error. This function is by default a no-op,
// but is overridden for read/write transactions to set the state to txAborted
// if Spanner aborts the transaction.
updateTxStateFunc func(err error) error

// Atomic. Only needed for DML statements, but used forall.
sequenceNumber int64

Expand Down Expand Up @@ -98,6 +105,13 @@ type txReadOnly struct {
otConfig *openTelemetryConfig
}

func (t *txReadOnly) updateTxState(err error) error {
if t.updateTxStateFunc == nil {
return err
}
return t.updateTxStateFunc(err)
}

// TransactionOptions provides options for a transaction.
type TransactionOptions struct {
CommitOptions CommitOptions
Expand Down Expand Up @@ -323,7 +337,7 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key
t.setTransactionID(nil)
return client, errInlineBeginTransactionFailed()
}
return client, err
return client, t.updateTxState(err)
}
md, err := client.Header()
if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
Expand All @@ -338,6 +352,9 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key
},
t.replaceSessionFunc,
setTransactionID,
func(err error) error {
return t.updateTxState(err)
},
t.setTimestamp,
t.release,
)
Expand Down Expand Up @@ -607,7 +624,7 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que
t.setTransactionID(nil)
return client, errInlineBeginTransactionFailed()
}
return client, err
return client, t.updateTxState(err)
}
md, err := client.Header()
if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
Expand All @@ -622,6 +639,9 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que
},
t.replaceSessionFunc,
setTransactionID,
func(err error) error {
return t.updateTxState(err)
},
t.setTimestamp,
t.release)
}
Expand Down Expand Up @@ -673,6 +693,8 @@ const (
txActive
// transaction is closed, cannot be used anymore.
txClosed
// transaction was aborted by Spanner and should be retried.
txAborted
)

// errRtsUnavailable returns error for read transaction's read timestamp being
Expand Down Expand Up @@ -1216,7 +1238,7 @@ func (t *ReadWriteTransaction) update(ctx context.Context, stmt Statement, opts
t.setTransactionID(nil)
return 0, errInlineBeginTransactionFailed()
}
return 0, ToSpannerError(err)
return 0, t.txReadOnly.updateTxState(ToSpannerError(err))
}
if hasInlineBeginTransaction {
if resultSet != nil && resultSet.GetMetadata() != nil && resultSet.GetMetadata().GetTransaction() != nil &&
Expand Down Expand Up @@ -1325,7 +1347,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts
t.setTransactionID(nil)
return nil, errInlineBeginTransactionFailed()
}
return nil, ToSpannerError(err)
return nil, t.txReadOnly.updateTxState(ToSpannerError(err))
}

haveTransactionID := false
Expand All @@ -1348,7 +1370,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts
return counts, errInlineBeginTransactionFailed()
}
if resp.Status != nil && resp.Status.Code != 0 {
return counts, spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message)
return counts, t.txReadOnly.updateTxState(spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message))
}
return counts, nil
}
Expand Down Expand Up @@ -1666,7 +1688,7 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions
trace.TracePrintf(ctx, nil, "Error in recording GFE Latency through OpenTelemetry. Error: %v", metricErr)
}
if e != nil {
return resp, toSpannerErrorWithCommitInfo(e, true)
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(e, true))
}
if tstamp := res.GetCommitTimestamp(); tstamp != nil {
resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
Expand Down Expand Up @@ -1758,6 +1780,7 @@ type ReadWriteStmtBasedTransaction struct {
// ReadWriteTransaction contains methods for performing transactional reads.
ReadWriteTransaction

client *Client
options TransactionOptions
}

Expand All @@ -1783,30 +1806,51 @@ func NewReadWriteStmtBasedTransaction(ctx context.Context, c *Client) (*ReadWrit
// used by the transaction will not be returned to the pool and cause a session
// leak.
//
// ResetForRetry resets the transaction before a retry attempt. This function
// returns a new transaction that should be used for the retry attempt. The
// transaction that is returned by this function is assigned a higher priority
// than the previous transaction, making it less probable to be aborted by
// Spanner again during the retry.
//
// NewReadWriteStmtBasedTransactionWithOptions is a configurable version of
// NewReadWriteStmtBasedTransaction.
func NewReadWriteStmtBasedTransactionWithOptions(ctx context.Context, c *Client, options TransactionOptions) (*ReadWriteStmtBasedTransaction, error) {
return newReadWriteStmtBasedTransactionWithSessionHandle(ctx, c, options, nil)
}

func newReadWriteStmtBasedTransactionWithSessionHandle(ctx context.Context, c *Client, options TransactionOptions, sh *sessionHandle) (*ReadWriteStmtBasedTransaction, error) {
var (
sh *sessionHandle
err error
t *ReadWriteStmtBasedTransaction
)
sh, err = c.idleSessions.take(ctx)
if err != nil {
// If session retrieval fails, just fail the transaction.
return nil, err
if sh == nil {
sh, err = c.idleSessions.take(ctx)
if err != nil {
// If session retrieval fails, just fail the transaction.
return nil, err
}
}
t = &ReadWriteStmtBasedTransaction{
ReadWriteTransaction: ReadWriteTransaction{
txReadyOrClosed: make(chan struct{}),
},
client: c,
}
t.txReadOnly.sp = c.idleSessions
t.txReadOnly.sh = sh
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
t.txReadOnly.ro = c.ro
t.txReadOnly.disableRouteToLeader = c.disableRouteToLeader
t.txReadOnly.updateTxStateFunc = func(err error) error {
if ErrCode(err) == codes.Aborted {
t.mu.Lock()
t.state = txAborted
t.mu.Unlock()
}
return err
}

t.txOpts = c.txo.merge(options)
t.ct = c.ct
t.otConfig = c.otConfig
Expand Down Expand Up @@ -1838,6 +1882,7 @@ func (t *ReadWriteStmtBasedTransaction) CommitWithReturnResp(ctx context.Context
}
if t.sh != nil {
t.sh.recycle()
t.sh = nil
}
return resp, err
}
Expand All @@ -1848,7 +1893,22 @@ func (t *ReadWriteStmtBasedTransaction) Rollback(ctx context.Context) {
t.rollback(ctx)
if t.sh != nil {
t.sh.recycle()
t.sh = nil
}
}

// ResetForRetry resets the transaction before a retry. This should be
// called if the transaction was aborted by Spanner and the application
// wants to retry the transaction.
// It is recommended to use this method above creating a new transaction,
// as this method will give the transaction a higher priority and thus a
// smaller probability of being aborted again by Spanner.
func (t *ReadWriteStmtBasedTransaction) ResetForRetry(ctx context.Context) (*ReadWriteStmtBasedTransaction, error) {
if t.state != txAborted {
return nil, fmt.Errorf("ResetForRetry should only be called on an active transaction that was aborted by Spanner")
}
// Create a new transaction that re-uses the current session if it is available.
return newReadWriteStmtBasedTransactionWithSessionHandle(ctx, t.client, t.options, t.sh)
}

// writeOnlyTransaction provides the most efficient way of doing write-only
Expand Down
104 changes: 102 additions & 2 deletions spanner/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,103 @@ func TestReadWriteStmtBasedTransaction_CommitAborted(t *testing.T) {
}
}

func TestReadWriteStmtBasedTransaction_QueryAborted(t *testing.T) {
t.Parallel()
rowCount, attempts, err := testReadWriteStmtBasedTransaction(t, map[string]SimulatedExecutionTime{
MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}},
})
if err != nil {
t.Fatalf("transaction failed to commit: %v", err)
}
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
t.Fatalf("Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
}
if g, w := attempts, 2; g != w {
t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
}
}

func TestReadWriteStmtBasedTransaction_UpdateAborted(t *testing.T) {
t.Parallel()
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
SessionPoolConfig: SessionPoolConfig{
// Use a session pool with size 1 to ensure that there are no session leaks.
MinOpened: 1,
MaxOpened: 1,
},
})
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteSql,
SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}})

ctx := context.Background()
tx, err := NewReadWriteStmtBasedTransaction(ctx, client)
if err != nil {
t.Fatal(err)
}
_, err = tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
if g, w := ErrCode(err), codes.Aborted; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}
tx, err = tx.ResetForRetry(ctx)
if err != nil {
t.Fatal(err)
}
c, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
if err != nil {
t.Fatal(err)
}
if g, w := c, int64(UpdateBarSetFooRowCount); g != w {
t.Fatalf("update count mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestReadWriteStmtBasedTransaction_BatchUpdateAborted(t *testing.T) {
t.Parallel()
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
SessionPoolConfig: SessionPoolConfig{
// Use a session pool with size 1 to ensure that there are no session leaks.
MinOpened: 1,
MaxOpened: 1,
},
})
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteBatchDml,
SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}})

ctx := context.Background()
tx, err := NewReadWriteStmtBasedTransaction(ctx, client)
if err != nil {
t.Fatal(err)
}
_, err = tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}})
if g, w := ErrCode(err), codes.Aborted; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}
tx, err = tx.ResetForRetry(ctx)
if err != nil {
t.Fatal(err)
}
c, err := tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}})
if err != nil {
t.Fatal(err)
}
if g, w := c, []int64{UpdateBarSetFooRowCount}; !reflect.DeepEqual(g, w) {
t.Fatalf("update count mismatch\n Got: %v\nWant: %v", g, w)
}
}

func testReadWriteStmtBasedTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime) (rowCount int64, attempts int, err error) {
server, client, teardown := setupMockedTestServer(t)
// server, client, teardown := setupMockedTestServer(t)
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
SessionPoolConfig: SessionPoolConfig{
// Use a session pool with size 1 to ensure that there are no session leaks.
MinOpened: 1,
MaxOpened: 1,
},
})
defer teardown()
for method, exec := range executionTimes {
server.TestSpanner.PutExecutionTime(method, exec)
Expand Down Expand Up @@ -500,9 +595,14 @@ func testReadWriteStmtBasedTransaction(t *testing.T, executionTimes map[string]S
return rowCount, nil
}

var tx *ReadWriteStmtBasedTransaction
for {
attempts++
tx, err := NewReadWriteStmtBasedTransaction(ctx, client)
if attempts > 1 {
tx, err = tx.ResetForRetry(ctx)
} else {
tx, err = NewReadWriteStmtBasedTransaction(ctx, client)
}
if err != nil {
return 0, attempts, fmt.Errorf("failed to begin a transaction: %v", err)
}
Expand Down
Loading