From c34a16e58980521c452f4cba8ae1056681c03263 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Thu, 2 Feb 2023 17:17:49 -0500 Subject: [PATCH 1/2] sqlite3: reduce C to Go string conversions in SQLiteConn.{query,exec} MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit fixes an issue where SQLiteConn.{exec,query} would use an exponential amount of memory when processing a prepared statement that consisted of multiple SQL statements ("stmt 1; stmt 2; ..."). Previously both exec/query used SQLiteConn.prepare() which converts the "tail" pointer set by sqlite3_prepare_v2() back into a Go string, which then has to be converted back to a C string for the next call to prepare() (assuming there are multiple SQL statements in the provided query). This commit fixes this by changing both exec and query to use the returned "tail" pointer for the next call to exec/query. It also changes prepare() to use pointer arithmetic to calculate the offset of the remaining "tail" portion of the query as a substring of the original query. This saves a call to C.GoString() and an allocation. Benchmarks: ``` goos: darwin goarch: arm64 pkg: github.com/mattn/go-sqlite3 │ base.10.txt │ new.10.txt │ │ sec/op │ sec/op vs base │ Suite/BenchmarkExec-10 1.351µ ± 1% 1.247µ ± 1% -7.74% (p=0.000 n=10) Suite/BenchmarkQuery-10 3.830µ ± 1% 3.558µ ± 1% -7.11% (p=0.000 n=10) Suite/BenchmarkParams-10 4.221µ ± 0% 4.228µ ± 1% ~ (p=1.000 n=10) Suite/BenchmarkStmt-10 2.906µ ± 1% 2.864µ ± 1% -1.45% (p=0.001 n=10) Suite/BenchmarkRows-10 149.1µ ± 4% 148.2µ ± 1% -0.61% (p=0.023 n=10) Suite/BenchmarkStmtRows-10 147.3µ ± 1% 145.6µ ± 0% -1.16% (p=0.000 n=10) Suite/BenchmarkExecStep-10 1898.9µ ± 3% 889.0µ ± 1% -53.18% (p=0.000 n=10) Suite/BenchmarkQueryStep-10 1848.0µ ± 1% 894.6µ ± 1% -51.59% (p=0.000 n=10) geomean 38.56µ 31.30µ -18.84% │ base.10.txt │ new.10.txt │ │ B/op │ B/op vs base │ Suite/BenchmarkExec-10 184.0 ± 0% 176.0 ± 0% -4.35% (p=0.000 n=10) Suite/BenchmarkQuery-10 864.0 ± 0% 856.0 ± 0% -0.93% (p=0.000 n=10) Suite/BenchmarkParams-10 1.289Ki ± 0% 1.281Ki ± 0% -0.61% (p=0.000 n=10) Suite/BenchmarkStmt-10 1.078Ki ± 0% 1.078Ki ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkRows-10 34.45Ki ± 0% 34.45Ki ± 0% -0.02% (p=0.000 n=10) Suite/BenchmarkStmtRows-10 34.40Ki ± 0% 34.40Ki ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkExecStep-10 5334.61Ki ± 0% 70.41Ki ± 0% -98.68% (p=0.000 n=10) Suite/BenchmarkQueryStep-10 5397.4Ki ± 0% 133.2Ki ± 0% -97.53% (p=0.000 n=10) geomean 17.06Ki 6.208Ki -63.62% ¹ all samples are equal │ base.10.txt │ new.10.txt │ │ allocs/op │ allocs/op vs base │ Suite/BenchmarkExec-10 13.00 ± 0% 12.00 ± 0% -7.69% (p=0.000 n=10) Suite/BenchmarkQuery-10 46.00 ± 0% 45.00 ± 0% -2.17% (p=0.000 n=10) Suite/BenchmarkParams-10 54.00 ± 0% 53.00 ± 0% -1.85% (p=0.000 n=10) Suite/BenchmarkStmt-10 49.00 ± 0% 49.00 ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkRows-10 2.042k ± 0% 2.041k ± 0% -0.05% (p=0.000 n=10) Suite/BenchmarkStmtRows-10 2.038k ± 0% 2.038k ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkExecStep-10 13.000k ± 0% 8.004k ± 0% -38.43% (p=0.000 n=10) Suite/BenchmarkQueryStep-10 11.013k ± 0% 6.017k ± 0% -45.36% (p=0.000 n=10) geomean 418.6 359.8 -14.04% ¹ all samples are equal ``` --- sqlite3.go | 102 ++++++++++++++++++++++++++++++++++-------------- sqlite3_test.go | 21 ++++++++++ 2 files changed, 94 insertions(+), 29 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index ed2a9e2a..6d6a30d6 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -30,6 +30,7 @@ package sqlite3 #endif #include #include +#include #ifdef __CYGWIN__ # include @@ -90,6 +91,16 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change return rv; } +static const char * +_trim_leading_spaces(const char *str) { + if (str) { + while (isspace(*str)) { + str++; + } + } + return str; +} + #ifdef SQLITE_ENABLE_UNLOCK_NOTIFY extern int _sqlite3_step_blocking(sqlite3_stmt *stmt); extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes); @@ -110,7 +121,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan static int _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) { - return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail); + int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail); + if (pzTail) { + *pzTail = _trim_leading_spaces(*pzTail); + } + return rv; } #else @@ -133,7 +148,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan static int _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) { - return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); + int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); + if (pzTail) { + *pzTail = _trim_leading_spaces(*pzTail); + } + return rv; } #endif @@ -858,25 +877,34 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err } func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + pquery := C.CString(query) + op := pquery // original pointer + defer C.free(unsafe.Pointer(op)) + + var stmtArgs []driver.NamedValue + var tail *C.char + s := new(SQLiteStmt) // escapes to the heap so reuse it + defer s.finalize() start := 0 for { - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err + *s = SQLiteStmt{c: c} // reset + rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail) + if rv != C.SQLITE_OK { + return nil, c.lastError() } + var res driver.Result - if s.(*SQLiteStmt).s != nil { - stmtArgs := make([]driver.NamedValue, 0, len(args)) + if s.s != nil { na := s.NumInput() if len(args)-start < na { - s.Close() + s.finalize() return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } // consume the number of arguments used in the current // statement and append all named arguments not // contained therein if len(args[start:start+na]) > 0 { - stmtArgs = append(stmtArgs, args[start:start+na]...) + stmtArgs = append(stmtArgs[:0], args[start:start+na]...) for i := range args { if (i < start || i >= na) && args[i].Name != "" { stmtArgs = append(stmtArgs, args[i]) @@ -886,23 +914,23 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named stmtArgs[i].Ordinal = i + 1 } } - res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) + var err error + res, err = s.exec(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { - s.Close() + s.finalize() return nil, err } start += na } - tail := s.(*SQLiteStmt).t - s.Close() - if tail == "" { + s.finalize() + if tail == nil || *tail == '\000' { if res == nil { // https://github.com/mattn/go-sqlite3/issues/963 res = &SQLiteResult{0, 0} } return res, nil } - query = tail + pquery = tail } } @@ -919,14 +947,21 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro } func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + pquery := C.CString(query) + op := pquery // original pointer + defer C.free(unsafe.Pointer(op)) + + var stmtArgs []driver.NamedValue + var tail *C.char + s := new(SQLiteStmt) // escapes to the heap so reuse it start := 0 for { - stmtArgs := make([]driver.NamedValue, 0, len(args)) - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err + *s = SQLiteStmt{c: c, cls: true} // reset + rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail) + if rv != C.SQLITE_OK { + return nil, c.lastError() } - s.(*SQLiteStmt).cls = true + na := s.NumInput() if len(args)-start < na { return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) @@ -934,7 +969,7 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name // consume the number of arguments used in the current // statement and append all named arguments not contained // therein - stmtArgs = append(stmtArgs, args[start:start+na]...) + stmtArgs = append(stmtArgs[:0], args[start:start+na]...) for i := range args { if (i < start || i >= na) && args[i].Name != "" { stmtArgs = append(stmtArgs, args[i]) @@ -943,19 +978,18 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name for i := range stmtArgs { stmtArgs[i].Ordinal = i + 1 } - rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs) + rows, err := s.query(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { - s.Close() + s.finalize() return rows, err } start += na - tail := s.(*SQLiteStmt).t - if tail == "" { + if tail == nil || *tail == '\000' { return rows, nil } rows.Close() - s.Close() - query = tail + s.finalize() + pquery = tail } } @@ -1818,8 +1852,11 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er return nil, c.lastError() } var t string - if tail != nil && *tail != '\000' { - t = strings.TrimSpace(C.GoString(tail)) + if tail != nil && *tail != 0 { + n := int(uintptr(unsafe.Pointer(tail))) - int(uintptr(unsafe.Pointer(pquery))) + if 0 <= n && n < len(query) { + t = strings.TrimSpace(query[n:]) + } } ss := &SQLiteStmt{c: c, s: s, t: t} runtime.SetFinalizer(ss, (*SQLiteStmt).Close) @@ -1913,6 +1950,13 @@ func (s *SQLiteStmt) Close() error { return nil } +func (s *SQLiteStmt) finalize() { + if s.s != nil { + C.sqlite3_finalize(s.s) + s.s = nil + } +} + // NumInput return a number of parameters. func (s *SQLiteStmt) NumInput() int { return int(C.sqlite3_bind_parameter_count(s.s)) diff --git a/sqlite3_test.go b/sqlite3_test.go index 63c939d3..2211ad1c 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -2111,6 +2111,8 @@ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkStmt", F: benchmarkStmt}, {Name: "BenchmarkRows", F: benchmarkRows}, {Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, + {Name: "BenchmarkExecStep", F: benchmarkExecStep}, + {Name: "BenchmarkQueryStep", F: benchmarkQueryStep}, } func (db *TestDB) mustExec(sql string, args ...any) sql.Result { @@ -2568,3 +2570,22 @@ func benchmarkStmtRows(b *testing.B) { } } } + +var largeSelectStmt = strings.Repeat("select 1;\n", 1_000) + +func benchmarkExecStep(b *testing.B) { + for n := 0; n < b.N; n++ { + if _, err := db.Exec(largeSelectStmt); err != nil { + b.Fatal(err) + } + } +} + +func benchmarkQueryStep(b *testing.B) { + var i int + for n := 0; n < b.N; n++ { + if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil { + b.Fatal(err) + } + } +} From f5e855b246d78a66eb30148f42ac90649d0673ee Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Thu, 20 Apr 2023 00:16:13 -0400 Subject: [PATCH 2/2] sqlite3: handle trailing comments and multiple SQL statements in Queries This commit fixes *SQLiteConn.Query to properly handle trailing comments after a SQL query statement. Previously, trailing comments could lead to an infinite loop. It also changes Query to error if the provided SQL statement contains multiple queries ("SELECT 1; SELECT 2;") - previously only the last query was executed ("SELECT 1; SELECT 2;" would yield only 2). This may be a breaking change as previously: Query consumed all of its args - despite only using the last query (Query now only uses the args required to satisfy the first query and errors if there is a mismatch); Query used only the last query and there may be code using this library that depends on this behavior. Personally, I believe the behavior introduced by this commit is correct and any code relying on the prior undocumented behavior incorrect, but it could still be a break. --- sqlite3.go | 92 ++++++++++-------------- sqlite3_test.go | 182 +++++++++++++++++++++++++++++++++++------------- 2 files changed, 168 insertions(+), 106 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 6d6a30d6..91e428c8 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -30,7 +30,6 @@ package sqlite3 #endif #include #include -#include #ifdef __CYGWIN__ # include @@ -91,16 +90,6 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change return rv; } -static const char * -_trim_leading_spaces(const char *str) { - if (str) { - while (isspace(*str)) { - str++; - } - } - return str; -} - #ifdef SQLITE_ENABLE_UNLOCK_NOTIFY extern int _sqlite3_step_blocking(sqlite3_stmt *stmt); extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes); @@ -121,11 +110,7 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan static int _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) { - int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail); - if (pzTail) { - *pzTail = _trim_leading_spaces(*pzTail); - } - return rv; + return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail); } #else @@ -148,12 +133,9 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan static int _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) { - int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); - if (pzTail) { - *pzTail = _trim_leading_spaces(*pzTail); - } - return rv; + return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); } + #endif void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { @@ -951,46 +933,44 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name op := pquery // original pointer defer C.free(unsafe.Pointer(op)) - var stmtArgs []driver.NamedValue var tail *C.char - s := new(SQLiteStmt) // escapes to the heap so reuse it - start := 0 - for { - *s = SQLiteStmt{c: c, cls: true} // reset - rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail) - if rv != C.SQLITE_OK { - return nil, c.lastError() + s := &SQLiteStmt{c: c, cls: true} + rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail) + if rv != C.SQLITE_OK { + return nil, c.lastError() + } + if s.s == nil { + return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil + } + na := s.NumInput() + if n := len(args); n != na { + s.finalize() + if n < na { + return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } + return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args)) + } + rows, err := s.query(ctx, args) + if err != nil && err != driver.ErrSkip { + s.finalize() + return rows, err + } - na := s.NumInput() - if len(args)-start < na { - return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) - } - // consume the number of arguments used in the current - // statement and append all named arguments not contained - // therein - stmtArgs = append(stmtArgs[:0], args[start:start+na]...) - for i := range args { - if (i < start || i >= na) && args[i].Name != "" { - stmtArgs = append(stmtArgs, args[i]) - } - } - for i := range stmtArgs { - stmtArgs[i].Ordinal = i + 1 - } - rows, err := s.query(ctx, stmtArgs) - if err != nil && err != driver.ErrSkip { - s.finalize() - return rows, err + // Consume the rest of the query + for pquery = tail; pquery != nil && *pquery != 0; pquery = tail { + var stmt *C.sqlite3_stmt + rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail) + if rv != C.SQLITE_OK { + rows.Close() + return nil, c.lastError() } - start += na - if tail == nil || *tail == '\000' { - return rows, nil + if stmt != nil { + rows.Close() + return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN } - rows.Close() - s.finalize() - pquery = tail } + + return rows, nil } // Begin transaction. @@ -2044,7 +2024,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { return s.query(context.Background(), list) } -func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) { if err := s.bind(args); err != nil { return nil, err } diff --git a/sqlite3_test.go b/sqlite3_test.go index 2211ad1c..089cf21a 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "net/url" "os" + "path/filepath" "reflect" "regexp" "runtime" @@ -1080,67 +1081,158 @@ func TestExecer(t *testing.T) { defer db.Close() _, err = db.Exec(` - create table foo (id integer); -- one comment - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); -- another comment + CREATE TABLE foo (id INTEGER); -- one comment + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); -- another comment `, 1, 2, 3) if err != nil { t.Error("Failed to call db.Exec:", err) } } -func TestQueryer(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) +func testQuery(t *testing.T, seed bool, test func(t *testing.T, db *sql.DB)) { + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3")) if err != nil { t.Fatal("Failed to open database:", err) } defer db.Close() - _, err = db.Exec(` - create table foo (id integer); - `) - if err != nil { - t.Error("Failed to call db.Query:", err) + if seed { + if _, err := db.Exec(`create table foo (id integer);`); err != nil { + t.Fatal(err) + } + _, err := db.Exec(` + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + INSERT INTO foo(id) VALUES(?); + `, 3, 2, 1) + if err != nil { + t.Fatal(err) + } } - _, err = db.Exec(` - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); - `, 3, 2, 1) - if err != nil { - t.Error("Failed to call db.Exec:", err) - } - rows, err := db.Query(` - select id from foo order by id; - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - defer rows.Close() - n := 0 - for rows.Next() { - var id int - err = rows.Scan(&id) + // Capture panic so tests can continue + defer func() { + if e := recover(); e != nil { + buf := make([]byte, 32*1024) + n := runtime.Stack(buf, false) + t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n]) + } + }() + test(t, db) +} + +func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} { + var values []interface{} + testQuery(t, true, func(t *testing.T, db *sql.DB) { + rows, err := db.Query(query, args...) if err != nil { - t.Error("Failed to db.Query:", err) + t.Fatal(err) } - if id != n+1 { - t.Error("Failed to db.Query: not matched results") + if rows == nil { + t.Fatal("nil rows") } - n = n + 1 + for i := 0; rows.Next(); i++ { + if i > 1_000 { + t.Fatal("To many iterations of rows.Next():", i) + } + var v interface{} + if err := rows.Scan(&v); err != nil { + t.Fatal(err) + } + values = append(values, v) + } + if err := rows.Err(); err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + }) + return values +} + +func TestQuery(t *testing.T) { + queries := []struct { + query string + args []interface{} + }{ + {"SELECT id FROM foo ORDER BY id;", nil}, + {"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}}, + {"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}}, + + // Comments + {"SELECT id FROM foo ORDER BY id; -- comment", nil}, + {"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil}, + { + `-- FOO + SELECT id FROM foo ORDER BY id; -- BAR + /* BAZ */`, + nil, + }, } - if err := rows.Err(); err != nil { - t.Errorf("Post-scan failed: %v\n", err) + want := []interface{}{ + int64(1), + int64(2), + int64(3), + } + for _, q := range queries { + t.Run("", func(t *testing.T) { + got := testQueryValues(t, q.query, q.args...) + if !reflect.DeepEqual(got, want) { + t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want) + } + }) } - if n != 3 { - t.Errorf("Expected 3 rows but retrieved %v", n) +} + +func TestQueryNoSQL(t *testing.T) { + got := testQueryValues(t, "") + if got != nil { + t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil) } } +func testQueryError(t *testing.T, query string, args ...interface{}) { + testQuery(t, true, func(t *testing.T, db *sql.DB) { + rows, err := db.Query(query, args...) + if err == nil { + t.Error("Expected an error got:", err) + } + if rows != nil { + t.Error("Returned rows should be nil on error!") + // Attempt to iterate over rows to make sure they don't panic. + for i := 0; rows.Next(); i++ { + if i > 1_000 { + t.Fatal("To many iterations of rows.Next():", i) + } + } + if err := rows.Err(); err != nil { + t.Error(err) + } + rows.Close() + } + }) +} + +func TestQueryNotEnoughArgs(t *testing.T) { + testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1) +} + +func TestQueryTooManyArgs(t *testing.T) { + // TODO: test error message / kind + testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2) +} + +func TestQueryMultipleStatements(t *testing.T) { + testQueryError(t, "SELECT 1; SELECT 2;") +} + +func TestQueryInvalidTable(t *testing.T) { + testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;") +} + func TestStress(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) @@ -2112,7 +2204,6 @@ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkRows", F: benchmarkRows}, {Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, {Name: "BenchmarkExecStep", F: benchmarkExecStep}, - {Name: "BenchmarkQueryStep", F: benchmarkQueryStep}, } func (db *TestDB) mustExec(sql string, args ...any) sql.Result { @@ -2580,12 +2671,3 @@ func benchmarkExecStep(b *testing.B) { } } } - -func benchmarkQueryStep(b *testing.B) { - var i int - for n := 0; n < b.N; n++ { - if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil { - b.Fatal(err) - } - } -}