Skip to content

Commit

Permalink
make sure Close always removes the runtime finalizer
Browse files Browse the repository at this point in the history
This commit fixes a bug in {SQLiteConn,SQLiteStmt}.Close that would lead
to the runtime finalizer not being removed if there was an error closing
the connection or statement. This commit also makes it safe to call
SQLiteConn.Close multiple times.
  • Loading branch information
charlievieth committed Dec 16, 2024
1 parent 7658c06 commit 4f27a02
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 78 deletions.
5 changes: 4 additions & 1 deletion backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ func (destConn *SQLiteConn) Backup(dest string, srcConn *SQLiteConn, src string)
runtime.SetFinalizer(bb, (*SQLiteBackup).Finish)
return bb, nil
}
return nil, destConn.lastError()
if destConn.db != nil {
return nil, destConn.lastError(int(C.sqlite3_extended_errcode(destConn.db)))
}
return nil, Error{Code: 1, ExtendedCode: 1, err: "backup: destination connection is nil"}
}

// Step to backs up for one step. Calls the underlying `sqlite3_backup_step`
Expand Down
24 changes: 21 additions & 3 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ package sqlite3
#endif
*/
import "C"
import "syscall"
import (
"sync"
"syscall"
)

// ErrNo inherit errno.
type ErrNo int

// ErrNoMask is mask code.
const ErrNoMask C.int = 0xff
const ErrNoMask = 0xff

// ErrNoExtended is extended errno.
type ErrNoExtended int
Expand Down Expand Up @@ -85,7 +88,7 @@ func (err Error) Error() string {
if err.err != "" {
str = err.err
} else {
str = C.GoString(C.sqlite3_errstr(C.int(err.Code)))
str = errorString(int(err.Code))
}
if err.SystemErrno != 0 {
str += ": " + err.SystemErrno.Error()
Expand Down Expand Up @@ -148,3 +151,18 @@ var (
ErrNoticeRecoverRollback = ErrNotice.Extend(2)
ErrWarningAutoIndex = ErrWarning.Extend(1)
)

var errStrCache sync.Map // int => string

// errorString returns the result of sqlite3_errstr for result code
// rv which may be cached.
func errorString(rv int) string {
if v, ok := errStrCache.Load(rv); ok {
return v.(string)
}
s := C.GoString(C.sqlite3_errstr(C.int(rv)))
if v, loaded := errStrCache.LoadOrStore(rv, s); loaded {
return v.(string)
}
return s
}
94 changes: 49 additions & 45 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int
defer C.free(unsafe.Pointer(cname))
rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, handle, (*[0]byte)(unsafe.Pointer(C.compareTrampoline)))
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand Down Expand Up @@ -675,7 +675,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl any, pure bool) error {
}
rv := sqlite3CreateFunction(c.db, cname, C.int(numArgs), C.int(opts), newHandle(c, &fi), C.callbackTrampoline, nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand Down Expand Up @@ -804,7 +804,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error
}
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand All @@ -816,32 +816,38 @@ func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0
}

