Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: reconnect after all idle connections close #290

Merged
merged 8 commits into from
Aug 29, 2024
108 changes: 84 additions & 24 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"cloud.google.com/go/civil"
Expand Down Expand Up @@ -156,6 +155,9 @@ type connector struct {
dsn string
connectorConfig connectorConfig

closerMu sync.RWMutex
closed bool

// spannerClientConfig represents the optional advanced configuration to be used
// by the Google Cloud Spanner client.
spannerClientConfig spanner.ClientConfig
Expand All @@ -169,7 +171,7 @@ type connector struct {
// propagated to the caller. This option is enabled by default.
retryAbortsInternally bool

initClient sync.Once
initClient sync.Mutex
client *spanner.Client
clientErr error
adminClient *adminapi.DatabaseAdminClient
Expand Down Expand Up @@ -264,6 +266,7 @@ func newConnector(d *Driver, dsn string) (*connector, error) {
}
}
config.UserAgent = userAgent

c := &connector{
driver: d,
dsn: dsn,
Expand All @@ -277,6 +280,11 @@ func newConnector(d *Driver, dsn string) (*connector, error) {
}

func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
c.closerMu.RLock()
defer c.closerMu.RUnlock()
if c.closed {
return nil, fmt.Errorf("connector has been closed")
}
return openDriverConn(ctx, c)
}

Expand All @@ -288,17 +296,10 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) {
c.connectorConfig.instance,
c.connectorConfig.database)

c.initClient.Do(func() {
c.client, c.clientErr = spanner.NewClientWithConfig(ctx, databaseName, c.spannerClientConfig, opts...)
c.adminClient, c.adminClientErr = adminapi.NewDatabaseAdminClient(ctx, opts...)
})
if c.clientErr != nil {
return nil, c.clientErr
}
if c.adminClientErr != nil {
return nil, c.adminClientErr
if err := c.increaseConnCount(ctx, databaseName, opts); err != nil {
return nil, err
}
atomic.AddInt32(&c.connCount, 1)

return &conn{
connector: c,
client: c.client,
Expand All @@ -311,10 +312,80 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) {
}, nil
}

// increaseConnCount initializes the client and increases the number of connections that are active.
func (c *connector) increaseConnCount(ctx context.Context, databaseName string, opts []option.ClientOption) error {
c.initClient.Lock()
defer c.initClient.Unlock()

if c.clientErr != nil {
return c.clientErr
}
if c.adminClientErr != nil {
return c.adminClientErr
}

if c.client == nil {
c.client, c.clientErr = spanner.NewClientWithConfig(ctx, databaseName, c.spannerClientConfig, opts...)
if c.clientErr != nil {
return c.clientErr
}

c.adminClient, c.adminClientErr = adminapi.NewDatabaseAdminClient(ctx, opts...)
if c.adminClientErr != nil {
c.client = nil
c.client.Close()
c.adminClient = nil
return c.adminClientErr
}
}

c.connCount++
return nil
}

// decreaseConnCount decreases the number of connections that are active and closes the underlying clients if it was the
// last connection.
func (c *connector) decreaseConnCount() error {
c.initClient.Lock()
defer c.initClient.Unlock()

c.connCount--
if c.connCount > 0 {
return nil
}

return c.closeClients()
}

func (c *connector) Driver() driver.Driver {
return c.driver
}

func (c *connector) Close() error {
c.closerMu.Lock()
c.closed = true
c.closerMu.Unlock()

c.driver.mu.Lock()
delete(c.driver.connectors, c.dsn)
c.driver.mu.Unlock()

return c.closeClients()
}

// Closes the underlying clients.
func (c *connector) closeClients() (err error) {
if c.client != nil {
c.client.Close()
c.client = nil
}
if c.adminClient != nil {
err = c.adminClient.Close()
c.adminClient = nil
}
return err
}

// SpannerConn is the public interface for the raw Spanner connection for the
// sql driver. This interface can be used with the db.Conn().Raw() method.
type SpannerConn interface {
Expand Down Expand Up @@ -954,18 +1025,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
}

func (c *conn) Close() error {
// Check if this is the last open connection of the connector.
if count := atomic.AddInt32(&c.connector.connCount, -1); count > 0 {
return nil
}

// This was the last connection. Remove the connector and close the Spanner clients.
c.connector.driver.mu.Lock()
delete(c.connector.driver.connectors, c.connector.dsn)
c.connector.driver.mu.Unlock()

c.client.Close()
return c.adminClient.Close()
return c.connector.decreaseConnCount()
}

