diff --git a/driver.go b/driver.go index 0a98f66f..7f4bb16e 100644 --- a/driver.go +++ b/driver.go @@ -171,7 +171,7 @@ type connector struct { // propagated to the caller. This option is enabled by default. retryAbortsInternally bool - initClient sync.Mutex + clientMu sync.Mutex client *spanner.Client clientErr error adminClient *adminapi.DatabaseAdminClient @@ -317,8 +317,8 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { // increaseConnCount initializes the client and increases the number of connections that are active. func (c *connector) increaseConnCount(ctx context.Context, databaseName string, opts []option.ClientOption) error { - c.initClient.Lock() - defer c.initClient.Unlock() + c.clientMu.Lock() + defer c.clientMu.Unlock() if c.clientErr != nil { return c.clientErr @@ -349,8 +349,8 @@ func (c *connector) increaseConnCount(ctx context.Context, databaseName string, // decreaseConnCount decreases the number of connections that are active and closes the underlying clients if it was the // last connection. func (c *connector) decreaseConnCount() error { - c.initClient.Lock() - defer c.initClient.Unlock() + c.clientMu.Lock() + defer c.clientMu.Unlock() c.connCount-- if c.connCount > 0 { @@ -373,6 +373,8 @@ func (c *connector) Close() error { delete(c.driver.connectors, c.dsn) c.driver.mu.Unlock() + c.clientMu.Lock() + defer c.clientMu.Unlock() return c.closeClients() } diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 81e1b0d3..1c7b4018 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -119,6 +119,39 @@ func TestSimpleQuery(t *testing.T) { } } +func TestConcurrentScanAndClose(t *testing.T) { + t.Parallel() + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + rows, err := db.QueryContext(context.Background(), testutil.SelectFooFromBar) + if err != nil { + t.Fatal(err) + } + + // Only fetch the first row of the query to make sure that the rows are not auto-closed + // when the end of the stream is reached. + rows.Next() + var got int64 + err = rows.Scan(&got) + if err != nil { + t.Fatal(err) + } + + // Close both the database and the rows (connection) in parallel. + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _ = db.Close() + }() + go func() { + defer wg.Done() + _ = rows.Close() + }() + wg.Wait() +} + func TestSingleQueryWithTimestampBound(t *testing.T) { t.Parallel()