diff --git a/sqlite3.go b/sqlite3.go index ce985ec8..5d535c72 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -145,6 +145,25 @@ void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) { sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT); } +static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes, + sqlite3_stmt **ppStmt, int *oBytes) { + const char *tail; + int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail); + if (rv != SQLITE_OK) { + return rv; + } + // TODO: combine this with the below logic + if (tail == NULL) { + // NB: this should not happen + *oBytes = nBytes; + return rv; + } + // Set oBytes to the number of bytes consumed instead of using the **pzTail + // out param since that requires storing a Go pointer in a C pointer, which + // is not allowed by CGO and will cause runtime.cgoCheckPointer to fail. + *oBytes = tail - zSql; + return rv; +} int _sqlite3_create_function( sqlite3 *db, @@ -919,44 +938,56 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro } func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - start := 0 + var ( + // stmtArgs []driver.NamedValue + start int + s SQLiteStmt // escapes to the heap so reuse it + sz C.int // number of query bytes consumed: escapes to the heap + ) + query = strings.TrimSpace(query) for { - stmtArgs := make([]driver.NamedValue, 0, len(args)) - s, err := c.prepare(ctx, query) - if err != nil { - return nil, err + s = SQLiteStmt{c: c, cls: true} // reset + sz = 0 + rv := C._sqlite3_prepare_v2(c.db, (*C.char)(unsafe.Pointer(stringData(query))), + C.int(len(query)), &s.s, &sz) + if rv != C.SQLITE_OK { + return nil, c.lastError() } - s.(*SQLiteStmt).cls = true + query = strings.TrimSpace(query[sz:]) + na := s.NumInput() if len(args)-start < na { - s.Close() - return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) + s.finalize() + return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } // consume the number of arguments used in the current - // statement and append all named arguments not contained - // therein - stmtArgs = append(stmtArgs, args[start:start+na]...) + // statement and append all named arguments not + // contained therein + // if stmtArgs == nil { + // stmtArgs = make([]driver.NamedValue, 0, na) + // } + // stmtArgs = append(stmtArgs[:0], args[start:start+na]...) + stmtArgs := args[start : start+na : start+na] for i := range args { if (i < start || i >= na) && args[i].Name != "" { stmtArgs = append(stmtArgs, args[i]) } } + start += na for i := range stmtArgs { stmtArgs[i].Ordinal = i + 1 } - rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs) + + rows, err := s.query(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { - s.Close() - return rows, err + s.finalize() + return nil, err } - start += na - tail := s.(*SQLiteStmt).t - if tail == "" { + if len(query) == 0 { return rows, nil } rows.Close() - s.Close() - query = tail + s.finalize() } } @@ -1914,6 +1945,13 @@ func (s *SQLiteStmt) Close() error { return nil } +func (s *SQLiteStmt) finalize() { + if s.s != nil { + C.sqlite3_finalize(s.s) + s.s = nil + } +} + // NumInput return a number of parameters. func (s *SQLiteStmt) NumInput() int { return int(C.sqlite3_bind_parameter_count(s.s)) @@ -1921,26 +1959,92 @@ func (s *SQLiteStmt) NumInput() int { var placeHolder = []byte{0} +func hasNamedArgs(args []driver.NamedValue) bool { + for _, v := range args { + if v.Name != "" { + return true + } + } + return false +} + +// TODO: return the column count as well func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { return s.c.lastError() } + if hasNamedArgs(args) { + return s.bindIndices(args) + } + + for _, arg := range args { + n := C.int(arg.Ordinal) + switch v := arg.Value.(type) { + case nil: + rv = C.sqlite3_bind_null(s.s, n) + case string: + p := stringData(v) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v))) + case int64: + rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) + case bool: + val := 0 + if v { + val = 1 + } + rv = C.sqlite3_bind_int(s.s, n, C.int(val)) + case float64: + rv = C.sqlite3_bind_double(s.s, n, C.double(v)) + case []byte: + if v == nil { + rv = C.sqlite3_bind_null(s.s, n) + } else { + ln := len(v) + if ln == 0 { + v = placeHolder + } + rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) + } + case time.Time: + b := []byte(v.Format(SQLiteTimestampFormats[0])) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + } + if rv != C.SQLITE_OK { + return s.c.lastError() + } + } + return nil +} + +func (s *SQLiteStmt) bindIndices(args []driver.NamedValue) error { + // Find the longest named parameter name. + n := 0 + for _, v := range args { + if m := len(v.Name); m > n { + n = m + } + } + buf := make([]byte, 0, n+2) // +2 for placeholder and null terminator + + // TODO: Reduce the size of this slive by using uint32 or uint16. + // By default SQLITE_MAX_FUNCTION_ARG is 100 so uint16 should work. bindIndices := make([][3]int, len(args)) - prefixes := []string{":", "@", "$"} for i, v := range args { bindIndices[i][0] = args[i].Ordinal if v.Name != "" { - for j := range prefixes { - cname := C.CString(prefixes[j] + v.Name) - bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname)) - C.free(unsafe.Pointer(cname)) + for j, c := range []byte{':', '@', '$'} { + buf = append(buf[:0], c) + buf = append(buf, v.Name...) + buf = append(buf, 0) + bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, (*C.char)(unsafe.Pointer(&buf[0])))) } args[i].Ordinal = bindIndices[i][0] } } + var rv C.int for i, arg := range args { for j := range bindIndices[i] { if bindIndices[i][j] == 0 { @@ -1951,20 +2055,16 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { case nil: rv = C.sqlite3_bind_null(s.s, n) case string: - if len(v) == 0 { - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) - } else { - b := []byte(v) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) - } + p := stringData(v) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v))) case int64: rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) case bool: + val := 0 if v { - rv = C.sqlite3_bind_int(s.s, n, 1) - } else { - rv = C.sqlite3_bind_int(s.s, n, 0) + val = 1 } + rv = C.sqlite3_bind_int(s.s, n, C.int(val)) case float64: rv = C.sqlite3_bind_double(s.s, n, C.double(v)) case []byte: @@ -1989,6 +2089,74 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { return nil } +// func (s *SQLiteStmt) bind(args []driver.NamedValue) error { +// rv := C.sqlite3_reset(s.s) +// if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { +// return s.c.lastError() +// } +// +// bindIndices := make([][3]int, len(args)) +// prefixes := []string{":", "@", "$"} +// for i, v := range args { +// bindIndices[i][0] = args[i].Ordinal +// if v.Name != "" { +// for j := range prefixes { +// cname := C.CString(prefixes[j] + v.Name) +// bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname)) +// C.free(unsafe.Pointer(cname)) +// } +// args[i].Ordinal = bindIndices[i][0] +// } +// } +// +// for i, arg := range args { +// for j := range bindIndices[i] { +// if bindIndices[i][j] == 0 { +// continue +// } +// n := C.int(bindIndices[i][j]) +// switch v := arg.Value.(type) { +// case nil: +// rv = C.sqlite3_bind_null(s.s, n) +// case string: +// if len(v) == 0 { +// rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) +// } else { +// b := []byte(v) +// rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) +// } +// case int64: +// rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) +// case bool: +// if v { +// rv = C.sqlite3_bind_int(s.s, n, 1) +// } else { +// rv = C.sqlite3_bind_int(s.s, n, 0) +// } +// case float64: +// rv = C.sqlite3_bind_double(s.s, n, C.double(v)) +// case []byte: +// if v == nil { +// rv = C.sqlite3_bind_null(s.s, n) +// } else { +// ln := len(v) +// if ln == 0 { +// v = placeHolder +// } +// rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) +// } +// case time.Time: +// b := []byte(v.Format(SQLiteTimestampFormats[0])) +// rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) +// } +// if rv != C.SQLITE_OK { +// return s.c.lastError() +// } +// } +// } +// return nil +// } + // Query the statement with arguments. Return records. func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { list := make([]driver.NamedValue, len(args)) diff --git a/sqlite3_test.go b/sqlite3_test.go index 63c939d3..06284aa4 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -10,6 +10,7 @@ package sqlite3 import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -1141,6 +1142,66 @@ func TestQueryer(t *testing.T) { } } +// Test that an empty query does not error or result in an infinite loop. +func TestEmptyQuery(t *testing.T) { + db, err := sql.Open("sqlite3", t.TempDir()+"/test.sqlite3") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + if err := db.QueryRow("").Scan(); err != nil { + t.Fatal(err) + } +} + +// This test ensures that we implement the bug where only the last SQL +// statement passed to Query is executed. See issues #950 and #1161. +// +// This bug should probably be fixed at some point since any code that +// relies on it is subtly wrong and exposing this bug to users would be +// to their benefit, but for now let's ensure that any changes to Query +// maintain the current/buggy behavior. +func TestMulipleStatementQueryBug(t *testing.T) { + db, err := sql.Open("sqlite3", t.TempDir()+"/test.sqlite3") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + t.Run("Rows", func(t *testing.T) { + rows, err := db.Query(`SELECT ?; SELECT ?;`, 1, 2) + if err != nil { + t.Fatal(err) + } + var vals []int64 + for rows.Next() { + var v int64 + if err := rows.Scan(&v); err != nil { + t.Fatal(err) + } + vals = append(vals, v) + } + if err := rows.Err(); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(vals, []int64{2}) { + t.Fatalf("Expected only the last query to be executed got: %v: want: %v", + vals, []int64{2}) + } + }) + + t.Run("QueryRow", func(t *testing.T) { + var v int64 + if err := db.QueryRow(`SELECT ?; SELECT ?;`, 1, 2).Scan(&v); err != nil { + t.Fatal(err) + } + if v != 2 { + t.Fatalf("Expected only the last query to be executed got: %v: want: %v", + v, 2) + } + }) +} + func TestStress(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) @@ -2030,7 +2091,7 @@ func BenchmarkCustomFunctions(b *testing.B) { } func TestSuite(t *testing.T) { - initializeTestDB(t) + initializeTestDB(t, false) defer freeTestDB() for _, test := range tests { @@ -2039,7 +2100,7 @@ func TestSuite(t *testing.T) { } func BenchmarkSuite(b *testing.B) { - initializeTestDB(b) + initializeTestDB(b, true) defer freeTestDB() for _, benchmark := range benchmarks { @@ -2068,8 +2129,13 @@ type TestDB struct { var db *TestDB -func initializeTestDB(t testing.TB) { - tempFilename := TempFilename(t) +func initializeTestDB(t testing.TB, memory bool) { + var tempFilename string + if memory { + tempFilename = ":memory:" + } else { + tempFilename = TempFilename(t) + } d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") if err != nil { os.Remove(tempFilename) @@ -2084,9 +2150,11 @@ func freeTestDB() { if err != nil { panic(err) } - err = os.Remove(db.tempFilename) - if err != nil { - panic(err) + if db.tempFilename != "" && db.tempFilename != ":memory:" { + err := os.Remove(db.tempFilename) + if err != nil { + panic(err) + } } } @@ -2107,6 +2175,8 @@ var tests = []testing.InternalTest{ var benchmarks = []testing.InternalBenchmark{ {Name: "BenchmarkExec", F: benchmarkExec}, {Name: "BenchmarkQuery", F: benchmarkQuery}, + {Name: "BenchmarkQuerySimple", F: benchmarkQuerySimple}, + {Name: "BenchmarkQueryContext", F: benchmarkQueryContext}, {Name: "BenchmarkParams", F: benchmarkParams}, {Name: "BenchmarkStmt", F: benchmarkStmt}, {Name: "BenchmarkRows", F: benchmarkRows}, @@ -2479,6 +2549,75 @@ func benchmarkQuery(b *testing.B) { } } +// Simple query to benchmark the construction/execution of sqlite3 queries. +func benchmarkQuerySimple(b *testing.B) { + var n int64 + for i := 0; i < b.N; i++ { + if err := db.QueryRow("select ?;", int64(1)).Scan(&n); err != nil { + panic(err) + } + } +} + +// benchmarkQueryContext is benchmark for QueryContext +func benchmarkQueryContext(b *testing.B) { + const createTableStmt = ` + CREATE TABLE IF NOT EXISTS query_context( + id INTEGER PRIMARY KEY + ); + DELETE FROM query_context; + VACUUM;` + test := func(ctx context.Context, b *testing.B) { + if _, err := db.Exec(createTableStmt); err != nil { + b.Fatal(err) + } + for i := 0; i < 10; i++ { + _, err := db.Exec("INSERT INTO query_context VALUES (?);", int64(i)) + if err != nil { + db.Fatal(err) + } + } + stmt, err := db.PrepareContext(ctx, `SELECT id FROM query_context;`) + if err != nil { + b.Fatal(err) + } + b.Cleanup(func() { stmt.Close() }) + + var n int + for i := 0; i < b.N; i++ { + rows, err := stmt.QueryContext(ctx) + if err != nil { + b.Fatal(err) + } + for rows.Next() { + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + } + if err := rows.Err(); err != nil { + b.Fatal(err) + } + } + } + + // When the context does not have a Done channel we should use + // the fast path that directly handles the query instead of + // handling it in a goroutine. This benchmark also serves to + // highlight the performance impact of using a cancelable + // context. + b.Run("Background", func(b *testing.B) { + test(context.Background(), b) + }) + + // Benchmark a query with a context that can be canceled. This + // requires using a goroutine and is thus much slower. + b.Run("WithCancel", func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + test(ctx, b) + }) +} + // benchmarkParams is benchmark for params func benchmarkParams(b *testing.B) { for i := 0; i < b.N; i++ {