func (c *SQLiteConn) lastError() error {
return lastError(c.db)
func (c *SQLiteConn) lastError(rv int) error {
return lastError(c.db, rv)
}

// Note: may be called with db == nil
func lastError(db *C.sqlite3) error {
rv := C.sqlite3_errcode(db) // returns SQLITE_NOMEM if db == nil
if rv == C.SQLITE_OK {
func lastError(db *C.sqlite3, rv int) error {
if rv == SQLITE_OK {
return nil
}
extrv := C.sqlite3_extended_errcode(db) // returns SQLITE_NOMEM if db == nil
errStr := C.GoString(C.sqlite3_errmsg(db)) // returns "out of memory" if db == nil
extrv := rv
// Convert the extended result code to a basic result code.
rv &= ErrNoMask

// https://www.sqlite.org/c3ref/system_errno.html
// sqlite3_system_errno is only meaningful if the error code was SQLITE_CANTOPEN,
// or it was SQLITE_IOERR and the extended code was not SQLITE_IOERR_NOMEM
var systemErrno syscall.Errno
if rv == C.SQLITE_CANTOPEN || (rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM) {
if db != nil && (rv == C.SQLITE_CANTOPEN ||
(rv == C.SQLITE_IOERR && extrv != C.SQLITE_IOERR_NOMEM)) {
systemErrno = syscall.Errno(C.sqlite3_system_errno(db))
}

var msg string
if db != nil {
msg = C.GoString(C.sqlite3_errmsg(db))
} else {
msg = errorString(extrv)
}
return Error{
Code: ErrNo(rv),
ExtendedCode: ErrNoExtended(extrv),
SystemErrno: systemErrno,
err: errStr,
err: msg,
}
}

Expand Down Expand Up @@ -1467,7 +1473,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if rv != 0 {
// Save off the error _before_ closing the database.
// This is safe even if db is nil.
err := lastError(db)
err := lastError(db, int(rv))
if db != nil {
C.sqlite3_close_v2(db)
}
Expand All @@ -1476,13 +1482,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if db == nil {
return nil, errors.New("sqlite succeeded without returning a database")
}
rv = C.sqlite3_extended_result_codes(db, 1)
if rv != SQLITE_OK {
C.sqlite3_close_v2(db)
return nil, lastError(db, int(rv))
}

exec := func(s string) error {
cs := C.CString(s)
rv := C.sqlite3_exec(db, cs, nil, nil, nil)
C.free(unsafe.Pointer(cs))
if rv != C.SQLITE_OK {
return lastError(db)
return lastError(db, int(rv))
}
return nil
}
Expand Down Expand Up @@ -1782,26 +1793,20 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
}

// Close the connection.
func (c *SQLiteConn) Close() error {
func (c *SQLiteConn) Close() (err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.db == nil {
return nil // Already closed
}
runtime.SetFinalizer(c, nil)
rv := C.sqlite3_close_v2(c.db)
if rv != C.SQLITE_OK {
return c.lastError()
err = lastError(nil, int(rv))
}
deleteHandles(c)
c.mu.Lock()
c.db = nil
c.mu.Unlock()
runtime.SetFinalizer(c, nil)
return nil
}

func (c *SQLiteConn) dbConnOpen() bool {
if c == nil {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
return c.db != nil
return err
}

// Prepare the query string. Return a new statement.
Expand All @@ -1816,7 +1821,7 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
var tail *C.char
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s, &tail)
if rv != C.SQLITE_OK {
return nil, c.lastError()
return nil, c.lastError(int(rv))
}
var t string
if tail != nil && *tail != '\000' {
Expand Down Expand Up @@ -1888,7 +1893,7 @@ func (c *SQLiteConn) SetFileControlInt(dbName string, op int, arg int) error {
cArg := C.int(arg)
rv := C.sqlite3_file_control(c.db, cDBName, C.int(op), unsafe.Pointer(&cArg))
if rv != C.SQLITE_OK {
return c.lastError()
return c.lastError(int(rv))
}
return nil
}
Expand All @@ -1901,16 +1906,15 @@ func (s *SQLiteStmt) Close() error {
return nil
}
s.closed = true
if !s.c.dbConnOpen() {
return errors.New("sqlite statement with already closed database connection")
}
rv := C.sqlite3_finalize(s.s)
conn := s.c
stmt := s.s
s.c = nil
s.s = nil
runtime.SetFinalizer(s, nil)
rv := C.sqlite3_finalize(stmt)
if rv != C.SQLITE_OK {
return s.c.lastError()
return conn.lastError(int(rv))
}
s.c = nil
runtime.SetFinalizer(s, nil)
return nil
}

Expand All @@ -1924,7 +1928,7 @@ var placeHolder = []byte{0}
func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return s.c.lastError()
return s.c.lastError(int(rv))
}

bindIndices := make([][3]int, len(args))
Expand Down Expand Up @@ -1982,7 +1986,7 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return s.c.lastError()
return s.c.lastError(int(rv))
}
}
}
Expand Down Expand Up @@ -2092,7 +2096,7 @@ func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) {
var rowid, changes C.longlong
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
err := s.c.lastError()
err := s.c.lastError(int(rv))
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
Expand Down Expand Up @@ -2128,8 +2132,8 @@ func (rc *SQLiteRows) Close() error {
}
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_OK {
s.mu.Unlock()
return s.c.lastError()
rc.s.mu.Unlock()
return rc.s.c.lastError(int(rv))
}
s.mu.Unlock()
return nil
Expand Down Expand Up @@ -2206,7 +2210,7 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
if rv != C.SQLITE_ROW {
rv = C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return rc.s.c.lastError()
return rc.s.c.lastError(int(rv))
}
return nil
}
Expand Down
20 changes: 7 additions & 13 deletions sqlite3_opt_unlock_notify.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// license that can be found in the LICENSE file.

#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
#include <stdio.h>
#ifndef USE_LIBSQLITE3
#include "sqlite3-binding.h"
#else
Expand All @@ -13,6 +12,10 @@

extern int unlock_notify_wait(sqlite3 *db);

static inline int is_locked(int rv) {
return rv == SQLITE_LOCKED || rv == SQLITE_LOCKED_SHAREDCACHE;
}

int
_sqlite3_step_blocking(sqlite3_stmt *stmt)
{
Expand All @@ -22,10 +25,7 @@ _sqlite3_step_blocking(sqlite3_stmt *stmt)
db = sqlite3_db_handle(stmt);
for (;;) {
rv = sqlite3_step(stmt);
if (rv != SQLITE_LOCKED) {
break;
}
if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) {
if (!is_locked(rv)) {
break;
}
rv = unlock_notify_wait(db);
Expand All @@ -47,10 +47,7 @@ _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* chan
db = sqlite3_db_handle(stmt);
for (;;) {
rv = sqlite3_step(stmt);
if (rv!=SQLITE_LOCKED) {
break;
}
if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) {
if (!is_locked(rv)) {
break;
}
rv = unlock_notify_wait(db);
Expand All @@ -72,10 +69,7 @@ _sqlite3_prepare_v2_blocking(sqlite3 *db, const char *zSql, int nBytes, sqlite3_

for (;;) {
rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
if (rv!=SQLITE_LOCKED) {
break;
}
if (sqlite3_extended_errcode(db) != SQLITE_LOCKED_SHAREDCACHE) {
if (!is_locked(rv)) {
break;
}
rv = unlock_notify_wait(db);
Expand Down
Loading

0 comments on commit 4f27a02

Please sign in to comment.