From 1f1ff36e60ae84d4a9b164b96f3a46c0ede27b38 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Sun, 15 Dec 2024 20:28:27 -0500 Subject: [PATCH] use extended error codes and handle closed db connections in lastError This commit changes the error handling logic so that it respects the offending result code (instead of only relying on sqlite3_errcode) and changes the db connection to always report the extended result code (which eliminates the need to call sqlite3_extended_errcode). These changes make it possible to correctly and safely handle errors when the underlying db connection has been closed. --- backup.go | 5 +- error.go | 28 ++++++++- sqlite3.go | 100 +++++++++++++++++++++--------- sqlite3_opt_unlock_notify.c | 20 +++--- sqlite3_opt_unlock_notify_test.go | 22 ++++--- sqlite3_opt_userauth.go | 8 +-- sqlite3_opt_vtable.go | 6 +- sqlite3_trace.go | 2 +- 8 files changed, 128 insertions(+), 63 deletions(-) diff --git a/backup.go b/backup.go index ecbb4697..16a72484 100644 --- a/backup.go +++ b/backup.go @@ -36,7 +36,10 @@ func (destConn *SQLiteConn) Backup(dest string, srcConn *SQLiteConn, src string) runtime.SetFinalizer(bb, (*SQLiteBackup).Finish) return bb, nil } - return nil, destConn.lastError() + if destConn.db != nil { + return nil, destConn.lastError(int(C.sqlite3_extended_errcode(destConn.db))) + } + return nil, Error{Code: 1, ExtendedCode: 1, err: "backup: destination connection is nil"} } // Step to backs up for one step. Calls the underlying `sqlite3_backup_step` diff --git a/error.go b/error.go index a9d41544..fb5922dd 100644 --- a/error.go +++ b/error.go @@ -13,13 +13,16 @@ package sqlite3 #endif */ import "C" -import "syscall" +import ( + "sync" + "syscall" +) // ErrNo inherit errno. type ErrNo int // ErrNoMask is mask code. -const ErrNoMask C.int = 0xff +const ErrNoMask = 0xff // ErrNoExtended is extended errno. type ErrNoExtended int @@ -85,7 +88,7 @@ func (err Error) Error() string { if err.err != "" { str = err.err } else { - str = C.GoString(C.sqlite3_errstr(C.int(err.Code))) + str = errorString(int(err.Code)) } if err.SystemErrno != 0 { str += ": " + err.SystemErrno.Error() @@ -148,3 +151,22 @@ const ( ErrNoticeRecoverRollback = ErrNoExtended(ErrNotice | 2<<8) ErrWarningAutoIndex = ErrNoExtended(ErrWarning | 1<<8) ) + +var errStrCache sync.Map // int => string + +// errorString returns the result of sqlite3_errstr for result code rv, +// which may be cached. +func errorString(rv int) string { + if v, ok := errStrCache.Load(rv); ok { + return v.(string) + } + s := C.GoString(C.sqlite3_errstr(C.int(rv))) + // Prevent the cache from growing unbounded by ignoring invalid + // error codes. + if s != "unknown error" { + if v, loaded := errStrCache.LoadOrStore(rv, s); loaded { + s = v.(string) + } + } + return s +} diff --git a/sqlite3.go b/sqlite3.go index f379fef9..78d39ab9 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -221,12 +221,15 @@ int _sqlite3_exec_no_args(sqlite3 *db, const char *zSql, int nBytes, int64_t *ro do { rv = _sqlite3_step_internal(stmt); } while (rv == SQLITE_ROW); + if (rv != SQLITE_OK && rv != SQLITE_DONE) { + return rv; + } // Only record the number of changes made by the last statement. *changes = sqlite3_changes64(db); *rowid = sqlite3_last_insert_rowid(db); - sqlite3_finalize(stmt); + rv = sqlite3_finalize(stmt); if (rv != SQLITE_OK && rv != SQLITE_DONE) { return rv; } @@ -386,6 +389,21 @@ static go_sqlite3_text_column _sqlite3_column_blob(sqlite3_stmt *stmt, int idx) } return r; } + +typedef struct { + const char *msg; + int code; +} go_sqlite3_db_error; + +static go_sqlite3_db_error _sqlite3_db_error(sqlite3 *db) { + if (db) { + return (go_sqlite3_db_error){ + .msg = sqlite3_errmsg(db), + .code = sqlite3_extended_errcode(db) + }; + } + return (go_sqlite3_db_error){ 0 }; +} */ import "C" import ( @@ -742,7 +760,7 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int defer C.free(unsafe.Pointer(cname)) rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, handle, (*[0]byte)(unsafe.Pointer(C.compareTrampoline))) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -884,7 +902,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl any, pure bool) error { } rv := sqlite3CreateFunction(c.db, cname, C.int(numArgs), C.int(opts), newHandle(c, &fi), C.callbackTrampoline, nil, nil) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -1013,7 +1031,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error } rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -1025,32 +1043,49 @@ func (c *SQLiteConn) AutoCommit() bool { return int(C.sqlite3_get_autocommit(c.db)) != 0 } -func (c *SQLiteConn) lastError() error { - return lastError(c.db) +func (c *SQLiteConn) lastError(rv int) error { + return lastError(c.db, rv) } -// Note: may be called with db == nil -func lastError(db *C.sqlite3) error { - rv := C.sqlite3_errcode(db) // returns SQLITE_NOMEM if db == nil - if rv == C.SQLITE_OK { +func lastError(db *C.sqlite3, rv int) error { + if rv == SQLITE_OK { return nil } - extrv := C.sqlite3_extended_errcode(db) // returns SQLITE_NOMEM if db == nil - errStr := C.GoString(C.sqlite3_errmsg(db)) // returns "out of memory" if db == nil + extrv := rv + // Convert the extended result code to a basic result code. + rv &= ErrNoMask // https://www.sqlite.org/c3ref/system_errno.html // sqlite3_system_errno is only meaningful if the error code was SQLITE_CANTOPEN, // or it was SQLITE_IOERR and the extended code was not SQLITE_IOERR_NOMEM var systemErrno syscall.Errno - if rv == C.SQLITE_CANTOPEN || (rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM) { + if db != nil && (rv == C.SQLITE_CANTOPEN || + (rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM)) { systemErrno = syscall.Errno(C.sqlite3_system_errno(db)) } + var msg string + if db != nil { + // TODO: Get the DB's error code as well and combine calls + msg = C.GoString(C.sqlite3_errmsg(db)) + // WARN: We need this check for Exec to return a nil error + // when the SQL statement is empty or simply a comment. + // That we need this special logic is likely indicative of + // a larger error with the logic here or in Exec. + if msg == "not an error" { + rv = int(C.sqlite3_errcode(db)) + if rv == SQLITE_OK { + return nil + } + } + } else { + msg = errorString(extrv) + } return Error{ Code: ErrNo(rv), ExtendedCode: ErrNoExtended(extrv), SystemErrno: systemErrno, - err: errStr, + err: msg, } } @@ -1089,7 +1124,7 @@ func (c *SQLiteConn) execArgs(ctx context.Context, query string, args []driver.N rv := C._sqlite3_prepare_v2(c.db, (*C.char)(unsafe.Pointer(stringData(query))), C.int(len(query)), &s.s, &sz) if rv != C.SQLITE_OK { - return nil, c.lastError() + return nil, c.lastError(int(rv)) } query = strings.TrimSpace(query[sz:]) @@ -1141,7 +1176,7 @@ func (c *SQLiteConn) execNoArgsSync(query string) (_ driver.Result, err error) { rv := C._sqlite3_exec_no_args(c.db, (*C.char)(unsafe.Pointer(stringData(query))), C.int(len(query)), &rowid, &changes) if rv != C.SQLITE_OK { - err = c.lastError() + err = c.lastError(int(rv)) } return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, err } @@ -1206,7 +1241,7 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name if rv == C.GO_SQLITE_MULTIPLE_QUERIES { return nil, errors.New("query contains multiple SQL statements") } - return nil, c.lastError() + return nil, c.lastError(int(rv)) } // The sqlite3_stmt will be nil if the SQL was valid but did not @@ -1742,7 +1777,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if rv != 0 { // Save off the error _before_ closing the database. // This is safe even if db is nil. - err := lastError(db) + err := lastError(db, int(rv)) if db != nil { C.sqlite3_close_v2(db) } @@ -1751,13 +1786,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if db == nil { return nil, errors.New("sqlite succeeded without returning a database") } + rv = C.sqlite3_extended_result_codes(db, 1) + if rv != SQLITE_OK { + C.sqlite3_close_v2(db) + return nil, lastError(db, int(rv)) + } exec := func(s string) error { cs := C.CString(s) rv := C.sqlite3_exec(db, cs, nil, nil, nil) C.free(unsafe.Pointer(cs)) if rv != C.SQLITE_OK { - return lastError(db) + return lastError(db, int(rv)) } return nil } @@ -2064,7 +2104,7 @@ func (c *SQLiteConn) Close() (err error) { runtime.SetFinalizer(c, nil) rv := C.sqlite3_close_v2(c.db) if rv != C.SQLITE_OK { - err = c.lastError() + err = lastError(nil, int(rv)) } deleteHandles(c) c.db = nil @@ -2090,7 +2130,7 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er var s *C.sqlite3_stmt rv := C._sqlite3_prepare_v2_internal(c.db, (*C.char)(unsafe.Pointer(p)), C.int(len(query)), &s, nil) if rv != C.SQLITE_OK { - return nil, c.lastError() + return nil, c.lastError(int(rv)) } ss := &SQLiteStmt{c: c, s: s} runtime.SetFinalizer(ss, (*SQLiteStmt).Close) @@ -2158,7 +2198,7 @@ func (c *SQLiteConn) SetFileControlInt(dbName string, op int, arg int) error { cArg := C.int(arg) rv := C.sqlite3_file_control(c.db, cDBName, C.int(op), unsafe.Pointer(&cArg)) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -2181,7 +2221,7 @@ func (s *SQLiteStmt) Close() error { } rv := C.sqlite3_finalize(stmt) if rv != C.SQLITE_OK { - return conn.lastError() + return conn.lastError(int(rv)) } return nil } @@ -2213,7 +2253,7 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { if s.reset { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { - return s.c.lastError() + return s.c.lastError(int(rv)) } } else { s.reset = true @@ -2258,7 +2298,7 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(ts))) } if rv != C.SQLITE_OK { - return s.c.lastError() + return s.c.lastError(int(rv)) } } return nil @@ -2327,7 +2367,7 @@ func (s *SQLiteStmt) bindIndices(args []driver.NamedValue) error { rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(ts))) } if rv != C.SQLITE_OK { - return s.c.lastError() + return s.c.lastError(int(rv)) } } } @@ -2437,7 +2477,7 @@ func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) { var rowid, changes C.longlong rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { - err := s.c.lastError() + err := s.c.lastError(int(rv)) C.sqlite3_reset(s.s) C.sqlite3_clear_bindings(s.s) return nil, err @@ -2473,8 +2513,8 @@ func (rc *SQLiteRows) Close() error { } rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_OK { - s.mu.Unlock() - return s.c.lastError() + rc.s.mu.Unlock() + return rc.s.c.lastError(int(rv)) } s.mu.Unlock() return nil @@ -2573,7 +2613,7 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { if rv != C.SQLITE_ROW { rv = C.sqlite3_reset(rc.s.s) if rv != C.SQLITE_OK { - return rc.s.c.lastError() + return rc.s.c.lastError(int(rv)) } return nil } diff --git a/sqlite3_opt_unlock_notify.c b/sqlite3_opt_unlock_notify.c index fc37b336..55f6f0dd 100644 --- a/sqlite3_opt_unlock_notify.c +++ b/sqlite3_opt_unlock_notify.c @@ -4,11 +4,14 @@ // license that can be found in the LICENSE file. #ifdef SQLITE_ENABLE_UNLOCK_NOTIFY -#include #include "sqlite3-binding.h" extern int unlock_notify_wait(sqlite3 *db); +static inline int is_locked(int rv) { + return rv == SQLITE_LOCKED || rv == SQLITE_LOCKED_SHAREDCACHE; +} + int _sqlite3_step_blocking(sqlite3_stmt *stmt) { @@ -18,10 +21,7 @@ _sqlite3_step_blocking(sqlite3_stmt *stmt) db = sqlite3_db_handle(stmt); for (;;) { rv = sqlite3_step(stmt); - if (rv != SQLITE_LOCKED) { - break; - } - if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) { + if (!is_locked(rv)) { break; } rv = unlock_notify_wait(db); @@ -43,10 +43,7 @@ _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* chan db = sqlite3_db_handle(stmt); for (;;) { rv = sqlite3_step(stmt); - if (rv!=SQLITE_LOCKED) { - break; - } - if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) { + if (!is_locked(rv)) { break; } rv = unlock_notify_wait(db); @@ -68,10 +65,7 @@ _sqlite3_prepare_v2_blocking(sqlite3 *db, const char *zSql, int nBytes, sqlite3_ for (;;) { rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); - if (rv!=SQLITE_LOCKED) { - break; - } - if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) { + if (!is_locked(rv)) { break; } rv = unlock_notify_wait(db); diff --git a/sqlite3_opt_unlock_notify_test.go b/sqlite3_opt_unlock_notify_test.go index 3a9168cd..65107c38 100644 --- a/sqlite3_opt_unlock_notify_test.go +++ b/sqlite3_opt_unlock_notify_test.go @@ -51,12 +51,12 @@ func TestUnlockNotify(t *testing.T) { wg.Add(1) timer := time.NewTimer(500 * time.Millisecond) go func() { + defer wg.Done() <-timer.C err := tx.Commit() if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() rows, err := db.Query("SELECT count(*) from foo") @@ -111,33 +111,39 @@ func TestUnlockNotifyMany(t *testing.T) { wg.Add(1) timer := time.NewTimer(500 * time.Millisecond) go func() { + defer wg.Done() <-timer.C err := tx.Commit() if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() const concurrentQueries = 1000 wg.Add(concurrentQueries) for i := 0; i < concurrentQueries; i++ { go func() { + defer wg.Done() rows, err := db.Query("SELECT count(*) from foo") if err != nil { - t.Fatal("Unable to query foo table:", err) + t.Error("Unable to query foo table:", err) + return } if rows.Next() { var count int if err := rows.Scan(&count); err != nil { - t.Fatal("Failed to Scan rows", err) + t.Error("Failed to Scan rows", err) + return + } + if count != 1 { + t.Errorf("count=%d want=%d", count, 1) } } if err := rows.Err(); err != nil { - t.Fatal("Failed at the call to Next:", err) + t.Error("Failed at the call to Next:", err) + return } - wg.Done() }() } wg.Wait() @@ -177,16 +183,17 @@ func TestUnlockNotifyDeadlock(t *testing.T) { wg.Add(1) timer := time.NewTimer(500 * time.Millisecond) go func() { + defer wg.Done() <-timer.C err := tx.Commit() if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() wg.Add(1) go func() { + defer wg.Done() tx2, err := db.Begin() if err != nil { t.Fatal("Failed to begin transaction:", err) @@ -201,7 +208,6 @@ func TestUnlockNotifyDeadlock(t *testing.T) { if err != nil { t.Fatal("Failed to commit transaction:", err) } - wg.Done() }() rows, err := tx.Query("SELECT count(*) from foo") diff --git a/sqlite3_opt_userauth.go b/sqlite3_opt_userauth.go index 20536acd..4fda59da 100644 --- a/sqlite3_opt_userauth.go +++ b/sqlite3_opt_userauth.go @@ -91,7 +91,7 @@ func (c *SQLiteConn) Authenticate(username, password string) error { case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } @@ -139,7 +139,7 @@ func (c *SQLiteConn) AuthUserAdd(username, password string, admin bool) error { case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } @@ -189,7 +189,7 @@ func (c *SQLiteConn) AuthUserChange(username, password string, admin bool) error case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } @@ -237,7 +237,7 @@ func (c *SQLiteConn) AuthUserDelete(username string) error { case C.SQLITE_OK: return nil default: - return c.lastError() + return c.lastError(int(rv)) } } diff --git a/sqlite3_opt_vtable.go b/sqlite3_opt_vtable.go index 9b164b3e..57c7303e 100644 --- a/sqlite3_opt_vtable.go +++ b/sqlite3_opt_vtable.go @@ -692,7 +692,7 @@ func (c *SQLiteConn) DeclareVTab(sql string) error { defer C.free(unsafe.Pointer(zSQL)) rv := C.sqlite3_declare_vtab(c.db, zSQL) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } @@ -707,13 +707,13 @@ func (c *SQLiteConn) CreateModule(moduleName string, module Module) error { case EponymousOnlyModule: rv := C._sqlite3_create_module_eponymous_only(c.db, mname, C.uintptr_t(uintptr(newHandle(c, &udm)))) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil case Module: rv := C._sqlite3_create_module(c.db, mname, C.uintptr_t(uintptr(newHandle(c, &udm)))) if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil } diff --git a/sqlite3_trace.go b/sqlite3_trace.go index 6c47cce1..d1eabacf 100644 --- a/sqlite3_trace.go +++ b/sqlite3_trace.go @@ -282,7 +282,7 @@ func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error { // passing the database connection handle as callback context. if rv != C.SQLITE_OK { - return c.lastError() + return c.lastError(int(rv)) } return nil }