Skip to content

Commit

Permalink
use extended error codes and handle closed db connections in lastError
Browse files Browse the repository at this point in the history
This commit changes the error handling logic so that it respects the
offending result code (instead of only relying on sqlite3_errcode) and
changes the db connection to always report the extended result code
(which eliminates the need to call sqlite3_extended_errcode). These
changes make it possible to correctly and safely handle errors when the
underlying db connection has been closed.
  • Loading branch information
charlievieth committed Dec 26, 2024
1 parent f2431e5 commit 1f1ff36
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 63 deletions.
5 changes: 4 additions & 1 deletion backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ func (destConn *SQLiteConn) Backup(dest string, srcConn *SQLiteConn, src string)
runtime.SetFinalizer(bb, (*SQLiteBackup).Finish)
return bb, nil
}
return nil, destConn.lastError()
if destConn.db != nil {
return nil, destConn.lastError(int(C.sqlite3_extended_errcode(destConn.db)))
}
return nil, Error{Code: 1, ExtendedCode: 1, err: "backup: destination connection is nil"}
}

// Step to backs up for one step. Calls the underlying `sqlite3_backup_step`
Expand Down
28 changes: 25 additions & 3 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ package sqlite3
#endif
*/
import "C"
import "syscall"
import (
"sync"
"syscall"
)

// ErrNo inherit errno.
type ErrNo int

// ErrNoMask is mask code.
const ErrNoMask C.int = 0xff
const ErrNoMask = 0xff

// ErrNoExtended is extended errno.
type ErrNoExtended int
Expand Down Expand Up @@ -85,7 +88,7 @@ func (err Error) Error() string {
if err.err != "" {
str = err.err
} else {
str = C.GoString(C.sqlite3_errstr(C.int(err.Code)))
str = errorString(int(err.Code))
}
if err.SystemErrno != 0 {
str += ": " + err.SystemErrno.Error()
Expand Down Expand Up @@ -148,3 +151,22 @@ const (
ErrNoticeRecoverRollback = ErrNoExtended(ErrNotice | 2<<8)
ErrWarningAutoIndex = ErrNoExtended(ErrWarning | 1<<8)
)

var errStrCache sync.Map // int => string

// errorString returns the result of sqlite3_errstr for result code rv,
// which may be cached.
func errorString(rv int) string {
if v, ok := errStrCache.Load(rv); ok {
return v.(string)
}
s := C.GoString(C.sqlite3_errstr(C.int(rv)))
// Prevent the cache from growing unbounded by ignoring invalid
// error codes.
if s != "unknown error" {
if v, loaded := errStrCache.LoadOrStore(rv, s); loaded {
s = v.(string)
}
}
return s
}
100 changes: 70 additions & 30 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,15 @@ int _sqlite3_exec_no_args(sqlite3 *db, const char *zSql, int nBytes, int64_t *ro
do {
rv = _sqlite3_step_internal(stmt);
} while (rv == SQLITE_ROW);
if (rv != SQLITE_OK && rv != SQLITE_DONE) {
return rv;
}
// Only record the number of changes made by the last statement.
*changes = sqlite3_changes64(db);
*rowid = sqlite3_last_insert_rowid(db);
sqlite3_finalize(stmt);
rv = sqlite3_finalize(stmt);
if (rv != SQLITE_OK && rv != SQLITE_DONE) {
return rv;
}
Expand Down Expand Up @@ -386,6 +389,21 @@ static go_sqlite3_text_column _sqlite3_column_blob(sqlite3_stmt *stmt, int idx)
}
return r;
}
typedef struct {
const char *msg;
int code;
} go_sqlite3_db_error;
static go_sqlite3_db_error _sqlite3_db_error(sqlite3 *db) {
if (db) {
return (go_sqlite3_db_error){
.msg = sqlite3_errmsg(db),
.code = sqlite3_extended_errcode(db)
};
}
return (go_sqlite3_db_error){ 0 };
}
*/
import "C"
import (
Expand Down Expand Up @@ -742,7 +760,7 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int
defer C.free(unsafe.Pointer(cname))
rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, handle, (*[0]byte)(unsafe.Pointer(C.compareTrampoline)))
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand Down Expand Up @@ -884,7 +902,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl any, pure bool) error {
}
rv := sqlite3CreateFunction(c.db, cname, C.int(numArgs), C.int(opts), newHandle(c, &fi), C.callbackTrampoline, nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand Down Expand Up @@ -1013,7 +1031,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error
}
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand All @@ -1025,32 +1043,49 @@ func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0
}

