From a8047d9a614b2b313e3d9817caa2b7e1e641e346 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Wed, 6 Nov 2024 21:43:23 -0500 Subject: [PATCH] Reduce QueryContext allocations by reusing the channel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit reduces the number of allocations and memory usage of QueryContext by inverting the goroutine: instead of processing the request in the goroutine and having it send the result, we now process the request in the method itself and goroutine is only used to interrupt the query if the context is canceled. The advantage of this approach is that we no longer need to send anything on the channel, but instead can treat the channel as a semaphore (this reduces the amount of memory allocated by this method). Additionally, we now reuse the channel used to communicate with the goroutine which reduces the number of allocations. This commit also adds a test that actually exercises the sqlite3_interrupt logic since the existing tests did not. Those tests cancelled the context before scanning any of the rows and could be made to pass without ever calling sqlite3_interrupt. The below version of SQLiteRows.Next passes the previous tests: ```go func (rc *SQLiteRows) Next(dest []driver.Value) error { rc.s.mu.Lock() defer rc.s.mu.Unlock() if rc.s.closed { return io.EOF } if err := rc.ctx.Err(); err != nil { return err } return rc.nextSyncLocked(dest) } ``` Benchmark results: ``` goos: darwin goarch: arm64 pkg: github.com/mattn/go-sqlite3 cpu: Apple M1 Max │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ Suite/BenchmarkQueryContext/Background-10 3.994µ ± 2% 4.034µ ± 1% ~ (p=0.289 n=10) Suite/BenchmarkQueryContext/WithCancel-10 12.02µ ± 3% 11.56µ ± 4% -3.87% (p=0.003 n=10) geomean 6.930µ 6.829µ -1.46% │ old.txt │ new.txt │ │ B/op │ B/op vs base │ Suite/BenchmarkQueryContext/Background-10 400.0 ± 0% 400.0 ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkQueryContext/WithCancel-10 2.376Ki ± 0% 1.025Ki ± 0% -56.87% (p=0.000 n=10) geomean 986.6 647.9 -34.33% ¹ all samples are equal │ old.txt │ new.txt │ │ allocs/op │ allocs/op vs base │ Suite/BenchmarkQueryContext/Background-10 12.00 ± 0% 12.00 ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkQueryContext/WithCancel-10 38.00 ± 0% 28.00 ± 0% -26.32% (p=0.000 n=10) geomean 21.35 18.33 -14.16% ¹ all samples are equal ``` --- sqlite3.go | 92 ++++++++++++++++----------- sqlite3_go18_test.go | 147 +++++++++++++++++++++++++++++++++++++++++++ sqlite3_test.go | 93 ++++++++++++++++++++++++--- 3 files changed, 289 insertions(+), 43 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 3025a500..f7bfa363 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -399,6 +399,9 @@ type SQLiteRows struct { decltype []string ctx context.Context // no better alternative to pass context into Next() method closemu sync.Mutex + // semaphore to signal the goroutine used to interrupt queries when a + // cancellable context is passed to QueryContext + sema chan struct{} } type functionInfo struct { @@ -2050,36 +2053,37 @@ func isInterruptErr(err error) bool { // exec executes a query that doesn't return rows. Attempts to honor context timeout. func (s *SQLiteStmt) exec(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - if ctx.Done() == nil { + done := ctx.Done() + if done == nil { return s.execSync(args) } - - type result struct { - r driver.Result - err error + if err := ctx.Err(); err != nil { + return nil, err // Fast check if the channel is closed } - resultCh := make(chan result) - defer close(resultCh) + + sema := make(chan struct{}) go func() { - r, err := s.execSync(args) - resultCh <- result{r, err} - }() - var rv result - select { - case rv = <-resultCh: - case <-ctx.Done(): select { - case rv = <-resultCh: // no need to interrupt, operation completed in db - default: - // this is still racy and can be no-op if executed between sqlite3_* calls in execSync. + case <-done: C.sqlite3_interrupt(s.c.db) - rv = <-resultCh // wait for goroutine completed - if isInterruptErr(rv.err) { - return nil, ctx.Err() - } + // Wait until signaled. We need to ensure that this goroutine + // will not call interrupt after this method returns. + <-sema + case <-sema: } + }() + r, err := s.execSync(args) + // Signal the goroutine to exit. This send will only succeed at a point + // where it is impossible for the goroutine to call sqlite3_interrupt. + // + // This is necessary to ensure the goroutine does not interrupt an + // unrelated query if the context is cancelled after this method returns + // but before the goroutine exits (we don't wait for it to exit). + sema <- struct{}{} + if err != nil && isInterruptErr(err) { + return nil, ctx.Err() } - return rv.r, rv.err + return r, err } func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) { @@ -2117,6 +2121,9 @@ func (rc *SQLiteRows) Close() error { return nil } rc.s = nil // remove reference to SQLiteStmt + if rc.sema != nil { + close(rc.sema) + } s.mu.Lock() if s.closed { s.mu.Unlock() @@ -2174,27 +2181,40 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { return io.EOF } - if rc.ctx.Done() == nil { + done := rc.ctx.Done() + if done == nil { return rc.nextSyncLocked(dest) } - resultCh := make(chan error) - defer close(resultCh) + if err := rc.ctx.Err(); err != nil { + return err // Fast check if the channel is closed + } + + if rc.sema == nil { + rc.sema = make(chan struct{}) + } go func() { - resultCh <- rc.nextSyncLocked(dest) - }() - select { - case err := <-resultCh: - return err - case <-rc.ctx.Done(): select { - case <-resultCh: // no need to interrupt - default: - // this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked. + case <-done: C.sqlite3_interrupt(rc.s.c.db) - <-resultCh // ensure goroutine completed + // Wait until signaled. We need to ensure that this goroutine + // will not call interrupt after this method returns. + <-rc.sema + case <-rc.sema: } - return rc.ctx.Err() + }() + + err := rc.nextSyncLocked(dest) + // Signal the goroutine to exit. This send will only succeed at a point + // where it is impossible for the goroutine to call sqlite3_interrupt. + // + // This is necessary to ensure the goroutine does not interrupt an + // unrelated query if the context is cancelled after this method returns + // but before the goroutine exits (we don't wait for it to exit). + rc.sema <- struct{}{} + if err != nil && isInterruptErr(err) { + err = rc.ctx.Err() } + return err } // nextSyncLocked moves cursor to next; must be called with locked mutex. diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index eec7479d..879d9fda 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -11,10 +11,12 @@ package sqlite3 import ( "context" "database/sql" + "errors" "fmt" "io/ioutil" "math/rand" "os" + "strings" "sync" "testing" "time" @@ -268,6 +270,151 @@ func TestQueryRowContextCancelParallel(t *testing.T) { } } +// Test that we can successfully interrupt a long running query when +// the context is canceled. The previous two QueryRowContext tests +// only test that we handle a previously cancelled context and thus +// do not call sqlite3_interrupt. +func TestQueryRowContextCancelInterrupt(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // Test that we have the unixepoch function and if not skip the test. + if _, err := db.Exec(`SELECT unixepoch(datetime(100000, 'unixepoch', 'localtime'))`); err != nil { + libVersion, libVersionNumber, sourceID := Version() + if strings.Contains(err.Error(), "no such function: unixepoch") { + t.Skip("Skipping the 'unixepoch' function is not implemented in "+ + "this version of sqlite3:", libVersion, libVersionNumber, sourceID) + } + t.Fatal(err) + } + + const createTableStmt = ` + CREATE TABLE timestamps ( + ts TIMESTAMP NOT NULL + );` + if _, err := db.Exec(createTableStmt); err != nil { + t.Fatal(err) + } + + stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + // Computationally expensive query that consumes many rows. This is needed + // to test cancellation because queries are not interrupted immediately. + // Instead, queries are only halted at certain checkpoints where the + // sqlite3.isInterrupted is checked and true. + queryStmt := ` + SELECT + SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1, + SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2, + SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3, + SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4 + FROM + timestamps + WHERE datetime(ts, 'unixepoch', 'localtime') + LIKE + ?;` + + query := func(t *testing.T, timeout time.Duration) (int, error) { + // Create a complicated pattern to match timestamps + const pattern = "%2%0%2%4%-%-%:%:%" + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + rows, err := db.QueryContext(ctx, queryStmt, pattern) + if err != nil { + return 0, err + } + var count int + for rows.Next() { + var n int64 + if err := rows.Scan(&n, &n, &n, &n); err != nil { + return count, err + } + count++ + } + return count, rows.Err() + } + + average := func(n int, fn func()) time.Duration { + start := time.Now() + for i := 0; i < n; i++ { + fn() + } + return time.Since(start) / time.Duration(n) + } + + createRows := func(n int) { + t.Logf("Creating %d rows", n) + if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil { + t.Fatal(err) + } + ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix() + rr := rand.New(rand.NewSource(1234)) + for i := 0; i < n; i++ { + if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil { + t.Fatal(err) + } + } + } + + const TargetRuntime = 200 * time.Millisecond + const N = 5_000 // Number of rows to insert at a time + + // Create enough rows that the query takes ~200ms to run. + start := time.Now() + createRows(N) + baseAvg := average(4, func() { + if _, err := query(t, time.Hour); err != nil { + t.Fatal(err) + } + }) + t.Log("Base average:", baseAvg) + rowCount := N * (int(TargetRuntime/baseAvg) + 1) + createRows(rowCount) + t.Log("Table setup time:", time.Since(start)) + + // Set the timeout to 1/10 of the average query time. + avg := average(2, func() { + n, err := query(t, time.Hour) + if err != nil { + t.Fatal(err) + } + if n == 0 { + t.Fatal("scanned zero rows") + } + }) + // Guard against the timeout being too short to reliably test. + if avg < TargetRuntime/2 { + t.Fatalf("Average query runtime should be around %s got: %s ", + TargetRuntime, avg) + } + timeout := (avg / 10).Round(100 * time.Microsecond) + t.Logf("Average: %s Timeout: %s", avg, timeout) + + for i := 0; i < 10; i++ { + tt := time.Now() + n, err := query(t, timeout) + if !errors.Is(err, context.DeadlineExceeded) { + fn := t.Errorf + if err != nil { + fn = t.Fatalf + } + fn("expected error %v got %v", context.DeadlineExceeded, err) + } + d := time.Since(tt) + t.Logf("%d: rows: %d duration: %s", i, n, d) + if d > timeout*4 { + t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d) + } + } +} + func TestExecCancel(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { diff --git a/sqlite3_test.go b/sqlite3_test.go index 94de7386..e3dcace3 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -10,6 +10,7 @@ package sqlite3 import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -2030,7 +2031,7 @@ func BenchmarkCustomFunctions(b *testing.B) { } func TestSuite(t *testing.T) { - initializeTestDB(t) + initializeTestDB(t, false) defer freeTestDB() for _, test := range tests { @@ -2039,7 +2040,7 @@ func TestSuite(t *testing.T) { } func BenchmarkSuite(b *testing.B) { - initializeTestDB(b) + initializeTestDB(b, true) defer freeTestDB() for _, benchmark := range benchmarks { @@ -2068,8 +2069,13 @@ type TestDB struct { var db *TestDB -func initializeTestDB(t testing.TB) { - tempFilename := TempFilename(t) +func initializeTestDB(t testing.TB, memory bool) { + var tempFilename string + if memory { + tempFilename = ":memory:" + } else { + tempFilename = TempFilename(t) + } d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") if err != nil { os.Remove(tempFilename) @@ -2084,9 +2090,11 @@ func freeTestDB() { if err != nil { panic(err) } - err = os.Remove(db.tempFilename) - if err != nil { - panic(err) + if db.tempFilename != "" && db.tempFilename != ":memory:" { + err := os.Remove(db.tempFilename) + if err != nil { + panic(err) + } } } @@ -2106,7 +2114,9 @@ var tests = []testing.InternalTest{ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkExec", F: benchmarkExec}, + {Name: "BenchmarkExecContext", F: benchmarkExecContext}, {Name: "BenchmarkQuery", F: benchmarkQuery}, + {Name: "BenchmarkQueryContext", F: benchmarkQueryContext}, {Name: "BenchmarkParams", F: benchmarkParams}, {Name: "BenchmarkStmt", F: benchmarkStmt}, {Name: "BenchmarkRows", F: benchmarkRows}, @@ -2466,6 +2476,16 @@ func benchmarkExec(b *testing.B) { } } +func benchmarkExecContext(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 0; i < b.N; i++ { + if _, err := db.ExecContext(ctx, "select 1"); err != nil { + panic(err) + } + } +} + // benchmarkQuery is benchmark for query func benchmarkQuery(b *testing.B) { for i := 0; i < b.N; i++ { @@ -2480,6 +2500,65 @@ func benchmarkQuery(b *testing.B) { } } +// benchmarkQueryContext is benchmark for QueryContext +func benchmarkQueryContext(b *testing.B) { + const createTableStmt = ` + CREATE TABLE IF NOT EXISTS query_context( + id INTEGER PRIMARY KEY + ); + DELETE FROM query_context; + VACUUM;` + test := func(ctx context.Context, b *testing.B) { + if _, err := db.Exec(createTableStmt); err != nil { + b.Fatal(err) + } + for i := 0; i < 10; i++ { + _, err := db.Exec("INSERT INTO query_context VALUES (?);", int64(i)) + if err != nil { + db.Fatal(err) + } + } + stmt, err := db.PrepareContext(ctx, `SELECT id FROM query_context;`) + if err != nil { + b.Fatal(err) + } + b.Cleanup(func() { stmt.Close() }) + + var n int + for i := 0; i < b.N; i++ { + rows, err := stmt.QueryContext(ctx) + if err != nil { + b.Fatal(err) + } + for rows.Next() { + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + } + if err := rows.Err(); err != nil { + b.Fatal(err) + } + } + } + + // When the context does not have a Done channel we should use + // the fast path that directly handles the query instead of + // handling it in a goroutine. This benchmark also serves to + // highlight the performance impact of using a cancelable + // context. + b.Run("Background", func(b *testing.B) { + test(context.Background(), b) + }) + + // Benchmark a query with a context that can be canceled. This + // requires using a goroutine and is thus much slower. + b.Run("WithCancel", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + test(ctx, b) + }) +} + // benchmarkParams is benchmark for params func benchmarkParams(b *testing.B) { for i := 0; i < b.N; i++ {