diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index a77d63dc5eb83..0764c7d17a007 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -2944,15 +2944,17 @@ func (rs *Rows) initContextClose(ctx, txctx context.Context) { if bypassRowsAwaitDone { return } - ctx, rs.cancel = context.WithCancel(ctx) - go rs.awaitDone(ctx, txctx) + closectx, cancel := context.WithCancel(ctx) + rs.cancel = cancel + go rs.awaitDone(ctx, txctx, closectx) } -// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided -// from the query context and is canceled when the query Rows is closed. +// awaitDone blocks until ctx, txctx, or closectx is canceled. +// The ctx is provided from the query context. // If the query was issued in a transaction, the transaction's context -// is also provided in txctx to ensure Rows is closed if the Tx is closed. -func (rs *Rows) awaitDone(ctx, txctx context.Context) { +// is also provided in txctx, to ensure Rows is closed if the Tx is closed. +// The closectx is closed by an explicit call to rs.Close. +func (rs *Rows) awaitDone(ctx, txctx, closectx context.Context) { var txctxDone <-chan struct{} if txctx != nil { txctxDone = txctx.Done() @@ -2964,6 +2966,9 @@ func (rs *Rows) awaitDone(ctx, txctx context.Context) { case <-txctxDone: err := txctx.Err() rs.contextDone.Store(&err) + case <-closectx.Done(): + // rs.cancel was called via Close(); don't store this into contextDone + // to ensure Err() is unaffected. } rs.close(ctx.Err()) } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 718056c351260..e6a5cd912ac36 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -4493,6 +4493,31 @@ func TestContextCancelBetweenNextAndErr(t *testing.T) { } } +func TestNilErrorAfterClose(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + // This WithCancel is important; Rows contains an optimization to avoid + // spawning a goroutine when the query/transaction context cannot be + // canceled, but this test tests a bug which is caused by said goroutine. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r, err := db.QueryContext(ctx, "SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + + if err := r.Close(); err != nil { + t.Fatal(err) + } + + time.Sleep(10 * time.Millisecond) // increase odds of seeing failure + if err := r.Err(); err != nil { + t.Fatal(err) + } +} + // badConn implements a bad driver.Conn, for TestBadDriver. // The Exec method panics. type badConn struct{}