func (c *SQLiteConn) lastError() error {
return lastError(c.db)
func (c *SQLiteConn) lastError(rv int) error {
return lastError(c.db, rv)
}

// Note: may be called with db == nil
func lastError(db *C.sqlite3) error {
rv := C.sqlite3_errcode(db) // returns SQLITE_NOMEM if db == nil
if rv == C.SQLITE_OK {
func lastError(db *C.sqlite3, rv int) error {
if rv == SQLITE_OK {
return nil
}
extrv := C.sqlite3_extended_errcode(db) // returns SQLITE_NOMEM if db == nil
errStr := C.GoString(C.sqlite3_errmsg(db)) // returns "out of memory" if db == nil
extrv := rv
// Convert the extended result code to a basic result code.
rv &= ErrNoMask

// https://www.sqlite.org/c3ref/system_errno.html
// sqlite3_system_errno is only meaningful if the error code was SQLITE_CANTOPEN,
// or it was SQLITE_IOERR and the extended code was not SQLITE_IOERR_NOMEM
var systemErrno syscall.Errno
if rv == C.SQLITE_CANTOPEN || (rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM) {
if db != nil && (rv == C.SQLITE_CANTOPEN ||
(rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM)) {
systemErrno = syscall.Errno(C.sqlite3_system_errno(db))
}

var msg string
if db != nil {
// TODO: Get the DB's error code as well and combine calls
msg = C.GoString(C.sqlite3_errmsg(db))
// WARN: We need this check for Exec to return a nil error
// when the SQL statement is empty or simply a comment.
// That we need this special logic is likely indicative of
// a larger error with the logic here or in Exec.
if msg == "not an error" {
rv = int(C.sqlite3_errcode(db))
if rv == SQLITE_OK {
return nil
}
}
} else {
msg = errorString(extrv)
}
return Error{
Code: ErrNo(rv),
ExtendedCode: ErrNoExtended(extrv),
SystemErrno: systemErrno,
err: errStr,
err: msg,
}
}

Expand Down Expand Up @@ -1089,7 +1124,7 @@ func (c *SQLiteConn) execArgs(ctx context.Context, query string, args []driver.N
rv := C._sqlite3_prepare_v2(c.db, (*C.char)(unsafe.Pointer(stringData(query))),
C.int(len(query)), &s.s, &sz)
if rv != C.SQLITE_OK {
return nil, c.lastError()
return nil, c.lastError(int(rv))
}
query = strings.TrimSpace(query[sz:])

Expand Down Expand Up @@ -1141,7 +1176,7 @@ func (c *SQLiteConn) execNoArgsSync(query string) (_ driver.Result, err error) {
rv := C._sqlite3_exec_no_args(c.db, (*C.char)(unsafe.Pointer(stringData(query))),
C.int(len(query)), &rowid, &changes)
if rv != C.SQLITE_OK {
err = c.lastError()
err = c.lastError(int(rv))
}
return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, err
}
Expand Down Expand Up @@ -1206,7 +1241,7 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
if rv == C.GO_SQLITE_MULTIPLE_QUERIES {
return nil, errors.New("query contains multiple SQL statements")
}
return nil, c.lastError()
return nil, c.lastError(int(rv))
}

// The sqlite3_stmt will be nil if the SQL was valid but did not
Expand Down Expand Up @@ -1742,7 +1777,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if rv != 0 {
// Save off the error _before_ closing the database.
// This is safe even if db is nil.
err := lastError(db)
err := lastError(db, int(rv))
if db != nil {
C.sqlite3_close_v2(db)
}
Expand All @@ -1751,13 +1786,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if db == nil {
return nil, errors.New("sqlite succeeded without returning a database")
}
rv = C.sqlite3_extended_result_codes(db, 1)
if rv != SQLITE_OK {
C.sqlite3_close_v2(db)
return nil, lastError(db, int(rv))
}

exec := func(s string) error {
cs := C.CString(s)
rv := C.sqlite3_exec(db, cs, nil, nil, nil)
C.free(unsafe.Pointer(cs))
if rv != C.SQLITE_OK {
return lastError(db)
return lastError(db, int(rv))
}
return nil
}
Expand Down Expand Up @@ -2064,7 +2104,7 @@ func (c *SQLiteConn) Close() (err error) {
runtime.SetFinalizer(c, nil)
rv := C.sqlite3_close_v2(c.db)
if rv != C.SQLITE_OK {
err = c.lastError()
err = lastError(nil, int(rv))
}
deleteHandles(c)
c.db = nil
Expand All @@ -2090,7 +2130,7 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
var s *C.sqlite3_stmt
rv := C._sqlite3_prepare_v2_internal(c.db, (*C.char)(unsafe.Pointer(p)), C.int(len(query)), &s, nil)
if rv != C.SQLITE_OK {
return nil, c.lastError()
return nil, c.lastError(int(rv))
}
ss := &SQLiteStmt{c: c, s: s}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
Expand Down Expand Up @@ -2158,7 +2198,7 @@ func (c *SQLiteConn) SetFileControlInt(dbName string, op int, arg int) error {
cArg := C.int(arg)
rv := C.sqlite3_file_control(c.db, cDBName, C.int(op), unsafe.Pointer(&cArg))
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand All @@ -2181,7 +2221,7 @@ func (s *SQLiteStmt) Close() error {
}
rv := C.sqlite3_finalize(stmt)
if rv != C.SQLITE_OK {
return conn.lastError()
return conn.lastError(int(rv))
}
return nil
}
Expand Down Expand Up @@ -2213,7 +2253,7 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
if s.reset {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return s.c.lastError()
return s.c.lastError(int(rv))
}
} else {
s.reset = true
Expand Down Expand Up @@ -2258,7 +2298,7 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(ts)))
}
if rv != C.SQLITE_OK {
return s.c.lastError()
return s.c.lastError(int(rv))
}
}
return nil
Expand Down Expand Up @@ -2327,7 +2367,7 @@ func (s *SQLiteStmt) bindIndices(args []driver.NamedValue) error {
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(ts)))
}
if rv != C.SQLITE_OK {
return s.c.lastError()
return s.c.lastError(int(rv))
}
}
}
Expand Down Expand Up @@ -2437,7 +2477,7 @@ func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) {
var rowid, changes C.longlong
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
err := s.c.lastError()
err := s.c.lastError(int(rv))
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
Expand Down Expand Up @@ -2473,8 +2513,8 @@ func (rc *SQLiteRows) Close() error {
}
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_OK {
s.mu.Unlock()
return s.c.lastError()
rc.s.mu.Unlock()
return rc.s.c.lastError(int(rv))
}
s.mu.Unlock()
return nil
Expand Down Expand Up @@ -2573,7 +2613,7 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
if rv != C.SQLITE_ROW {
rv = C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return rc.s.c.lastError()
return rc.s.c.lastError(int(rv))
}
return nil
}
Expand Down
20 changes: 7 additions & 13 deletions sqlite3_opt_unlock_notify.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
// license that can be found in the LICENSE file.

