Skip to content

Commit

Permalink
feat: add SQL statements for max_commit_delay
Browse files Browse the repository at this point in the history
  • Loading branch information
olavloite committed Jan 20, 2025
1 parent b9da4a8 commit 7a0f08f
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 23 deletions.
52 changes: 42 additions & 10 deletions client_side_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ func (s *statementExecutor) ShowExcludeTxnFromChangeStreams(_ context.Context, c
return &rows{it: it}, nil
}

func (s *statementExecutor) ShowMaxCommitDelay(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Rows, error) {
it, err := createStringIterator("MaxCommitDelay", c.MaxCommitDelay().String())
if err != nil {
return nil, err
}
return &rows{it: it}, nil
}

func (s *statementExecutor) ShowTransactionTag(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Rows, error) {
it, err := createStringIterator("TransactionTag", c.TransactionTag())
if err != nil {
Expand Down Expand Up @@ -163,6 +171,16 @@ func (s *statementExecutor) SetExcludeTxnFromChangeStreams(_ context.Context, c
return c.setExcludeTxnFromChangeStreams(exclude)
}

var maxCommitDelayRegexp = regexp.MustCompile(`(?i)^\s*('(?P<duration>(\d{1,19})(s|ms|us|ns))'|(?P<number>\d{1,19})|(?P<null>NULL))\s*$`)

func (s *statementExecutor) SetMaxCommitDelay(_ context.Context, c *conn, params string, _ []driver.NamedValue) (driver.Result, error) {
duration, err := parseDuration(maxCommitDelayRegexp, "max_commit_delay", params)
if err != nil {
return nil, err
}
return c.setMaxCommitDelay(duration)
}

func (s *statementExecutor) SetTransactionTag(_ context.Context, c *conn, params string, _ []driver.NamedValue) (driver.Result, error) {
tag, err := parseTag(params)
if err != nil {
Expand Down Expand Up @@ -209,13 +227,13 @@ func (s *statementExecutor) SetReadOnlyStaleness(_ context.Context, c *conn, par
if strongRegexp.MatchString(params) {
staleness = spanner.StrongRead()
} else if exactStalenessRegexp.MatchString(params) {
d, err := parseDuration(exactStalenessRegexp, params)
d, err := parseDuration(exactStalenessRegexp, "staleness", params)
if err != nil {
return nil, err
}
staleness = spanner.ExactStaleness(d)
} else if maxStalenessRegexp.MatchString(params) {
d, err := parseDuration(maxStalenessRegexp, params)
d, err := parseDuration(maxStalenessRegexp, "staleness", params)
if err != nil {
return nil, err
}
Expand All @@ -238,16 +256,27 @@ func (s *statementExecutor) SetReadOnlyStaleness(_ context.Context, c *conn, par
return c.setReadOnlyStaleness(staleness)
}

func parseDuration(re *regexp.Regexp, params string) (time.Duration, error) {
func parseDuration(re *regexp.Regexp, name, params string) (time.Duration, error) {
matches := matchesToMap(re, params)
if matches["duration"] == "" {
return 0, spanner.ToSpannerError(status.Error(codes.InvalidArgument, "No duration found in staleness string"))
if matches["duration"] == "" && matches["number"] == "" && matches["null"] == "" {
return 0, spanner.ToSpannerError(status.Error(codes.InvalidArgument, fmt.Sprintf("No duration found in %s string: %v", name, params)))
}
d, err := time.ParseDuration(matches["duration"])
if err != nil {
return 0, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "Invalid duration: %s", matches["duration"]))
if matches["duration"] != "" {
d, err := time.ParseDuration(matches["duration"])
if err != nil {
return 0, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "Invalid duration: %s", matches["duration"]))
}
return d, nil
} else if matches["number"] != "" {
d, err := strconv.Atoi(matches["number"])
if err != nil {
return 0, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "Invalid duration: %s", matches["number"]))
}
return time.Millisecond * time.Duration(d), nil
} else if matches["null"] != "" {
return time.Duration(0), nil
}
return d, nil
return time.Duration(0), spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "Unrecognized duration: %s", params))
}

func parseTimestamp(re *regexp.Regexp, params string) (time.Time, error) {
Expand All @@ -263,8 +292,11 @@ func parseTimestamp(re *regexp.Regexp, params string) (time.Time, error) {
}

func matchesToMap(re *regexp.Regexp, s string) map[string]string {
match := re.FindStringSubmatch(s)
matches := make(map[string]string)
match := re.FindStringSubmatch(s)
if match == nil {
return matches
}
for i, name := range re.SubexpNames() {
if i != 0 && name != "" {
matches[name] = match[i]
Expand Down
60 changes: 60 additions & 0 deletions client_side_statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,66 @@ func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) {
}
}

func TestStatementExecutor_MaxCommitDelay(t *testing.T) {
c := &conn{logger: noopLogger}
s := &statementExecutor{}
ctx := context.Background()
for i, test := range []struct {
wantValue time.Duration
setValue string
wantSetErr bool
}{
{time.Second, "'1s'", false},
{10 * time.Millisecond, "'10ms'", false},
{20 * time.Microsecond, "'20us'", false},
{30 * time.Nanosecond, "'30ns'", false},
{time.Duration(0), "NULL", false},
{100 * time.Millisecond, "100", false},
{100 * time.Millisecond, "true", true},
{100 * time.Millisecond, "ms", true},
{100 * time.Millisecond, "'ms'", true},
{100 * time.Millisecond, "20ms", true},
{100 * time.Millisecond, "'10'", true},
{100 * time.Millisecond, "'10ms", true},
{100 * time.Millisecond, "10ms'", true},
} {
res, err := s.SetMaxCommitDelay(ctx, c, test.setValue, nil)
if test.wantSetErr {
if err == nil {
t.Fatalf("%d: missing expected error for value %q", i, test.setValue)
}
} else {
if err != nil {
t.Fatalf("%d: could not set new value %q for max_commit_delay: %v", i, test.setValue, err)
}
if res != driver.ResultNoRows {
t.Fatalf("%d: result mismatch\nGot: %v\nWant: %v", i, res, driver.ResultNoRows)
}
}

it, err := s.ShowMaxCommitDelay(ctx, c, "", nil)
if err != nil {
t.Fatalf("%d: could not get current max_commit_delay value from connection: %v", i, err)
}
cols := it.Columns()
wantCols := []string{"MaxCommitDelay"}
if !cmp.Equal(cols, wantCols) {
t.Fatalf("%d: column names mismatch\nGot: %v\nWant: %v", i, cols, wantCols)
}
values := make([]driver.Value, len(cols))
if err := it.Next(values); err != nil {
t.Fatalf("%d: failed to get first row for max_commit_delay: %v", i, err)
}
wantValues := []driver.Value{test.wantValue.String()}
if !cmp.Equal(values, wantValues) {
t.Fatalf("%d: max_commit_delay values mismatch\nGot: %v\nWant: %v", i, values, wantValues)
}
if err := it.Next(values); err != io.EOF {
t.Fatalf("%d: error mismatch\nGot: %v\nWant: %v", i, err, io.EOF)
}
}
}

func TestStatementExecutor_SetTransactionTag(t *testing.T) {
ctx := context.Background()
for i, test := range []struct {
Expand Down
36 changes: 36 additions & 0 deletions client_side_statements_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ var jsonFile = `{
"method": "statementShowExcludeTxnFromChangeStreams",
"exampleStatements": ["show variable exclude_txn_from_change_streams"]
},
{
"name": "SHOW VARIABLE MAX_COMMIT_DELAY",
"executorName": "ClientSideStatementNoParamExecutor",
"resultType": "RESULT_SET",
"statementType": "SHOW_MAX_COMMIT_DELAY",
"regex": "(?is)\\A\\s*show\\s+variable\\s+max_commit_delay\\s*\\z",
"method": "statementShowMaxCommitDelay",
"exampleStatements": ["show variable max_commit_delay"]
},
{
"name": "SHOW VARIABLE TRANSACTION_TAG",
"executorName": "ClientSideStatementNoParamExecutor",
Expand Down Expand Up @@ -184,6 +193,33 @@ var jsonFile = `{
"converterName": "ClientSideStatementValueConverters$BooleanConverter"
}
},
{
"name": "SET MAX_COMMIT_DELAY = '<duration>'|NULL",
"executorName": "ClientSideStatementSetExecutor",
"resultType": "NO_RESULT",
"statementType": "SET_MAX_COMMIT_DELAY",
"regex": "(?is)\\A\\s*set\\s+max_commit_delay\\s*(?:=)\\s*(.*)\\z",
"method": "statementSetMaxCommitDelay",
"exampleStatements": [
"set max_commit_delay=null",
"set max_commit_delay = null",
"set max_commit_delay = null ",
"set max_commit_delay=1000",
"set max_commit_delay = 1000",
"set max_commit_delay = 1000 ",
"set max_commit_delay='1s'",
"set max_commit_delay = '1s'",
"set max_commit_delay = '1s' ",
"set max_commit_delay='100ms'",
"set max_commit_delay='10000us'",
"set max_commit_delay='9223372036854775807ns'"],
"setStatement": {
"propertyName": "MAX_COMMIT_DELAY",
"separator": "=",
"allowedValues": "('(\\d{1,19})(s|ms|us|ns)'|\\d{1,19}|NULL)",
"converterName": "ClientSideStatementValueConverters$DurationConverter"
}
},
{
"name": "SET TRANSACTION_TAG = '<tag>'",
"executorName": "ClientSideStatementSetExecutor",
Expand Down
19 changes: 10 additions & 9 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -783,12 +783,10 @@ type conn struct {
autocommitDMLMode AutocommitDMLMode
// readOnlyStaleness is used for queries in autocommit mode and for read-only transactions.
readOnlyStaleness spanner.TimestampBound
// maxCommitDelay is applied to commit requests both in autocommit and read/write transactions.
maxCommitDelay time.Duration

// execOptions are applied to the next statement that is executed on this connection.
// It can be set by passing it in as an argument to ExecContext or QueryContext
// and is cleared after each execution.
// execOptions are applied to the next statement or transaction that is executed
// on this connection. It can also be set by passing it in as an argument to
// ExecContext or QueryContext.
execOptions ExecOptions
}

Expand Down Expand Up @@ -882,7 +880,7 @@ func (c *conn) setReadOnlyStaleness(staleness spanner.TimestampBound) (driver.Re
}

func (c *conn) MaxCommitDelay() time.Duration {
return c.maxCommitDelay
return *c.execOptions.TransactionOptions.CommitOptions.MaxCommitDelay
}

func (c *conn) SetMaxCommitDelay(delay time.Duration) error {
Expand All @@ -891,7 +889,7 @@ func (c *conn) SetMaxCommitDelay(delay time.Duration) error {
}

func (c *conn) setMaxCommitDelay(delay time.Duration) (driver.Result, error) {
c.maxCommitDelay = delay
c.execOptions.TransactionOptions.CommitOptions.MaxCommitDelay = &delay
return driver.ResultNoRows, nil
}

Expand Down Expand Up @@ -1167,7 +1165,7 @@ func (c *conn) ResetSession(_ context.Context) error {
c.retryAborts = true
c.autocommitDMLMode = Transactional
c.readOnlyStaleness = spanner.TimestampBound{}
c.maxCommitDelay = time.Duration(0)
c.execOptions = ExecOptions{}
return nil
}

Expand Down Expand Up @@ -1413,7 +1411,10 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp

// options returns and resets the ExecOptions for the next statement.
func (c *conn) options() ExecOptions {
defer func() { c.execOptions = ExecOptions{} }()
defer func() {
c.execOptions.TransactionOptions.TransactionTag = ""
c.execOptions.QueryOptions.RequestTag = ""
}()
return c.execOptions
}

Expand Down
36 changes: 32 additions & 4 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,18 @@ func TestSimpleReadWriteTransaction(t *testing.T) {

db, server, teardown := setupTestDBConnection(t)
defer teardown()
tx, err := db.Begin()

ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
if _, err := conn.ExecContext(ctx, "set max_commit_delay='10ms'"); err != nil {
t.Fatal(err)
}

tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -398,6 +409,9 @@ func TestSimpleReadWriteTransaction(t *testing.T) {
if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) {
t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e)
}
if g, w := commitReq.MaxCommitDelay.Nanos, int32(time.Millisecond*10); g != w {
t.Fatalf("max_commit_delay mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestPreparedQuery(t *testing.T) {
Expand Down Expand Up @@ -1235,7 +1249,18 @@ func TestDmlInAutocommit(t *testing.T) {

db, server, teardown := setupTestDBConnection(t)
defer teardown()
res, err := db.ExecContext(context.Background(), testutil.UpdateBarSetFoo)
ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
_, err = conn.ExecContext(ctx, "set max_commit_delay=100")
if err != nil {
t.Fatal(err)
}

res, err := conn.ExecContext(ctx, testutil.UpdateBarSetFoo)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1268,6 +1293,9 @@ func TestDmlInAutocommit(t *testing.T) {
if commitReq.GetTransactionId() == nil {
t.Fatalf("missing id selector for CommitRequest")
}
if g, w := commitReq.MaxCommitDelay.Nanos, int32(time.Millisecond*100); g != w {
t.Fatalf("max_commit_delay mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestQueryWithDuplicateNamedParameter(t *testing.T) {
Expand Down Expand Up @@ -2684,11 +2712,11 @@ func TestExcludeTxnFromChangeStreams_Transaction(t *testing.T) {
t.Fatalf("missing ExcludeTxnFromChangeStreams option on BeginTransaction option")
}

// Verify that the flag is reset after the transaction.
// Verify that the flag is NOT reset after the transaction.
if err := conn.QueryRowContext(ctx, "SHOW VARIABLE EXCLUDE_TXN_FROM_CHANGE_STREAMS").Scan(&exclude); err != nil {
t.Fatalf("failed to get exclude setting: %v", err)
}
if g, w := exclude, false; g != w {
if g, w := exclude, true; g != w {
t.Fatalf("exclude_txn_from_change_streams mismatch\n Got: %v\nWant: %v", g, w)
}
}
Expand Down

0 comments on commit 7a0f08f

Please sign in to comment.