Skip to content

Commit

Permalink
Reduce QueryContext allocations by reusing the channel
Browse files Browse the repository at this point in the history
This commit reduces the number of allocations and memory usage of
QueryContext by inverting the goroutine: instead of processing the
request in the goroutine and having it send the result, we now process
the request in the method itself and goroutine is only used to interrupt
the query if the context is canceled. The advantage of this approach is
that we no longer need to send anything on the channel, but instead can
treat the channel as a semaphore (this reduces the amount of memory
allocated by this method).

Additionally, we now reuse the channel used to communicate with the
goroutine which reduces the number of allocations.

This commit also adds a test that actually exercises the
sqlite3_interrupt logic since the existing tests did not. Those tests
cancelled the context before scanning any of the rows and could be made
to pass without ever calling sqlite3_interrupt. The below version of
SQLiteRows.Next passes the previous tests:

```go
func (rc *SQLiteRows) Next(dest []driver.Value) error {
	rc.s.mu.Lock()
	defer rc.s.mu.Unlock()
	if rc.s.closed {
		return io.EOF
	}
	if err := rc.ctx.Err(); err != nil {
		return err
	}
	return rc.nextSyncLocked(dest)
}
```

Benchmark results:
```
goos: darwin
goarch: arm64
pkg: github.com/mattn/go-sqlite3
cpu: Apple M1 Max
                                          │   old.txt   │              new.txt               │
                                          │   sec/op    │   sec/op     vs base               │
Suite/BenchmarkQueryContext/Background-10   3.994µ ± 2%   4.034µ ± 1%       ~ (p=0.289 n=10)
Suite/BenchmarkQueryContext/WithCancel-10   12.02µ ± 3%   11.56µ ± 4%  -3.87% (p=0.003 n=10)
geomean                                     6.930µ        6.829µ       -1.46%

                                          │   old.txt    │                new.txt                 │
                                          │     B/op     │     B/op      vs base                  │
Suite/BenchmarkQueryContext/Background-10     400.0 ± 0%     400.0 ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkQueryContext/WithCancel-10   2.376Ki ± 0%   1.025Ki ± 0%  -56.87% (p=0.000 n=10)
geomean                                       986.6          647.9       -34.33%
¹ all samples are equal

                                          │  old.txt   │               new.txt                │
                                          │ allocs/op  │ allocs/op   vs base                  │
Suite/BenchmarkQueryContext/Background-10   12.00 ± 0%   12.00 ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkQueryContext/WithCancel-10   38.00 ± 0%   28.00 ± 0%  -26.32% (p=0.000 n=10)
geomean                                     21.35        18.33       -14.16%
¹ all samples are equal
```
  • Loading branch information
charlievieth committed Nov 9, 2024
1 parent 41871ea commit 580000f
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 21 deletions.
47 changes: 33 additions & 14 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ type SQLiteRows struct {
cls bool
closed bool
ctx context.Context // no better alternative to pass context into Next() method
// semaphore to signal the goroutine used to interrupt queries when a
// cancellable context is passed to QueryContext
sema chan struct{}
}