#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
#include <stdio.h>
#include "sqlite3-binding.h"

extern int unlock_notify_wait(sqlite3 *db);

static inline int is_locked(int rv) {
return rv == SQLITE_LOCKED || rv == SQLITE_LOCKED_SHAREDCACHE;
}

int
_sqlite3_step_blocking(sqlite3_stmt *stmt)
{
Expand All @@ -18,10 +21,7 @@ _sqlite3_step_blocking(sqlite3_stmt *stmt)
db = sqlite3_db_handle(stmt);
for (;;) {
rv = sqlite3_step(stmt);
if (rv != SQLITE_LOCKED) {
break;
}
if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) {
if (!is_locked(rv)) {
break;
}
rv = unlock_notify_wait(db);
Expand All @@ -43,10 +43,7 @@ _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* chan
db = sqlite3_db_handle(stmt);
for (;;) {
rv = sqlite3_step(stmt);
if (rv!=SQLITE_LOCKED) {
break;
}
if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) {
if (!is_locked(rv)) {
break;
}
rv = unlock_notify_wait(db);
Expand All @@ -68,10 +65,7 @@ _sqlite3_prepare_v2_blocking(sqlite3 *db, const char *zSql, int nBytes, sqlite3_

for (;;) {
rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
if (rv!=SQLITE_LOCKED) {
break;
}
if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) {
if (!is_locked(rv)) {
break;
}
rv = unlock_notify_wait(db);
Expand Down
Loading

0 comments on commit 1f1ff36

Please sign in to comment.