diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 2fe5ea42da3a4d..cfeb3b34373919 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -15,6 +15,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" ) @@ -90,6 +91,8 @@ func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { type fakeDB struct { name string + useRawBytes atomic.Bool + mu sync.Mutex tables map[string]*table badConn bool @@ -697,6 +700,8 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm switch cmd { case "WIPE": // Nothing + case "USE_RAWBYTES": + c.db.useRawBytes.Store(true) case "SELECT": stmt, err = c.prepareSelect(stmt, parts) case "CREATE": @@ -800,6 +805,9 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d case "WIPE": db.wipe() return driver.ResultNoRows, nil + case "USE_RAWBYTES": + s.c.db.useRawBytes.Store(true) + return driver.ResultNoRows, nil case "CREATE": if err := db.createTable(s.table, s.colName, s.colType); err != nil { return nil, err @@ -929,6 +937,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( txStatus = "transaction" } cursor := &rowsCursor{ + db: s.c.db, parentMem: s.c, posRow: -1, rows: [][]*row{ @@ -1025,6 +1034,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( } cursor := &rowsCursor{ + db: s.c.db, parentMem: s.c, posRow: -1, rows: setMRows, @@ -1067,6 +1077,7 @@ func (tx *fakeTx) Rollback() error { } type rowsCursor struct { + db *fakeDB parentMem memToucher cols [][]string colType [][]string @@ -1141,7 +1152,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { // messing up conversions or doing them differently. dest[i] = v - if bs, ok := v.([]byte); ok { + if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() { if rc.bytesClone == nil { rc.bytesClone = make(map[*byte][]byte) } diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 8dd48107a6cf2c..3db387e841dcd2 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -2893,6 +2893,8 @@ type Rows struct { cancel func() // called when Rows is closed, may be nil. closeStmt *driverStmt // if non-nil, statement to Close on close + contextDone atomic.Pointer[error] // error that awaitDone saw; set before close attempt + // closemu prevents Rows from closing while there // is an active streaming result. It is held for read during non-close operations // and exclusively during close. @@ -2905,6 +2907,15 @@ type Rows struct { // lastcols is only used in Scan, Next, and NextResultSet which are expected // not to be called concurrently. lastcols []driver.Value + + // closemuScanHold is whether the previous call to Scan kept closemu RLock'ed + // without unlocking it. It does that when the user passes a *RawBytes scan + // target. In that case, we need to prevent awaitDone from closing the Rows + // while the user's still using the memory. See go.dev/issue/60304. + // + // It is only used by Scan, Next, and NextResultSet which are expected + // not to be called concurrently. + closemuScanHold bool } // lasterrOrErrLocked returns either lasterr or the provided err. @@ -2942,7 +2953,11 @@ func (rs *Rows) awaitDone(ctx, txctx context.Context) { } select { case <-ctx.Done(): + err := ctx.Err() + rs.contextDone.Store(&err) case <-txctxDone: + err := txctx.Err() + rs.contextDone.Store(&err) } rs.close(ctx.Err()) } @@ -2954,6 +2969,15 @@ func (rs *Rows) awaitDone(ctx, txctx context.Context) { // // Every call to Scan, even the first one, must be preceded by a call to Next. func (rs *Rows) Next() bool { + // If the user's calling Next, they're done with their previous row's Scan + // results (any RawBytes memory), so we can release the read lock that would + // be preventing awaitDone from calling close. + rs.closemuRUnlockIfHeldByScan() + + if rs.contextDone.Load() != nil { + return false + } + var doClose, ok bool withLock(rs.closemu.RLocker(), func() { doClose, ok = rs.nextLocked() @@ -3008,6 +3032,11 @@ func (rs *Rows) nextLocked() (doClose, ok bool) { // scanning. If there are further result sets they may not have rows in the result // set. func (rs *Rows) NextResultSet() bool { + // If the user's calling NextResultSet, they're done with their previous + // row's Scan results (any RawBytes memory), so we can release the read lock + // that would be preventing awaitDone from calling close. + rs.closemuRUnlockIfHeldByScan() + var doClose bool defer func() { if doClose { @@ -3044,6 +3073,10 @@ func (rs *Rows) NextResultSet() bool { // Err returns the error, if any, that was encountered during iteration. // Err may be called after an explicit or implicit Close. func (rs *Rows) Err() error { + if errp := rs.contextDone.Load(); errp != nil { + return *errp + } + rs.closemu.RLock() defer rs.closemu.RUnlock() return rs.lasterrOrErrLocked(nil) @@ -3237,6 +3270,11 @@ func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType { // If any of the first arguments implementing Scanner returns an error, // that error will be wrapped in the returned error. func (rs *Rows) Scan(dest ...any) error { + if rs.closemuScanHold { + // This should only be possible if the user calls Scan twice in a row + // without calling Next. + return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)") + } rs.closemu.RLock() if rs.lasterr != nil && rs.lasterr != io.EOF { @@ -3248,23 +3286,50 @@ func (rs *Rows) Scan(dest ...any) error { rs.closemu.RUnlock() return err } - rs.closemu.RUnlock() + + if scanArgsContainRawBytes(dest) { + rs.closemuScanHold = true + } else { + rs.closemu.RUnlock() + } if rs.lastcols == nil { + rs.closemuRUnlockIfHeldByScan() return errors.New("sql: Scan called without calling Next") } if len(dest) != len(rs.lastcols) { + rs.closemuRUnlockIfHeldByScan() return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) } + for i, sv := range rs.lastcols { err := convertAssignRows(dest[i], sv, rs) if err != nil { + rs.closemuRUnlockIfHeldByScan() return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err) } } return nil } +// closemuRUnlockIfHeldByScan releases any closemu.RLock held open by a previous +// call to Scan with *RawBytes. +func (rs *Rows) closemuRUnlockIfHeldByScan() { + if rs.closemuScanHold { + rs.closemuScanHold = false + rs.closemu.RUnlock() + } +} + +func scanArgsContainRawBytes(args []any) bool { + for _, a := range args { + if _, ok := a.(*RawBytes); ok { + return true + } + } + return false +} + // rowsCloseHook returns a function so tests may install the // hook through a test only mutex. var rowsCloseHook = func() func(*Rows, *error) { return nil } @@ -3274,6 +3339,11 @@ var rowsCloseHook = func() func(*Rows, *error) { return nil } // the Rows are closed automatically and it will suffice to check the // result of Err. Close is idempotent and does not affect the result of Err. func (rs *Rows) Close() error { + // If the user's calling Close, they're done with their previous row's Scan + // results (any RawBytes memory), so we can release the read lock that would + // be preventing awaitDone from calling the unexported close before we do so. + rs.closemuRUnlockIfHeldByScan() + return rs.close(nil) } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 2b3d76f513097e..29a6709f234855 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -4385,6 +4385,64 @@ func TestRowsScanProperlyWrapsErrors(t *testing.T) { } } +// From go.dev/issue/60304 +func TestContextCancelDuringRawBytesScan(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + if _, err := db.Exec("USE_RAWBYTES"); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r, err := db.QueryContext(ctx, "SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + numRows := 0 + var sink byte + for r.Next() { + numRows++ + var s RawBytes + err = r.Scan(&s) + if !r.closemuScanHold { + t.Errorf("expected closemu to be held") + } + if err != nil { + t.Fatal(err) + } + t.Logf("read %q", s) + if numRows == 2 { + cancel() // invalidate the context, which used to call close asynchronously + } + for _, b := range s { // some operation reading from the raw memory + sink += b + } + } + if r.closemuScanHold { + t.Errorf("closemu held; should not be") + } + + // There are 3 rows. We canceled after reading 2 so we expect either + // 2 or 3 depending on how the awaitDone goroutine schedules. + switch numRows { + case 0, 1: + t.Errorf("got %d rows; want 2+", numRows) + case 2: + if err := r.Err(); err != context.Canceled { + t.Errorf("unexpected error: %v (%T)", err, err) + } + default: + // Made it to the end. This is rare, but fine. Permit it. + } + + if err := r.Close(); err != nil { + t.Fatal(err) + } +} + // badConn implements a bad driver.Conn, for TestBadDriver. // The Exec method panics. type badConn struct{}