From a66fe30908ccded09ebfe1fdd8f7a34d83666911 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Sun, 8 Dec 2024 19:47:14 -0500 Subject: [PATCH 1/2] make sure Close always removes the runtime finalizer This commit fixes a bug in {SQLiteConn,SQLiteStmt}.Close that would lead to the runtime finalizer not being removed if there was an error closing the connection or statement. This commit also makes it safe to call SQLiteConn.Close multiple times. --- sqlite3.go | 37 +++++++++++++++---------------------- sqlite3_test.go | 8 ++++++++ 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 3025a500..c02af5ab 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1782,26 +1782,20 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { } // Close the connection. -func (c *SQLiteConn) Close() error { +func (c *SQLiteConn) Close() (err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.db == nil { + return nil // Already closed + } + runtime.SetFinalizer(c, nil) rv := C.sqlite3_close_v2(c.db) if rv != C.SQLITE_OK { - return c.lastError() + err = c.lastError() } deleteHandles(c) - c.mu.Lock() c.db = nil - c.mu.Unlock() - runtime.SetFinalizer(c, nil) - return nil -} - -func (c *SQLiteConn) dbConnOpen() bool { - if c == nil { - return false - } - c.mu.Lock() - defer c.mu.Unlock() - return c.db != nil + return err } // Prepare the query string. Return a new statement. @@ -1901,16 +1895,15 @@ func (s *SQLiteStmt) Close() error { return nil } s.closed = true - if !s.c.dbConnOpen() { - return errors.New("sqlite statement with already closed database connection") - } - rv := C.sqlite3_finalize(s.s) + conn := s.c + stmt := s.s + s.c = nil s.s = nil + runtime.SetFinalizer(s, nil) + rv := C.sqlite3_finalize(stmt) if rv != C.SQLITE_OK { - return s.c.lastError() + return conn.lastError() } - s.c = nil - runtime.SetFinalizer(s, nil) return nil } diff --git a/sqlite3_test.go b/sqlite3_test.go index 94de7386..7f66df5e 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -338,6 +338,14 @@ func TestClose(t *testing.T) { if err == nil { t.Fatal("Failed to operate closed statement") } + // Closing a statement should not error even if the db is closed. + if err := stmt.Close(); err != nil { + t.Fatal(err) + } + // Second close should be a no-op + if err := stmt.Close(); err != nil { + t.Fatal(err) + } } func TestInsert(t *testing.T) { From e85adcf795b6467a973f26bbfd33bd2702cf3889 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Sun, 15 Dec 2024 20:28:27 -0500 Subject: [PATCH 2/2] use extended error codes and handle closed db connections in lastError This commit changes the error handling logic so that it respects the offending result code (instead of only relying on sqlite3_errcode) and changes the db connection to always report the extended result code (which eliminates the need to call sqlite3_extended_errcode). These changes make it possible to correctly and safely handle errors when the underlying db connection has been closed. --- backup.go | 5 ++- error.go | 28 ++++++++++++-- sqlite3.go | 61 ++++++++++++++++++------------- sqlite3_opt_unlock_notify.c | 20 ++++------ sqlite3_opt_unlock_notify_test.go | 22 +++++++---- sqlite3_opt_userauth.go | 8 ++-- sqlite3_opt_vtable.go | 6 +-- sqlite3_trace.go | 2 +- 8 files changed, 94 insertions(+), 58 deletions(-) diff --git a/backup.go b/backup.go index ecbb4697..16a72484 100644 --- a/backup.go +++ b/backup.go @@ -36,7 +36,10 @@ func (destConn *SQLiteConn) Backup(dest string, srcConn *SQLiteConn, src string) runtime.SetFinalizer(bb, (*SQLiteBackup).Finish) return bb, nil } - return nil, destConn.lastError() + if destConn.db != nil { + return nil, destConn.lastError(int(C.sqlite3_extended_errcode(destConn.db))) + } + return nil, Error{Code: 1, ExtendedCode: 1, err: "backup: destination connection is nil"} } // Step to backs up for one step. Calls the underlying `sqlite3_backup_step` diff --git a/error.go b/error.go index 58ab252e..438bb56e 100644 --- a/error.go +++ b/error.go @@ -13,13 +13,16 @@ package sqlite3 #endif */ import "C" -import "syscall" +import ( + "sync" + "syscall" +) // ErrNo inherit errno. type ErrNo int // ErrNoMask is mask code. -const ErrNoMask C.int = 0xff +const ErrNoMask = 0xff // ErrNoExtended is extended errno. type ErrNoExtended int @@ -85,7 +88,7 @@ func (err Error) Error() string { if err.err != "" { str = err.err } else { - str = C.GoString(C.sqlite3_errstr(C.int(err.Code))) + str = errorString(int(err.Code)) } if err.SystemErrno != 0 { str += ": " + err.SystemErrno.Error() @@ -148,3 +151,22 @@ var ( ErrNoticeRecoverRollback = ErrNotice.Extend(2) ErrWarningAutoIndex = ErrWarning.Extend(1) ) + +var errStrCache sync.Map // int => string + +// errorString returns the result of sqlite3_errstr for result code rv, +// which may be cached. +func errorString(rv int) string { + if v, ok := errStrCache.Load(rv); ok { + return v.(string) + } + s := C.GoString(C.sqlite3_errstr(C.int(rv))) + // Prevent the cache from growing unbounded by ignoring invalid + // error codes. + if s != "unknown error" { + if v, loaded := errStrCache.LoadOrStore(rv, s); loaded { + s = v.(string) + } + } + return s +} diff --git a/sqlite3.go b/sqlite3.go index c02af5ab..fd9c3827 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -540,7 +540,7 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int defer C.free(unsafe.Pointer(cname)) rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, handle, (*[0]byte)(unsafe.Pointer(C.compareTrampoline))) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -675,7 +675,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl any, pure bool) error { } rv := sqlite3CreateFunction(c.db, cname, C.int(numArgs), C.int(opts), newHandle(c, &fi), C.callbackTrampoline, nil, nil) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -804,7 +804,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error } rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -816,32 +816,38 @@ func (c *SQLiteConn) AutoCommit() bool { return int(C.sqlite3_get_autocommit(c.db)) != 0 } -func (c *SQLiteConn) lastError() error { - return lastError(c.db) +func (c *SQLiteConn) lastError(rv int) error { + return lastError(c.db, rv) } -// Note: may be called with db == nil -func lastError(db *C.sqlite3) error { - rv := C.sqlite3_errcode(db) // returns SQLITE_NOMEM if db == nil - if rv == C.SQLITE_OK { +func lastError(db *C.sqlite3, rv int) error { + if rv == SQLITE_OK { return nil } - extrv := C.sqlite3_extended_errcode(db) // returns SQLITE_NOMEM if db == nil - errStr := C.GoString(C.sqlite3_errmsg(db)) // returns "out of memory" if db == nil + extrv := rv + // Convert the extended result code to a basic result code. + rv &= ErrNoMask // https://www.sqlite.org/c3ref/system_errno.html // sqlite3_system_errno is only meaningful if the error code was SQLITE_CANTOPEN, // or it was SQLITE_IOERR and the extended code was not SQLITE_IOERR_NOMEM var systemErrno syscall.Errno - if rv == C.SQLITE_CANTOPEN || (rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM) { + if db != nil && (rv == C.SQLITE_CANTOPEN || + (rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM)) { systemErrno = syscall.Errno(C.sqlite3_system_errno(db)) } + var msg string + if db != nil { + msg = C.GoString(C.sqlite3_errmsg(db)) + } else { + msg = errorString(extrv) + } return Error{ Code: ErrNo(rv), ExtendedCode: ErrNoExtended(extrv), SystemErrno: systemErrno, - err: errStr, + err: msg, } } @@ -1467,7 +1473,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if rv != 0 { // Save off the error _before_ closing the database. // This is safe even if db is nil. - err := lastError(db) + err := lastError(db, int(rv)) if db != nil { C.sqlite3_close_v2(db) } @@ -1476,13 +1482,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if db == nil { return nil, errors.New("sqlite succeeded without returning a database") } + rv = C.sqlite3_extended_result_codes(db, 1) + if rv != SQLITE_OK { + C.sqlite3_close_v2(db) + return nil, lastError(db, int(rv)) + } exec := func(s string) error { cs := C.CString(s) rv := C.sqlite3_exec(db, cs, nil, nil, nil) C.free(unsafe.Pointer(cs)) if rv != C.SQLITE_OK { - return lastError(db) + return lastError(db, int(rv)) } return nil } @@ -1791,7 +1802,7 @@ func (c *SQLiteConn) Close() (err error) { runtime.SetFinalizer(c, nil) rv := C.sqlite3_close_v2(c.db) if rv != C.SQLITE_OK { - err = c.lastError() + err = lastError(nil, int(rv)) } deleteHandles(c) c.db = nil @@ -1810,7 +1821,7 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er var tail *C.char rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s, &tail) if rv != C.SQLITE_OK { - return nil, c.lastError() + return nil, c.lastError(int(rv)) } var t string if tail != nil && *tail != '\000' { @@ -1882,7 +1893,7 @@ func (c *SQLiteConn) SetFileControlInt(dbName string, op int, arg int) error { cArg := C.int(arg) rv := C.sqlite3_file_control(c.db, cDBName, C.int(op), unsafe.Pointer(&cArg)) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -1902,7 +1913,7 @@ func (s *SQLiteStmt) Close() error { runtime.SetFinalizer(s, nil) rv := C.sqlite3_finalize(stmt) if rv != C.SQLITE_OK { - return conn.lastError() + return conn.lastError(int(rv)) } return nil } @@ -1917,7 +1928,7 @@ var placeHolder = []byte{0} func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { - return s.c.lastError() + return s.c.lastError(int(rv)) } bindIndices := make([][3]int, len(args)) @@ -1975,7 +1986,7 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) } if rv != C.SQLITE_OK { - return s.c.lastError() + return s.c.lastError(int(rv)) } } } @@ -2085,7 +2096,7 @@ func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) { var rowid, changes C.longlong rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { - err := s.c.lastError() + err := s.c.lastError(int(rv)) C.sqlite3_reset(s.s) C.sqlite3_clear_bindings(s.s) return nil, err @@ -2121,8 +2132,8 @@ func (rc *SQLiteRows) Close() error { } rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_OK { - s.mu.Unlock() - return s.c.lastError() + rc.s.mu.Unlock() + return rc.s.c.lastError(int(rv)) } s.mu.Unlock() return nil @@ -2199,7 +2210,7 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { if rv != C.SQLITE_ROW { rv = C.sqlite3_reset(rc.s.s) if rv != C.SQLITE_OK { - return rc.s.c.lastError() + return rc.s.c.lastError(int(rv)) } return nil } diff --git a/sqlite3_opt_unlock_notify.c b/sqlite3_opt_unlock_notify.c index 3a00f43d..047b4aac 100644 --- a/sqlite3_opt_unlock_notify.c +++ b/sqlite3_opt_unlock_notify.c @@ -4,7 +4,6 @@ // license that can be found in the LICENSE file. #ifdef SQLITE_ENABLE_UNLOCK_NOTIFY -#include #ifndef USE_LIBSQLITE3 #include "sqlite3-binding.h" #else @@ -13,6 +12,10 @@ extern int unlock_notify_wait(sqlite3 *db); +static inline int is_locked(int rv) { + return rv == SQLITE_LOCKED || rv == SQLITE_LOCKED_SHAREDCACHE; +} + int _sqlite3_step_blocking(sqlite3_stmt *stmt) { @@ -22,10 +25,7 @@ _sqlite3_step_blocking(sqlite3_stmt *stmt) db = sqlite3_db_handle(stmt); for (;;) { rv = sqlite3_step(stmt); - if (rv != SQLITE_LOCKED) { - break; - } - if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) { + if (!is_locked(rv)) { break; } rv = unlock_notify_wait(db); @@ -47,10 +47,7 @@ _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* chan db = sqlite3_db_handle(stmt); for (;;) { rv = sqlite3_step(stmt); - if (rv!=SQLITE_LOCKED) { - break; - } - if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) { + if (!is_locked(rv)) { break; } rv = unlock_notify_wait(db); @@ -72,10 +69,7 @@ _sqlite3_prepare_v2_blocking(sqlite3 *db, const char *zSql, int nBytes, sqlite3_ for (;;) { rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); - if (rv!=SQLITE_LOCKED) { - break; - } - if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) { + if (!is_locked(rv)) { break; } rv = unlock_notify_wait(db); diff --git a/sqlite3_opt_unlock_notify_test.go b/sqlite3_opt_unlock_notify_test.go index 3a9168cd..65107c38 100644 --- a/sqlite3_opt_unlock_notify_test.go +++ b/sqlite3_opt_unlock_notify_test.go @@ -51,12 +51,12 @@ func TestUnlockNotify(t *testing.T) { wg.Add(1) timer := time.NewTimer(500 * time.Millisecond) go func() { + defer wg.Done() <-timer.C err := tx.Commit() if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() rows, err := db.Query("SELECT count(*) from foo") @@ -111,33 +111,39 @@ func TestUnlockNotifyMany(t *testing.T) { wg.Add(1) timer := time.NewTimer(500 * time.Millisecond) go func() { + defer wg.Done() <-timer.C err := tx.Commit() if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() const concurrentQueries = 1000 wg.Add(concurrentQueries) for i := 0; i < concurrentQueries; i++ { go func() { + defer wg.Done() rows, err := db.Query("SELECT count(*) from foo") if err != nil { - t.Fatal("Unable to query foo table:", err) + t.Error("Unable to query foo table:", err) + return } if rows.Next() { var count int if err := rows.Scan(&count); err != nil { - t.Fatal("Failed to Scan rows", err) + t.Error("Failed to Scan rows", err) + return + } + if count != 1 { + t.Errorf("count=%d want=%d", count, 1) } } if err := rows.Err(); err != nil { - t.Fatal("Failed at the call to Next:", err) + t.Error("Failed at the call to Next:", err) + return } - wg.Done() }() } wg.Wait() @@ -177,16 +183,17 @@ func TestUnlockNotifyDeadlock(t *testing.T) { wg.Add(1) timer := time.NewTimer(500 * time.Millisecond) go func() { + defer wg.Done() <-timer.C err := tx.Commit() if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() wg.Add(1) go func() { + defer wg.Done() tx2, err := db.Begin() if err != nil { t.Fatal("Failed to begin transaction:", err) @@ -201,7 +208,6 @@ func TestUnlockNotifyDeadlock(t *testing.T) { if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() rows, err := tx.Query("SELECT count(*) from foo") diff --git a/sqlite3_opt_userauth.go b/sqlite3_opt_userauth.go index 76d84016..1ae2a184 100644 --- a/sqlite3_opt_userauth.go +++ b/sqlite3_opt_userauth.go @@ -95,7 +95,7 @@ func (c *SQLiteConn) Authenticate(username, password string) error { case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } @@ -143,7 +143,7 @@ func (c *SQLiteConn) AuthUserAdd(username, password string, admin bool) error { case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } @@ -193,7 +193,7 @@ func (c *SQLiteConn) AuthUserChange(username, password string, admin bool) error case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } @@ -241,7 +241,7 @@ func (c *SQLiteConn) AuthUserDelete(username string) error { case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } diff --git a/sqlite3_opt_vtable.go b/sqlite3_opt_vtable.go index 9b164b3e..57c7303e 100644 --- a/sqlite3_opt_vtable.go +++ b/sqlite3_opt_vtable.go @@ -692,7 +692,7 @@ func (c *SQLiteConn) DeclareVTab(sql string) error { defer C.free(unsafe.Pointer(zSQL)) rv := C.sqlite3_declare_vtab(c.db, zSQL) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -707,13 +707,13 @@ func (c *SQLiteConn) CreateModule(moduleName string, module Module) error { case EponymousOnlyModule: rv := C._sqlite3_create_module_eponymous_only(c.db, mname, C.uintptr_t(uintptr(newHandle(c, &udm)))) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil case Module: rv := C._sqlite3_create_module(c.db, mname, C.uintptr_t(uintptr(newHandle(c, &udm)))) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } diff --git a/sqlite3_trace.go b/sqlite3_trace.go index 6c47cce1..d1eabacf 100644 --- a/sqlite3_trace.go +++ b/sqlite3_trace.go @@ -282,7 +282,7 @@ func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error { // passing the database connection handle as callback context. if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil }