Skip to content

Commit

Permalink
Reduce QueryContext allocations by reusing the channel
Browse files Browse the repository at this point in the history
This commit reduces the number of allocations and memory usage of
QueryContext by inverting the goroutine: instead of processing the
request in the goroutine and having it send the result, we now process
the request in the method itself and goroutine is only used to interrupt
the query if the context is canceled. The advantage of this approach is
that we no longer need to send anything on the channel, but instead can
treat the channel as a semaphore (this reduces the amount of memory
allocated by this method).

Additionally, we now reuse the channel used to communicate with the
goroutine which reduces the number of allocations.

This commit also adds a test that actually exercises the
sqlite3_interrupt logic since the existing tests did not. Those tests
cancelled the context before scanning any of the rows and could be made
to pass without ever calling sqlite3_interrupt. The below version of
SQLiteRows.Next passes the previous tests:

```go
func (rc *SQLiteRows) Next(dest []driver.Value) error {
	rc.s.mu.Lock()
	defer rc.s.mu.Unlock()
	if rc.s.closed {
		return io.EOF
	}
	if err := rc.ctx.Err(); err != nil {
		return err
	}
	return rc.nextSyncLocked(dest)
}
```

Benchmark results:
```
goos: darwin
goarch: arm64
pkg: github.com/mattn/go-sqlite3
cpu: Apple M1 Max
                                          │   old.txt   │              new.txt               │
                                          │   sec/op    │   sec/op     vs base               │
Suite/BenchmarkQueryContext/Background-10   3.994µ ± 2%   4.034µ ± 1%       ~ (p=0.289 n=10)
Suite/BenchmarkQueryContext/WithCancel-10   12.02µ ± 3%   11.56µ ± 4%  -3.87% (p=0.003 n=10)
geomean                                     6.930µ        6.829µ       -1.46%

                                          │   old.txt    │                new.txt                 │
                                          │     B/op     │     B/op      vs base                  │
Suite/BenchmarkQueryContext/Background-10     400.0 ± 0%     400.0 ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkQueryContext/WithCancel-10   2.376Ki ± 0%   1.025Ki ± 0%  -56.87% (p=0.000 n=10)
geomean                                       986.6          647.9       -34.33%
¹ all samples are equal

                                          │  old.txt   │               new.txt                │
                                          │ allocs/op  │ allocs/op   vs base                  │
Suite/BenchmarkQueryContext/Background-10   12.00 ± 0%   12.00 ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkQueryContext/WithCancel-10   38.00 ± 0%   28.00 ± 0%  -26.32% (p=0.000 n=10)
geomean                                     21.35        18.33       -14.16%
¹ all samples are equal
```
  • Loading branch information
charlievieth committed Dec 20, 2024
1 parent 7658c06 commit a8047d9
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 43 deletions.
92 changes: 56 additions & 36 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ type SQLiteRows struct {
decltype []string
ctx context.Context // no better alternative to pass context into Next() method
closemu sync.Mutex
// semaphore to signal the goroutine used to interrupt queries when a
// cancellable context is passed to QueryContext
sema chan struct{}
}

