diff --git a/benchmark_test.go b/benchmark_test.go index 5828d40f9..3e25a3bf2 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -317,3 +317,57 @@ func BenchmarkExecContext(b *testing.B) { }) } } + +// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes. +// "size=" means size of each blobs. +func BenchmarkQueryRawBytes(b *testing.B) { + var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} + db := initDB(b, + "DROP TABLE IF EXISTS bench_rawbytes", + "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", + ) + defer db.Close() + + blob := make([]byte, sizes[len(sizes)-1]) + for i := range blob { + blob[i] = 42 + } + for i := 0; i < 100; i++ { + _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob) + if err != nil { + b.Fatal(err) + } + } + + for _, s := range sizes { + b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) { + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) + b.ReportAllocs() + b.ResetTimer() + + for j := 0; j < b.N; j++ { + rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s) + if err != nil { + b.Fatal(err) + } + nrows := 0 + for rows.Next() { + var buf sql.RawBytes + err := rows.Scan(&buf) + if err != nil { + b.Fatal(err) + } + if len(buf) != s { + b.Fatalf("size mismatch: expected %v, got %v", s, len(buf)) + } + nrows++ + } + rows.Close() + if nrows != 100 { + b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows) + } + } + }) + } +} diff --git a/buffer.go b/buffer.go index 19486bd6f..0774c5c8c 100644 --- a/buffer.go +++ b/buffer.go @@ -15,47 +15,69 @@ import ( ) const defaultBufSize = 4096 +const maxCachedBufSize = 256 * 1024 // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. +// This buffer is backed by two byte slices in a double-buffering scheme type buffer struct { buf []byte // buf is a byte buffer who's length and capacity are equal. nc net.Conn idx int length int timeout time.Duration + dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer + flipcnt uint // flipccnt is the current buffer counter for double-buffering } // newBuffer allocates and returns a new buffer. func newBuffer(nc net.Conn) buffer { + fg := make([]byte, defaultBufSize) return buffer{ - buf: make([]byte, defaultBufSize), - nc: nc, + buf: fg, + nc: nc, + dbuf: [2][]byte{fg, nil}, } } +// flip replaces the active buffer with the background buffer +// this is a delayed flip that simply increases the buffer counter; +// the actual flip will be performed the next time we call `buffer.fill` +func (b *buffer) flip() { + b.flipcnt += 1 +} + // fill reads into the buffer until at least _need_ bytes are in it func (b *buffer) fill(need int) error { n := b.length + // fill data into its double-buffering target: if we've called + // flip on this buffer, we'll be copying to the background buffer, + // and then filling it with network data; otherwise we'll just move + // the contents of the current buffer to the front before filling it + dest := b.dbuf[b.flipcnt&1] + + // grow buffer if necessary to fit the whole packet. + if need > len(dest) { + // Round up to the next multiple of the default size + dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) - // move existing data to the beginning - if n > 0 && b.idx > 0 { - copy(b.buf[0:n], b.buf[b.idx:]) + // if the allocated buffer is not too large, move it to backing storage + // to prevent extra allocations on applications that perform large reads + if len(dest) <= maxCachedBufSize { + b.dbuf[b.flipcnt&1] = dest + } } - // grow buffer if necessary - // TODO: let the buffer shrink again at some point - // Maybe keep the org buf slice and swap back? - if need > len(b.buf) { - // Round up to the next multiple of the default size - newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) - copy(newBuf, b.buf) - b.buf = newBuf + // if we're filling the fg buffer, move the existing data to the start of it. + // if we're filling the bg buffer, copy over the data + if n > 0 { + copy(dest[:n], b.buf[b.idx:]) } + b.buf = dest b.idx = 0 for { diff --git a/driver_test.go b/driver_test.go index c35588a09..9c3d286ce 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2938,3 +2938,58 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() }) } + +// TestRawBytesAreNotModified checks for a race condition that arises when a query context +// is canceled while a user is calling rows.Scan. This is a more stringent test than the one +// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using +// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit +// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers. +func TestRawBytesAreNotModified(t *testing.T) { + const blob = "abcdefghijklmnop" + const contextRaceIterations = 20 + const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row. + const insertRows = 4 + + var sqlBlobs = [2]string{ + strings.Repeat(blob, blobSize/len(blob)), + strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)), + } + + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + for i := 0; i < insertRows; i++ { + dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1]) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) + if err != nil { + t.Fatal(err) + } + + var b int + var raw sql.RawBytes + for rows.Next() { + if err := rows.Scan(&b, &raw); err != nil { + t.Fatal(err) + } + + before := string(raw) + // Ensure cancelling the query does not corrupt the contents of `raw` + cancel() + time.Sleep(time.Microsecond * 100) + after := string(raw) + + if before != after { + t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) + } + } + rows.Close() + }() + } + }) +} diff --git a/rows.go b/rows.go index d3b1e2822..888bdb5f0 100644 --- a/rows.go +++ b/rows.go @@ -111,6 +111,13 @@ func (rows *mysqlRows) Close() (err error) { return err } + // flip the buffer for this connection if we need to drain it. + // note that for a successful query (i.e. one where rows.next() + // has been called until it returns false), `rows.mc` will be nil + // by the time the user calls `(*Rows).Close`, so we won't reach this + // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 + mc.buf.flip() + // Remove unread packets from stream if !rows.rs.done { err = mc.readUntilEOF()