func (c *conn) Begin() (driver.Tx, error) {
Expand Down
111 changes: 111 additions & 0 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2372,7 +2372,118 @@ func TestExcludeTxnFromChangeStreams_Transaction(t *testing.T) {
if g, w := exclude, false; g != w {
t.Fatalf("exclude_txn_from_change_streams mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestMaxIdleConnectionsNonZero(t *testing.T) {
t.Parallel()

// Set MinSessions=1, so we can use the number of BatchCreateSessions requests as an indication
// of the number of clients that was created.
db, server, teardown := setupTestDBConnectionWithParams(t, "MinSessions=1")
defer teardown()

db.SetMaxIdleConns(2)
for i := 0; i < 2; i++ {
openAndCloseConn(t, db)
}

// Verify that only one client was created.
// This happens because we have a non-zero value for the number of idle connections.
requests := drainRequestsFromServer(server.TestSpanner)
batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BatchCreateSessionsRequest{}))
if g, w := len(batchRequests), 1; g != w {
t.Fatalf("BatchCreateSessions requests count mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestMaxIdleConnectionsZero(t *testing.T) {
t.Parallel()

// Set MinSessions=1, so we can use the number of BatchCreateSessions requests as an indication
// of the number of clients that was created.
db, server, teardown := setupTestDBConnectionWithParams(t, "MinSessions=1")
defer teardown()

db.SetMaxIdleConns(0)
for i := 0; i < 2; i++ {
openAndCloseConn(t, db)
}

// Verify that two clients were created and closed.
// This should happen because we do not keep any idle connections open.
requests := drainRequestsFromServer(server.TestSpanner)
batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BatchCreateSessionsRequest{}))
if g, w := len(batchRequests), 2; g != w {
t.Fatalf("BatchCreateSessions requests count mismatch\n Got: %v\nWant: %v", g, w)
}
}

func openAndCloseConn(t *testing.T, db *sql.DB) {
ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatalf("failed to get a connection: %v", err)
}
defer func() {
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %v", err)
}
}()

var result int64
if err := conn.QueryRowContext(ctx, "SELECT 1").Scan(&result); err != nil {
t.Fatalf("failed to select: %v", err)
}
if result != 1 {
t.Fatalf("expected 1 got %v", result)
}
}

func TestCannotReuseClosedConnector(t *testing.T) {
// Note: This test cannot be parallel, as it inspects the size of the shared
// map of connectors in the driver. There is no guarantee how many connectors
// will be open when the test is running, if there are also other tests running
// in parallel.

db, _, teardown := setupTestDBConnection(t)
defer teardown()

ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
t.Fatalf("failed to get a connection: %v", err)
}
_ = conn.Close()
connectors := db.Driver().(*Driver).connectors
if g, w := len(connectors), 1; g != w {
t.Fatal("underlying connector has not been created")
}
var connector *connector
for _, v := range connectors {
connector = v
}
if connector.closed {
t.Fatal("connector is closed")
}

if err := db.Close(); err != nil {
t.Fatalf("failed to close connector: %v", err)
}
_, err = db.Conn(ctx)
if err == nil {
t.Fatal("missing error for getting a connection from a closed connector")
}
if g, w := err.Error(), "sql: database is closed"; g != w {
t.Fatalf("error mismatch for getting a connection from a closed connector\n Got: %v\nWant: %v", g, w)
}
// Verify that the underlying connector also has been closed.
if g, w := len(connectors), 0; g != w {
t.Fatal("underlying connector has not been closed")
}
if !connector.closed {
t.Fatal("connector is not closed")
}
}

func numeric(v string) big.Rat {
Expand Down
6 changes: 3 additions & 3 deletions examples/ddl-batches/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ import (
// It is therefore recommended that DDL statements are always executed in batches whenever possible.
//
// DDL batches can be executed in two ways using the Spanner go sql driver:
// 1. By executing the SQL statements `START BATCH DDL` and `RUN BATCH`.
// 2. By unwrapping the Spanner specific driver interface spannerdriver.Driver and calling the
// spannerdriver.Driver#StartBatchDDL and spannerdriver.Driver#RunBatch methods.
// 1. By executing the SQL statements `START BATCH DDL` and `RUN BATCH`.
// 2. By unwrapping the Spanner specific driver interface spannerdriver.Driver and calling the
// spannerdriver.Driver#StartBatchDDL and spannerdriver.Driver#RunBatch methods.
//
// This sample shows how to use both possibilities.
//
Expand Down
6 changes: 3 additions & 3 deletions examples/dml-batches/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ var createTableStatement = "CREATE TABLE Singers (SingerId INT64, Name STRING(MA
// that are needed.
//
// DML batches can be executed in two ways using the Spanner go sql driver:
// 1. By executing the SQL statements `START BATCH DML` and `RUN BATCH`.
// 2. By unwrapping the Spanner specific driver interface spannerdriver.Driver and calling the
// spannerdriver.Driver#StartBatchDML and spannerdriver.Driver#RunBatch methods.
// 1. By executing the SQL statements `START BATCH DML` and `RUN BATCH`.
// 2. By unwrapping the Spanner specific driver interface spannerdriver.Driver and calling the
// spannerdriver.Driver#StartBatchDML and spannerdriver.Driver#RunBatch methods.
//
// This sample shows how to use both possibilities.
//
Expand Down
Loading