type functionInfo struct {
Expand Down Expand Up @@ -2050,36 +2053,37 @@ func isInterruptErr(err error) bool {

// exec executes a query that doesn't return rows. Attempts to honor context timeout.
func (s *SQLiteStmt) exec(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if ctx.Done() == nil {
done := ctx.Done()
if done == nil {
return s.execSync(args)
}

type result struct {
r driver.Result
err error
if err := ctx.Err(); err != nil {
return nil, err // Fast check if the channel is closed
}
resultCh := make(chan result)
defer close(resultCh)

sema := make(chan struct{})
go func() {
r, err := s.execSync(args)
resultCh <- result{r, err}
}()
var rv result
select {
case rv = <-resultCh:
case <-ctx.Done():
select {
case rv = <-resultCh: // no need to interrupt, operation completed in db
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
case <-done:
C.sqlite3_interrupt(s.c.db)
rv = <-resultCh // wait for goroutine completed
if isInterruptErr(rv.err) {
return nil, ctx.Err()
}
// Wait until signaled. We need to ensure that this goroutine
// will not call interrupt after this method returns.
<-sema
case <-sema:
}
}()
r, err := s.execSync(args)
// Signal the goroutine to exit. This send will only succeed at a point
// where it is impossible for the goroutine to call sqlite3_interrupt.
//
// This is necessary to ensure the goroutine does not interrupt an
// unrelated query if the context is cancelled after this method returns
// but before the goroutine exits (we don't wait for it to exit).
sema <- struct{}{}
if err != nil && isInterruptErr(err) {
return nil, ctx.Err()
}
return rv.r, rv.err
return r, err
}

func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) {
Expand Down Expand Up @@ -2117,6 +2121,9 @@ func (rc *SQLiteRows) Close() error {
return nil
}
rc.s = nil // remove reference to SQLiteStmt
if rc.sema != nil {
close(rc.sema)
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
Expand Down Expand Up @@ -2174,27 +2181,40 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return io.EOF
}

if rc.ctx.Done() == nil {
done := rc.ctx.Done()
if done == nil {
return rc.nextSyncLocked(dest)
}
resultCh := make(chan error)
defer close(resultCh)
if err := rc.ctx.Err(); err != nil {
return err // Fast check if the channel is closed
}

if rc.sema == nil {
rc.sema = make(chan struct{})
}
go func() {
resultCh <- rc.nextSyncLocked(dest)
}()
select {
case err := <-resultCh:
return err
case <-rc.ctx.Done():
select {
case <-resultCh: // no need to interrupt
default:
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
case <-done:
C.sqlite3_interrupt(rc.s.c.db)
<-resultCh // ensure goroutine completed
// Wait until signaled. We need to ensure that this goroutine
// will not call interrupt after this method returns.
<-rc.sema
case <-rc.sema:
}
return rc.ctx.Err()
}()

err := rc.nextSyncLocked(dest)
// Signal the goroutine to exit. This send will only succeed at a point
// where it is impossible for the goroutine to call sqlite3_interrupt.
//
// This is necessary to ensure the goroutine does not interrupt an
// unrelated query if the context is cancelled after this method returns
// but before the goroutine exits (we don't wait for it to exit).
rc.sema <- struct{}{}
if err != nil && isInterruptErr(err) {
err = rc.ctx.Err()
}
return err
}

// nextSyncLocked moves cursor to next; must be called with locked mutex.
Expand Down
147 changes: 147 additions & 0 deletions sqlite3_go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ package sqlite3
import (
"context"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"os"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -268,6 +270,151 @@ func TestQueryRowContextCancelParallel(t *testing.T) {
}
}

// Test that we can successfully interrupt a long running query when
// the context is canceled. The previous two QueryRowContext tests
// only test that we handle a previously cancelled context and thus
// do not call sqlite3_interrupt.
func TestQueryRowContextCancelInterrupt(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()

// Test that we have the unixepoch function and if not skip the test.
if _, err := db.Exec(`SELECT unixepoch(datetime(100000, 'unixepoch', 'localtime'))`); err != nil {
libVersion, libVersionNumber, sourceID := Version()
if strings.Contains(err.Error(), "no such function: unixepoch") {
t.Skip("Skipping the 'unixepoch' function is not implemented in "+
"this version of sqlite3:", libVersion, libVersionNumber, sourceID)
}
t.Fatal(err)
}

const createTableStmt = `
CREATE TABLE timestamps (
ts TIMESTAMP NOT NULL
);`
if _, err := db.Exec(createTableStmt); err != nil {
t.Fatal(err)
}

stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()

// Computationally expensive query that consumes many rows. This is needed
// to test cancellation because queries are not interrupted immediately.
// Instead, queries are only halted at certain checkpoints where the
// sqlite3.isInterrupted is checked and true.
queryStmt := `
SELECT
SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1,
SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2,
SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3,
SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4
FROM
timestamps
WHERE datetime(ts, 'unixepoch', 'localtime')
LIKE
?;`

query := func(t *testing.T, timeout time.Duration) (int, error) {
// Create a complicated pattern to match timestamps
const pattern = "%2%0%2%4%-%-%:%:%"
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
rows, err := db.QueryContext(ctx, queryStmt, pattern)
if err != nil {
return 0, err
}
var count int
for rows.Next() {
var n int64
if err := rows.Scan(&n, &n, &n, &n); err != nil {
return count, err
}
count++
}
return count, rows.Err()
}

average := func(n int, fn func()) time.Duration {
start := time.Now()
for i := 0; i < n; i++ {
fn()
}
return time.Since(start) / time.Duration(n)
}

createRows := func(n int) {
t.Logf("Creating %d rows", n)
if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil {
t.Fatal(err)
}
ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix()
rr := rand.New(rand.NewSource(1234))
for i := 0; i < n; i++ {
if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil {
t.Fatal(err)
}
}
}

const TargetRuntime = 200 * time.Millisecond
const N = 5_000 // Number of rows to insert at a time

// Create enough rows that the query takes ~200ms to run.
start := time.Now()
createRows(N)
baseAvg := average(4, func() {
if _, err := query(t, time.Hour); err != nil {
t.Fatal(err)
}
})
t.Log("Base average:", baseAvg)
rowCount := N * (int(TargetRuntime/baseAvg) + 1)
createRows(rowCount)
t.Log("Table setup time:", time.Since(start))

// Set the timeout to 1/10 of the average query time.
avg := average(2, func() {
n, err := query(t, time.Hour)
if err != nil {
t.Fatal(err)
}
if n == 0 {
t.Fatal("scanned zero rows")
}
})
// Guard against the timeout being too short to reliably test.
if avg < TargetRuntime/2 {
t.Fatalf("Average query runtime should be around %s got: %s ",
TargetRuntime, avg)
}
timeout := (avg / 10).Round(100 * time.Microsecond)
t.Logf("Average: %s Timeout: %s", avg, timeout)

for i := 0; i < 10; i++ {
tt := time.Now()
n, err := query(t, timeout)
if !errors.Is(err, context.DeadlineExceeded) {
fn := t.Errorf
if err != nil {
fn = t.Fatalf
}
fn("expected error %v got %v", context.DeadlineExceeded, err)
}
d := time.Since(tt)
t.Logf("%d: rows: %d duration: %s", i, n, d)
if d > timeout*4 {
t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d)
}
}
}

func TestExecCancel(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
Expand Down
Loading

0 comments on commit a8047d9

Please sign in to comment.