type functionInfo struct {
Expand Down Expand Up @@ -2118,6 +2121,9 @@ func (rc *SQLiteRows) Close() error {
return nil
}
rc.closed = true
if rc.sema != nil {
close(rc.sema)
}
if rc.cls {
rc.s.mu.Unlock()
return rc.s.Close()
Expand Down Expand Up @@ -2172,27 +2178,40 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return io.EOF
}

if rc.ctx.Done() == nil {
done := rc.ctx.Done()
if done == nil {
return rc.nextSyncLocked(dest)
}
resultCh := make(chan error)
defer close(resultCh)
if err := rc.ctx.Err(); err != nil {
return err // Fast check if the channel is closed
}

if rc.sema == nil {
rc.sema = make(chan struct{})
}
go func() {
resultCh <- rc.nextSyncLocked(dest)
}()
select {
case err := <-resultCh:
return err
case <-rc.ctx.Done():
select {
case <-resultCh: // no need to interrupt
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
case <-done:
C.sqlite3_interrupt(rc.s.c.db)
<-resultCh // ensure goroutine completed
// Wait until signaled. We need to ensure that this goroutine
// will not call interrupt after this method returns.
<-rc.sema
case <-rc.sema:
}
return rc.ctx.Err()
}()

err := rc.nextSyncLocked(dest)
// Signal the goroutine to exit. This send will only succeed at a point
// where it is impossible for the goroutine to call sqlite3_interrupt.
//
// This is necessary to ensure the goroutine does not interrupt an
// unrelated query if the context is cancelled after this method returns
// but before the goroutine exits (we don't wait for it to exit).
rc.sema <- struct{}{}
if err != nil && isInterruptErr(err) {
err = rc.ctx.Err()
}
return err
}

// nextSyncLocked moves cursor to next; must be called with locked mutex.
Expand Down
147 changes: 147 additions & 0 deletions sqlite3_go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ package sqlite3
import (
"context"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"os"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -268,6 +270,151 @@ func TestQueryRowContextCancelParallel(t *testing.T) {
}
}

// Test that we can successfully interrupt a long running query when
// the context is canceled. The previous two QueryRowContext tests
// only test that we handle a previously cancelled context and thus
// do not call sqlite3_interrupt.
func TestQueryRowContextCancelInterrupt(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()

// Test that we have the unixepoch function and if not skip the test.
if _, err := db.Exec(`SELECT unixepoch(datetime(100000, 'unixepoch', 'localtime'))`); err != nil {
libVersion, libVersionNumber, sourceID := Version()
if strings.Contains(err.Error(), "no such function: unixepoch") {
t.Skip("Skipping the 'unixepoch' function is not implemented in "+
"this version of sqlite3:", libVersion, libVersionNumber, sourceID)
}
t.Fatal(err)
}

const createTableStmt = `
CREATE TABLE timestamps (
ts TIMESTAMP NOT NULL
);`
if _, err := db.Exec(createTableStmt); err != nil {
t.Fatal(err)
}

stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()

// Computationally expensive query that consumes many rows. This is needed
// to test cancellation because queries are not interrupted immediately.
// Instead, queries are only halted at certain checkpoints where the
// sqlite3.isInterrupted is checked and true.
queryStmt := `
SELECT
SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1,
SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2,
SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3,
SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4
FROM
timestamps
WHERE datetime(ts, 'unixepoch', 'localtime')
LIKE
?;`

query := func(t *testing.T, timeout time.Duration) (int, error) {
// Create a complicated pattern to match timestamps
const pattern = "%2%0%2%4%-%-%:%:%"
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
rows, err := db.QueryContext(ctx, queryStmt, pattern)
if err != nil {
return 0, err
}
var count int
for rows.Next() {
var n int64
if err := rows.Scan(&n, &n, &n, &n); err != nil {
return count, err
}
count++
}
return count, rows.Err()
}

average := func(n int, fn func()) time.Duration {
start := time.Now()
for i := 0; i < n; i++ {
fn()
}
return time.Since(start) / time.Duration(n)
}

createRows := func(n int) {
t.Logf("Creating %d rows", n)
if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil {
t.Fatal(err)
}
ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix()
rr := rand.New(rand.NewSource(1234))
for i := 0; i < n; i++ {
if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil {
t.Fatal(err)
}
}
}

const TargetRuntime = 200 * time.Millisecond
const N = 5_000 // Number of rows to insert at a time

// Create enough rows that the query takes ~200ms to run.
start := time.Now()
createRows(N)
baseAvg := average(4, func() {
if _, err := query(t, time.Hour); err != nil {
t.Fatal(err)
}
})
t.Log("Base average:", baseAvg)
rowCount := N * (int(TargetRuntime/baseAvg) + 1)
createRows(rowCount)
t.Log("Table setup time:", time.Since(start))

// Set the timeout to 1/10 of the average query time.
avg := average(2, func() {
n, err := query(t, time.Hour)
if err != nil {
t.Fatal(err)
}
if n == 0 {
t.Fatal("scanned zero rows")
}
})
// Guard against the timeout being too short to reliably test.
if avg < TargetRuntime/2 {
t.Fatalf("Average query runtime should be around %s got: %s ",
TargetRuntime, avg)
}
timeout := (avg / 10).Round(100 * time.Microsecond)
t.Logf("Average: %s Timeout: %s", avg, timeout)

for i := 0; i < 10; i++ {
tt := time.Now()
n, err := query(t, timeout)
if !errors.Is(err, context.DeadlineExceeded) {
fn := t.Errorf
if err != nil {
fn = t.Fatalf
}
fn("expected error %v got %v", context.DeadlineExceeded, err)
}
d := time.Since(tt)
t.Logf("%d: rows: %d duration: %s", i, n, d)
if d > timeout*4 {
t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d)
}
}
}

