From d2d4030dd7d7b763bac1c13171bccd325091fbe1 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Sun, 15 Dec 2024 20:28:27 -0500 Subject: [PATCH] use extended error codes and handle closed db connections in lastError This commit changes the error handling logic so that it respects the offending result code (instead of only relying on sqlite3_errcode) and changes the db connection to always report the extended result code (which eliminates the need to call sqlite3_extended_errcode). These changes make it possible to correctly and safely handle errors when the underlying db connection has been closed. --- backup.go | 5 ++- error.go | 24 +++++++++++-- sqlite3.go | 59 ++++++++++++++++++------------- sqlite3_opt_unlock_notify.c | 20 ++++------- sqlite3_opt_unlock_notify_test.go | 22 +++++++----- sqlite3_opt_userauth.go | 8 ++--- sqlite3_opt_vtable.go | 6 ++-- sqlite3_trace.go | 2 +- 8 files changed, 89 insertions(+), 57 deletions(-) diff --git a/backup.go b/backup.go index ecbb4697..16a72484 100644 --- a/backup.go +++ b/backup.go @@ -36,7 +36,10 @@ func (destConn *SQLiteConn) Backup(dest string, srcConn *SQLiteConn, src string) runtime.SetFinalizer(bb, (*SQLiteBackup).Finish) return bb, nil } - return nil, destConn.lastError() + if destConn.db != nil { + return nil, destConn.lastError(int(C.sqlite3_extended_errcode(destConn.db))) + } + return nil, Error{Code: 1, ExtendedCode: 1, err: "backup: destination connection is nil"} } // Step to backs up for one step. Calls the underlying `sqlite3_backup_step` diff --git a/error.go b/error.go index 58ab252e..416f4e85 100644 --- a/error.go +++ b/error.go @@ -13,13 +13,16 @@ package sqlite3 #endif */ import "C" -import "syscall" +import ( + "sync" + "syscall" +) // ErrNo inherit errno. type ErrNo int // ErrNoMask is mask code. -const ErrNoMask C.int = 0xff +const ErrNoMask = 0xff // ErrNoExtended is extended errno. type ErrNoExtended int @@ -85,7 +88,7 @@ func (err Error) Error() string { if err.err != "" { str = err.err } else { - str = C.GoString(C.sqlite3_errstr(C.int(err.Code))) + str = errorString(int(err.Code)) } if err.SystemErrno != 0 { str += ": " + err.SystemErrno.Error() @@ -148,3 +151,18 @@ var ( ErrNoticeRecoverRollback = ErrNotice.Extend(2) ErrWarningAutoIndex = ErrWarning.Extend(1) ) + +var errStrCache sync.Map // int => string + +// errorString returns the result of sqlite3_errstr for result code +// rv which may be cached. +func errorString(rv int) string { + if v, ok := errStrCache.Load(rv); ok { + return v.(string) + } + s := C.GoString(C.sqlite3_errstr(C.int(rv))) + if v, loaded := errStrCache.LoadOrStore(rv, s); loaded { + return v.(string) + } + return s +} diff --git a/sqlite3.go b/sqlite3.go index cead9921..d03b9fda 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -540,7 +540,7 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int defer C.free(unsafe.Pointer(cname)) rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, handle, (*[0]byte)(unsafe.Pointer(C.compareTrampoline))) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -675,7 +675,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl any, pure bool) error { } rv := sqlite3CreateFunction(c.db, cname, C.int(numArgs), C.int(opts), newHandle(c, &fi), C.callbackTrampoline, nil, nil) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -804,7 +804,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error } rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -816,32 +816,38 @@ func (c *SQLiteConn) AutoCommit() bool { return int(C.sqlite3_get_autocommit(c.db)) != 0 } -func (c *SQLiteConn) lastError() error { - return lastError(c.db) +func (c *SQLiteConn) lastError(rv int) error { + return lastError(c.db, rv) } -// Note: may be called with db == nil -func lastError(db *C.sqlite3) error { - rv := C.sqlite3_errcode(db) // returns SQLITE_NOMEM if db == nil - if rv == C.SQLITE_OK { +func lastError(db *C.sqlite3, rv int) error { + if rv == SQLITE_OK { return nil } - extrv := C.sqlite3_extended_errcode(db) // returns SQLITE_NOMEM if db == nil - errStr := C.GoString(C.sqlite3_errmsg(db)) // returns "out of memory" if db == nil + extrv := rv + // Convert the extended result code to a basic result code. + rv &= ErrNoMask // https://www.sqlite.org/c3ref/system_errno.html // sqlite3_system_errno is only meaningful if the error code was SQLITE_CANTOPEN, // or it was SQLITE_IOERR and the extended code was not SQLITE_IOERR_NOMEM var systemErrno syscall.Errno - if rv == C.SQLITE_CANTOPEN || (rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM) { + if db != nil && (rv == C.SQLITE_CANTOPEN || + (rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM)) { systemErrno = syscall.Errno(C.sqlite3_system_errno(db)) } + var msg string + if db != nil { + msg = C.GoString(C.sqlite3_errmsg(db)) + } else { + msg = errorString(extrv) + } return Error{ Code: ErrNo(rv), ExtendedCode: ErrNoExtended(extrv), SystemErrno: systemErrno, - err: errStr, + err: msg, } } @@ -1467,7 +1473,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if rv != 0 { // Save off the error _before_ closing the database. // This is safe even if db is nil. - err := lastError(db) + err := lastError(db, int(rv)) if db != nil { C.sqlite3_close_v2(db) } @@ -1476,13 +1482,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if db == nil { return nil, errors.New("sqlite succeeded without returning a database") } + rv = C.sqlite3_extended_result_codes(db, 1) + if rv != SQLITE_OK { + C.sqlite3_close_v2(db) + return nil, lastError(db, int(rv)) + } exec := func(s string) error { cs := C.CString(s) rv := C.sqlite3_exec(db, cs, nil, nil, nil) C.free(unsafe.Pointer(cs)) if rv != C.SQLITE_OK { - return lastError(db) + return lastError(db, int(rv)) } return nil } @@ -1791,7 +1802,7 @@ func (c *SQLiteConn) Close() (err error) { runtime.SetFinalizer(c, nil) rv := C.sqlite3_close_v2(c.db) if rv != C.SQLITE_OK { - err = c.lastError() + err = lastError(nil, int(rv)) } deleteHandles(c) c.db = nil @@ -1810,7 +1821,7 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er var tail *C.char rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s, &tail) if rv != C.SQLITE_OK { - return nil, c.lastError() + return nil, c.lastError(int(rv)) } var t string if tail != nil && *tail != '\000' { @@ -1882,7 +1893,7 @@ func (c *SQLiteConn) SetFileControlInt(dbName string, op int, arg int) error { cArg := C.int(arg) rv := C.sqlite3_file_control(c.db, cDBName, C.int(op), unsafe.Pointer(&cArg)) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -1902,7 +1913,7 @@ func (s *SQLiteStmt) Close() error { runtime.SetFinalizer(s, nil) rv := C.sqlite3_finalize(stmt) if rv != C.SQLITE_OK { - return conn.lastError() + return conn.lastError(int(rv)) } return nil } @@ -1917,7 +1928,7 @@ var placeHolder = []byte{0} func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { - return s.c.lastError() + return s.c.lastError(int(rv)) } bindIndices := make([][3]int, len(args)) @@ -1975,7 +1986,7 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) } if rv != C.SQLITE_OK { - return s.c.lastError() + return s.c.lastError(int(rv)) } } } @@ -2087,7 +2098,7 @@ func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) { var rowid, changes C.longlong rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { - err := s.c.lastError() + err := s.c.lastError(int(rv)) C.sqlite3_reset(s.s) C.sqlite3_clear_bindings(s.s) return nil, err @@ -2118,7 +2129,7 @@ func (rc *SQLiteRows) Close() error { rv := C.sqlite3_reset(rc.s.s) if rv != C.SQLITE_OK { rc.s.mu.Unlock() - return rc.s.c.lastError() + return rc.s.c.lastError(int(rv)) } rc.s.mu.Unlock() rc.s = nil @@ -2197,7 +2208,7 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { if rv != C.SQLITE_ROW { rv = C.sqlite3_reset(rc.s.s) if rv != C.SQLITE_OK { - return rc.s.c.lastError() + return rc.s.c.lastError(int(rv)) } return nil } diff --git a/sqlite3_opt_unlock_notify.c b/sqlite3_opt_unlock_notify.c index fc37b336..55f6f0dd 100644 --- a/sqlite3_opt_unlock_notify.c +++ b/sqlite3_opt_unlock_notify.c @@ -4,11 +4,14 @@ // license that can be found in the LICENSE file. #ifdef SQLITE_ENABLE_UNLOCK_NOTIFY -#include #include "sqlite3-binding.h" extern int unlock_notify_wait(sqlite3 *db); +static inline int is_locked(int rv) { + return rv == SQLITE_LOCKED || rv == SQLITE_LOCKED_SHAREDCACHE; +} + int _sqlite3_step_blocking(sqlite3_stmt *stmt) { @@ -18,10 +21,7 @@ _sqlite3_step_blocking(sqlite3_stmt *stmt) db = sqlite3_db_handle(stmt); for (;;) { rv = sqlite3_step(stmt); - if (rv != SQLITE_LOCKED) { - break; - } - if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) { + if (!is_locked(rv)) { break; } rv = unlock_notify_wait(db); @@ -43,10 +43,7 @@ _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* chan db = sqlite3_db_handle(stmt); for (;;) { rv = sqlite3_step(stmt); - if (rv!=SQLITE_LOCKED) { - break; - } - if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) { + if (!is_locked(rv)) { break; } rv = unlock_notify_wait(db); @@ -68,10 +65,7 @@ _sqlite3_prepare_v2_blocking(sqlite3 *db, const char *zSql, int nBytes, sqlite3_ for (;;) { rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); - if (rv!=SQLITE_LOCKED) { - break; - } - if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) { + if (!is_locked(rv)) { break; } rv = unlock_notify_wait(db); diff --git a/sqlite3_opt_unlock_notify_test.go b/sqlite3_opt_unlock_notify_test.go index 3a9168cd..65107c38 100644 --- a/sqlite3_opt_unlock_notify_test.go +++ b/sqlite3_opt_unlock_notify_test.go @@ -51,12 +51,12 @@ func TestUnlockNotify(t *testing.T) { wg.Add(1) timer := time.NewTimer(500 * time.Millisecond) go func() { + defer wg.Done() <-timer.C err := tx.Commit() if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() rows, err := db.Query("SELECT count(*) from foo") @@ -111,33 +111,39 @@ func TestUnlockNotifyMany(t *testing.T) { wg.Add(1) timer := time.NewTimer(500 * time.Millisecond) go func() { + defer wg.Done() <-timer.C err := tx.Commit() if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() const concurrentQueries = 1000 wg.Add(concurrentQueries) for i := 0; i < concurrentQueries; i++ { go func() { + defer wg.Done() rows, err := db.Query("SELECT count(*) from foo") if err != nil { - t.Fatal("Unable to query foo table:", err) + t.Error("Unable to query foo table:", err) + return } if rows.Next() { var count int if err := rows.Scan(&count); err != nil { - t.Fatal("Failed to Scan rows", err) + t.Error("Failed to Scan rows", err) + return + } + if count != 1 { + t.Errorf("count=%d want=%d", count, 1) } } if err := rows.Err(); err != nil { - t.Fatal("Failed at the call to Next:", err) + t.Error("Failed at the call to Next:", err) + return } - wg.Done() }() } wg.Wait() @@ -177,16 +183,17 @@ func TestUnlockNotifyDeadlock(t *testing.T) { wg.Add(1) timer := time.NewTimer(500 * time.Millisecond) go func() { + defer wg.Done() <-timer.C err := tx.Commit() if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() wg.Add(1) go func() { + defer wg.Done() tx2, err := db.Begin() if err != nil { t.Fatal("Failed to begin transaction:", err) @@ -201,7 +208,6 @@ func TestUnlockNotifyDeadlock(t *testing.T) { if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() rows, err := tx.Query("SELECT count(*) from foo") diff --git a/sqlite3_opt_userauth.go b/sqlite3_opt_userauth.go index 76d84016..1ae2a184 100644 --- a/sqlite3_opt_userauth.go +++ b/sqlite3_opt_userauth.go @@ -95,7 +95,7 @@ func (c *SQLiteConn) Authenticate(username, password string) error { case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } @@ -143,7 +143,7 @@ func (c *SQLiteConn) AuthUserAdd(username, password string, admin bool) error { case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } @@ -193,7 +193,7 @@ func (c *SQLiteConn) AuthUserChange(username, password string, admin bool) error case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } @@ -241,7 +241,7 @@ func (c *SQLiteConn) AuthUserDelete(username string) error { case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } diff --git a/sqlite3_opt_vtable.go b/sqlite3_opt_vtable.go index 9b164b3e..57c7303e 100644 --- a/sqlite3_opt_vtable.go +++ b/sqlite3_opt_vtable.go @@ -692,7 +692,7 @@ func (c *SQLiteConn) DeclareVTab(sql string) error { defer C.free(unsafe.Pointer(zSQL)) rv := C.sqlite3_declare_vtab(c.db, zSQL) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -707,13 +707,13 @@ func (c *SQLiteConn) CreateModule(moduleName string, module Module) error { case EponymousOnlyModule: rv := C._sqlite3_create_module_eponymous_only(c.db, mname, C.uintptr_t(uintptr(newHandle(c, &udm)))) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil case Module: rv := C._sqlite3_create_module(c.db, mname, C.uintptr_t(uintptr(newHandle(c, &udm)))) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } diff --git a/sqlite3_trace.go b/sqlite3_trace.go index 6c47cce1..d1eabacf 100644 --- a/sqlite3_trace.go +++ b/sqlite3_trace.go @@ -282,7 +282,7 @@ func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error { // passing the database connection handle as callback context. if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil }