diff --git a/.github/workflows/integration-tests-on-production.yml b/.github/workflows/integration-tests-on-production.yml index 20891459..dbb7c19f 100644 --- a/.github/workflows/integration-tests-on-production.yml +++ b/.github/workflows/integration-tests-on-production.yml @@ -20,7 +20,7 @@ jobs: needs: [check-env] if: needs.check-env.outputs.has-key == 'true' runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 45 steps: - name: Install Go uses: actions/setup-go@v2 @@ -35,7 +35,7 @@ jobs: service_account_key: ${{ secrets.GCP_SA_KEY }} export_default_credentials: true - name: Run integration tests on production - run: go test -timeout 30m + run: go test -timeout 45m env: JOB_TYPE: test SPANNER_TEST_PROJECT: ${{ secrets.GCP_PROJECT_ID }} diff --git a/driver.go b/driver.go index 0729ac70..3989ecc6 100644 --- a/driver.go +++ b/driver.go @@ -23,6 +23,8 @@ import ( "regexp" "strconv" "strings" + "sync" + "sync/atomic" "time" "cloud.google.com/go/civil" @@ -55,11 +57,13 @@ var dsnRegExp = regexp.MustCompile("((?P[\\w.-]+(?:\\.[\\w\\.-]+)*[\\ var _ driver.DriverContext = &Driver{} func init() { - sql.Register("spanner", &Driver{}) + sql.Register("spanner", &Driver{connectors: make(map[string]*connector)}) } // Driver represents a Google Cloud Spanner database/sql driver. type Driver struct { + mu sync.Mutex + connectors map[string]*connector } // Open opens a connection to a Google Cloud Spanner database. @@ -132,6 +136,7 @@ func extractConnectorParams(paramsString string) (map[string]string, error) { type connector struct { driver *Driver + dsn string connectorConfig connectorConfig // spannerClientConfig represents the optional advanced configuration to be used @@ -146,9 +151,22 @@ type connector struct { // retried internally (when possible), or whether all aborted errors will be // propagated to the caller. This option is enabled by default. retryAbortsInternally bool + + initClient sync.Once + client *spanner.Client + clientErr error + adminClient *adminapi.DatabaseAdminClient + adminClientErr error + connCount int32 } func newConnector(d *Driver, dsn string) (*connector, error) { + d.mu.Lock() + defer d.mu.Unlock() + if c, ok := d.connectors[dsn]; ok { + return c, nil + } + connectorConfig, err := extractConnectorConfig(dsn) if err != nil { return nil, err @@ -174,13 +192,31 @@ func newConnector(d *Driver, dsn string) (*connector, error) { config := spanner.ClientConfig{ SessionPoolConfig: spanner.DefaultSessionPoolConfig, } - return &connector{ + if strval, ok := connectorConfig.params["minsessions"]; ok { + if val, err := strconv.ParseUint(strval, 10, 64); err == nil { + config.MinOpened = val + } + } + if strval, ok := connectorConfig.params["maxsessions"]; ok { + if val, err := strconv.ParseUint(strval, 10, 64); err == nil { + config.MaxOpened = val + } + } + if strval, ok := connectorConfig.params["writesessions"]; ok { + if val, err := strconv.ParseFloat(strval, 64); err == nil { + config.WriteSessions = val + } + } + c := &connector{ driver: d, + dsn: dsn, connectorConfig: connectorConfig, spannerClientConfig: config, options: opts, retryAbortsInternally: retryAbortsInternally, - }, nil + } + d.connectors[dsn] = c + return c, nil } func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { @@ -194,18 +230,22 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { c.connectorConfig.project, c.connectorConfig.instance, c.connectorConfig.database) - client, err := spanner.NewClientWithConfig(ctx, databaseName, c.spannerClientConfig, opts...) - if err != nil { - return nil, err - } - adminClient, err := adminapi.NewDatabaseAdminClient(ctx, opts...) - if err != nil { - return nil, err + c.initClient.Do(func() { + c.client, c.clientErr = spanner.NewClientWithConfig(ctx, databaseName, c.spannerClientConfig, opts...) + c.adminClient, c.adminClientErr = adminapi.NewDatabaseAdminClient(ctx, opts...) + }) + if c.clientErr != nil { + return nil, c.clientErr + } + if c.adminClientErr != nil { + return nil, c.adminClientErr } + atomic.AddInt32(&c.connCount, 1) return &conn{ - client: client, - adminClient: adminClient, + connector: c, + client: c.client, + adminClient: c.adminClient, database: databaseName, retryAborts: c.retryAbortsInternally, execSingleQuery: queryInSingleUse, @@ -215,7 +255,7 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { } func (c *connector) Driver() driver.Driver { - return &Driver{} + return c.driver } // SpannerConn is the public interface for the raw Spanner connection for the @@ -295,6 +335,7 @@ type SpannerConn interface { } type conn struct { + connector *connector closed bool client *spanner.Client adminClient *adminapi.DatabaseAdminClient @@ -787,8 +828,18 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } func (c *conn) Close() error { + // Check if this is the last open connection of the connector. + if count := atomic.AddInt32(&c.connector.connCount, -1); count > 0 { + return nil + } + + // This was the last connection. Remove the connector and close the Spanner clients. + c.connector.driver.mu.Lock() + delete(c.connector.driver.connectors, c.connector.dsn) + c.connector.driver.mu.Unlock() + c.client.Close() - return nil + return c.adminClient.Close() } func (c *conn) Begin() (driver.Tx, error) { diff --git a/driver_test.go b/driver_test.go index 3b22f098..c694e8b5 100644 --- a/driver_test.go +++ b/driver_test.go @@ -22,56 +22,70 @@ import ( "cloud.google.com/go/spanner" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/codes" ) func TestExtractDnsParts(t *testing.T) { tests := []struct { - input string - want connectorConfig - wantErr error + input string + wantConnectorConfig connectorConfig + wantSpannerConfig spanner.ClientConfig + wantErr error }{ { input: "projects/p/instances/i/databases/d", - want: connectorConfig{ + wantConnectorConfig: connectorConfig{ project: "p", instance: "i", database: "d", params: map[string]string{}, }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + }, }, { input: "projects/DEFAULT_PROJECT_ID/instances/test-instance/databases/test-database", - want: connectorConfig{ + wantConnectorConfig: connectorConfig{ project: "DEFAULT_PROJECT_ID", instance: "test-instance", database: "test-database", params: map[string]string{}, }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + }, }, { input: "localhost:9010/projects/p/instances/i/databases/d", - want: connectorConfig{ + wantConnectorConfig: connectorConfig{ host: "localhost:9010", project: "p", instance: "i", database: "d", params: map[string]string{}, }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + }, }, { input: "spanner.googleapis.com/projects/p/instances/i/databases/d", - want: connectorConfig{ + wantConnectorConfig: connectorConfig{ host: "spanner.googleapis.com", project: "p", instance: "i", database: "d", params: map[string]string{}, }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + }, }, { input: "spanner.googleapis.com/projects/p/instances/i/databases/d?usePlainText=true", - want: connectorConfig{ + wantConnectorConfig: connectorConfig{ host: "spanner.googleapis.com", project: "p", instance: "i", @@ -80,10 +94,13 @@ func TestExtractDnsParts(t *testing.T) { "useplaintext": "true", }, }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + }, }, { input: "spanner.googleapis.com/projects/p/instances/i/databases/d;credentials=/path/to/credentials.json", - want: connectorConfig{ + wantConnectorConfig: connectorConfig{ host: "spanner.googleapis.com", project: "p", instance: "i", @@ -92,10 +109,13 @@ func TestExtractDnsParts(t *testing.T) { "credentials": "/path/to/credentials.json", }, }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + }, }, { input: "spanner.googleapis.com/projects/p/instances/i/databases/d?credentials=/path/to/credentials.json;readonly=true", - want: connectorConfig{ + wantConnectorConfig: connectorConfig{ host: "spanner.googleapis.com", project: "p", instance: "i", @@ -105,10 +125,13 @@ func TestExtractDnsParts(t *testing.T) { "readonly": "true", }, }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + }, }, { input: "spanner.googleapis.com/projects/p/instances/i/databases/d?usePlainText=true;", - want: connectorConfig{ + wantConnectorConfig: connectorConfig{ host: "spanner.googleapis.com", project: "p", instance: "i", @@ -117,6 +140,35 @@ func TestExtractDnsParts(t *testing.T) { "useplaintext": "true", }, }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + }, + }, + { + input: "spanner.googleapis.com/projects/p/instances/i/databases/d?minSessions=200;maxSessions=1000;writeSessions=0.5", + wantConnectorConfig: connectorConfig{ + host: "spanner.googleapis.com", + project: "p", + instance: "i", + database: "d", + params: map[string]string{ + "minsessions": "200", + "maxsessions": "1000", + "writesessions": "0.5", + }, + }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.SessionPoolConfig{ + MinOpened: 200, + MaxOpened: 1000, + WriteSessions: 0.5, + HealthCheckInterval: spanner.DefaultSessionPoolConfig.HealthCheckInterval, + HealthCheckWorkers: spanner.DefaultSessionPoolConfig.HealthCheckWorkers, + MaxBurst: spanner.DefaultSessionPoolConfig.MaxBurst, + MaxIdle: spanner.DefaultSessionPoolConfig.MaxIdle, + TrackSessionHandles: spanner.DefaultSessionPoolConfig.TrackSessionHandles, + }, + }, }, } for _, tc := range tests { @@ -124,8 +176,15 @@ func TestExtractDnsParts(t *testing.T) { if err != nil { t.Errorf("extract failed for %q: %v", tc.input, err) } else { - if !cmp.Equal(config, tc.want, cmp.AllowUnexported(connectorConfig{})) { - t.Errorf("connector config mismatch for %q\ngot: %v\nwant %v", tc.input, config, tc.want) + if !cmp.Equal(config, tc.wantConnectorConfig, cmp.AllowUnexported(connectorConfig{})) { + t.Errorf("connector config mismatch for %q\ngot: %v\nwant %v", tc.input, config, tc.wantConnectorConfig) + } + conn, err := newConnector(&Driver{connectors: make(map[string]*connector)}, tc.input) + if err != nil { + t.Errorf("failed to get connector for %q: %v", tc.input, err) + } + if !cmp.Equal(conn.spannerClientConfig, tc.wantSpannerConfig, cmpopts.IgnoreUnexported(spanner.ClientConfig{}, spanner.SessionPoolConfig{})) { + t.Errorf("connector Spanner client config mismatch for %q\n Got: %v\nWant: %v", tc.input, conn.spannerClientConfig, tc.wantSpannerConfig) } } } diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 2088d101..c4dea3ea 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -22,7 +22,9 @@ import ( "encoding/json" "fmt" "math/big" + "math/rand" "reflect" + "sync" "testing" "time" @@ -1860,6 +1862,202 @@ func TestShowVariableCommitTimestamp(t *testing.T) { } } +func TestMinSessions(t *testing.T) { + t.Parallel() + + minSessions := int32(10) + ctx := context.Background() + db, server, teardown := setupTestDBConnectionWithParams(t, fmt.Sprintf("minSessions=%v", minSessions)) + defer teardown() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + var res int64 + if err := conn.QueryRowContext(ctx, "SELECT 1").Scan(&res); err != nil { + t.Fatalf("failed to execute query on connection: %v", err) + } + // Wait until all sessions have been created. + waitFor(t, func() error { + created := int32(server.TestSpanner.TotalSessionsCreated()) + if created != minSessions { + return fmt.Errorf("num open sessions mismatch\n Got: %d\nWant: %d", created, minSessions) + } + return nil + }) + _ = conn.Close() + _ = db.Close() + + // Verify that the connector created 10 sessions on the server. + reqs := drainRequestsFromServer(server.TestSpanner) + createReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.BatchCreateSessionsRequest{})) + numCreated := int32(0) + for _, req := range createReqs { + numCreated += req.(*sppb.BatchCreateSessionsRequest).SessionCount + } + if g, w := numCreated, minSessions; g != w { + t.Errorf("session creation count mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestMaxSessions(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=0;maxSessions=2") + defer teardown() + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := db.Conn(ctx) + if err != nil { + t.Errorf("failed to get a connection: %v", err) + } + var res int64 + if err := conn.QueryRowContext(ctx, "SELECT 1").Scan(&res); err != nil { + t.Errorf("failed to execute query on connection: %v", err) + } + _ = conn.Close() + }() + } + wg.Wait() + + // Verify that the connector only created 2 sessions on the server. + reqs := drainRequestsFromServer(server.TestSpanner) + createReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.BatchCreateSessionsRequest{})) + numCreated := int32(0) + for _, req := range createReqs { + numCreated += req.(*sppb.BatchCreateSessionsRequest).SessionCount + } + if g, w := numCreated, int32(2); g != w { + t.Errorf("session creation count mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestClientReuse(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=2") + defer teardown() + + // Repeatedly get a connection and close it using the same DB instance. These + // connections should all share the same Spanner client, and only initialized + // one session pool. + for i := 0; i < 5; i++ { + conn, err := db.Conn(ctx) + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + var res int64 + if err := conn.QueryRowContext(ctx, "SELECT 1").Scan(&res); err != nil { + t.Fatalf("failed to execute query on connection: %v", err) + } + _ = conn.Close() + } + // Verify that the connector only created 2 sessions on the server. + reqs := drainRequestsFromServer(server.TestSpanner) + createReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.BatchCreateSessionsRequest{})) + numCreated := int32(0) + for _, req := range createReqs { + numCreated += req.(*sppb.BatchCreateSessionsRequest).SessionCount + } + if g, w := numCreated, int32(2); g != w { + t.Errorf("session creation count mismatch\n Got: %v\nWant: %v", g, w) + } + + // Now close the DB instance and create a new DB connection. + // This should cause the first Spanner client to be closed and + // a new one to be opened. + _ = db.Close() + + db, err := sql.Open( + "spanner", + fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;minSessions=2", server.Address)) + if err != nil { + t.Fatalf("failed to open new DB instance: %v", err) + } + var res int64 + if err := db.QueryRowContext(ctx, "SELECT 1").Scan(&res); err != nil { + t.Fatalf("failed to execute query on db: %v", err) + } + reqs = drainRequestsFromServer(server.TestSpanner) + createReqs = requestsOfType(reqs, reflect.TypeOf(&sppb.BatchCreateSessionsRequest{})) + numCreated = int32(0) + for _, req := range createReqs { + numCreated += req.(*sppb.BatchCreateSessionsRequest).SessionCount + } + if g, w := numCreated, int32(2); g != w { + t.Errorf("session creation count mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestStressClientReuse(t *testing.T) { + t.Parallel() + + ctx := context.Background() + _, server, teardown := setupTestDBConnection(t) + defer teardown() + + rand.Seed(time.Now().UnixNano()) + numSessions := 10 + numClients := 5 + numParallel := 50 + var wg sync.WaitGroup + for clientIndex := 0; clientIndex < numClients; clientIndex++ { + // Open a DB using a dsn that contains a meaningless number. This will ensure that + // the underlying client will be different from the other connections that use a + // different number. + db, err := sql.Open("spanner", + fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;minSessions=%v;maxSessions=%v;randomNumber=%v", server.Address, numSessions, numSessions, clientIndex)) + if err != nil { + t.Fatalf("failed to open DB: %v", err) + } + // Execute random operations in parallel on the database. + for i := 0; i < numParallel; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := db.Conn(ctx) + if err != nil { + t.Errorf("failed to get a connection: %v", err) + } + if rand.Int()%2 == 0 { + if _, err := conn.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Errorf("failed to execute update on connection: %v", err) + } + } else { + var res int64 + if err := conn.QueryRowContext(ctx, "SELECT 1").Scan(&res); err != nil { + t.Errorf("failed to execute query on connection: %v", err) + } + } + _ = conn.Close() + }() + } + } + wg.Wait() + + // Verify that each unique connection string created numSessions (10) sessions on the server. + reqs := drainRequestsFromServer(server.TestSpanner) + createReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.BatchCreateSessionsRequest{})) + numCreated := int32(0) + for _, req := range createReqs { + numCreated += req.(*sppb.BatchCreateSessionsRequest).SessionCount + } + if g, w := numCreated, int32(numSessions*numClients); g != w { + t.Errorf("session creation count mismatch\n Got: %v\nWant: %v", g, w) + } + sqlReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(sqlReqs), numClients*numParallel; g != w { + t.Errorf("ExecuteSql request count mismatch\n Got: %v\nWant: %v", g, w) + } +} + func numeric(v string) big.Rat { res, _ := big.NewRat(1, 1).SetString(v) return *res @@ -1970,3 +2168,28 @@ loop: } return reqs } + +func waitFor(t *testing.T, assert func() error) { + t.Helper() + timeout := 5 * time.Second + ta := time.After(timeout) + + for { + select { + case <-ta: + if err := assert(); err != nil { + t.Fatalf("after %v waiting, got %v", timeout, err) + } + return + default: + } + + if err := assert(); err != nil { + // Fail. Let's pause and retry. + time.Sleep(time.Millisecond) + continue + } + + return + } +} diff --git a/statement_parser.go b/statement_parser.go index 1de5a634..62f2324c 100644 --- a/statement_parser.go +++ b/statement_parser.go @@ -21,6 +21,7 @@ import ( "reflect" "regexp" "strings" + "sync" "unicode" "cloud.google.com/go/spanner" @@ -311,7 +312,9 @@ type setStatement struct { ConverterName string `json:"converterName"` } +var statementsInit sync.Once var statements *clientSideStatements +var statementsCompileErr error // compileStatements loads all client side statements from the json file and // assigns the Go methods to the different statements that should be executed @@ -367,10 +370,13 @@ func (c *executableClientSideStatement) QueryContext(ctx context.Context, args [ // corresponds with the given query string, or nil if it is not a valid client // side statement. func parseClientSideStatement(c *conn, query string) (*executableClientSideStatement, error) { - if statements == nil { + statementsInit.Do(func() { if err := compileStatements(); err != nil { - return nil, err + statementsCompileErr = err } + }) + if statementsCompileErr != nil { + return nil, statementsCompileErr } for _, stmt := range statements.Statements { if stmt.regexp.MatchString(query) {