func TestExecCancel(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
Expand Down
82 changes: 75 additions & 7 deletions sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package sqlite3

import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"errors"
Expand Down Expand Up @@ -2030,7 +2031,7 @@ func BenchmarkCustomFunctions(b *testing.B) {
}

func TestSuite(t *testing.T) {
initializeTestDB(t)
initializeTestDB(t, false)
defer freeTestDB()

for _, test := range tests {
Expand All @@ -2039,7 +2040,7 @@ func TestSuite(t *testing.T) {
}

func BenchmarkSuite(b *testing.B) {
initializeTestDB(b)
initializeTestDB(b, true)
defer freeTestDB()

for _, benchmark := range benchmarks {
Expand Down Expand Up @@ -2068,8 +2069,13 @@ type TestDB struct {

var db *TestDB

func initializeTestDB(t testing.TB) {
tempFilename := TempFilename(t)
func initializeTestDB(t testing.TB, memory bool) {
var tempFilename string
if memory {
tempFilename = ":memory:"
} else {
tempFilename = TempFilename(t)
}
d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil {
os.Remove(tempFilename)
Expand All @@ -2084,9 +2090,11 @@ func freeTestDB() {
if err != nil {
panic(err)
}
err = os.Remove(db.tempFilename)
if err != nil {
panic(err)
if db.tempFilename != "" && db.tempFilename != ":memory:" {
err := os.Remove(db.tempFilename)
if err != nil {
panic(err)
}
}
}

Expand All @@ -2107,6 +2115,7 @@ var tests = []testing.InternalTest{
var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkExec", F: benchmarkExec},
{Name: "BenchmarkQuery", F: benchmarkQuery},
{Name: "BenchmarkQueryContext", F: benchmarkQueryContext},
{Name: "BenchmarkParams", F: benchmarkParams},
{Name: "BenchmarkStmt", F: benchmarkStmt},
{Name: "BenchmarkRows", F: benchmarkRows},
Expand Down Expand Up @@ -2479,6 +2488,65 @@ func benchmarkQuery(b *testing.B) {
}
}

// benchmarkQueryContext is benchmark for QueryContext
func benchmarkQueryContext(b *testing.B) {
const createTableStmt = `
CREATE TABLE IF NOT EXISTS query_context(
id INTEGER PRIMARY KEY
);
DELETE FROM query_context;
VACUUM;`
test := func(ctx context.Context, b *testing.B) {
if _, err := db.Exec(createTableStmt); err != nil {
b.Fatal(err)
}
for i := 0; i < 10; i++ {
_, err := db.Exec("INSERT INTO query_context VALUES (?);", int64(i))
if err != nil {
db.Fatal(err)
}
}
stmt, err := db.PrepareContext(ctx, `SELECT id FROM query_context;`)
if err != nil {
b.Fatal(err)
}
b.Cleanup(func() { stmt.Close() })

var n int
for i := 0; i < b.N; i++ {
rows, err := stmt.QueryContext(ctx)
if err != nil {
b.Fatal(err)
}
for rows.Next() {
if err := rows.Scan(&n); err != nil {
b.Fatal(err)
}
}
if err := rows.Err(); err != nil {
b.Fatal(err)
}
}
}

// When the context does not have a Done channel we should use
// the fast path that directly handles the query instead of
// handling it in a goroutine. This benchmark also serves to
// highlight the performance impact of using a cancelable
// context.
b.Run("Background", func(b *testing.B) {
test(context.Background(), b)
})

// Benchmark a query with a context that can be canceled. This
// requires using a goroutine and is thus much slower.
b.Run("WithCancel", func(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
test(ctx, b)
})
}

// benchmarkParams is benchmark for params
func benchmarkParams(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down

0 comments on commit 580000f

Please sign in to comment.