Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqlite3: reduce C to Go string conversions in SQLiteConn.{query,exec} #1133

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 74 additions & 50 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_
{
return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
}

#endif

void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
Expand Down Expand Up @@ -858,25 +859,34 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
}

func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
pquery := C.CString(query)
op := pquery // original pointer
defer C.free(unsafe.Pointer(op))

var stmtArgs []driver.NamedValue
var tail *C.char
s := new(SQLiteStmt) // escapes to the heap so reuse it
defer s.finalize()
start := 0
for {
s, err := c.prepare(ctx, query)
if err != nil {
return nil, err
*s = SQLiteStmt{c: c} // reset
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
if rv != C.SQLITE_OK {
return nil, c.lastError()
}

var res driver.Result
if s.(*SQLiteStmt).s != nil {
stmtArgs := make([]driver.NamedValue, 0, len(args))
if s.s != nil {
na := s.NumInput()
if len(args)-start < na {
s.Close()
s.finalize()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
}
// consume the number of arguments used in the current
// statement and append all named arguments not
// contained therein
if len(args[start:start+na]) > 0 {
stmtArgs = append(stmtArgs, args[start:start+na]...)
stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
Expand All @@ -886,23 +896,23 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named
stmtArgs[i].Ordinal = i + 1
}
}
res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs)
var err error
res, err = s.exec(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.Close()
s.finalize()
return nil, err
}
start += na
}
tail := s.(*SQLiteStmt).t
s.Close()
if tail == "" {
s.finalize()
if tail == nil || *tail == '\000' {
if res == nil {
// https://github.com/mattn/go-sqlite3/issues/963
res = &SQLiteResult{0, 0}
}
return res, nil
}
query = tail
pquery = tail
}
}

Expand All @@ -919,44 +929,48 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
}

func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
start := 0
for {
stmtArgs := make([]driver.NamedValue, 0, len(args))
s, err := c.prepare(ctx, query)
if err != nil {
return nil, err
}
s.(*SQLiteStmt).cls = true
na := s.NumInput()
if len(args)-start < na {
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
}
// consume the number of arguments used in the current
// statement and append all named arguments not contained
// therein
stmtArgs = append(stmtArgs, args[start:start+na]...)
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
}
}
for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
pquery := C.CString(query)
op := pquery // original pointer
defer C.free(unsafe.Pointer(op))

var tail *C.char
s := &SQLiteStmt{c: c, cls: true}
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
if rv != C.SQLITE_OK {
return nil, c.lastError()
}
if s.s == nil {
return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil
}
na := s.NumInput()
if n := len(args); n != na {
s.finalize()
if n < na {
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
}
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.Close()
return rows, err
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
}
rows, err := s.query(ctx, args)
if err != nil && err != driver.ErrSkip {
s.finalize()
return rows, err
}

// Consume the rest of the query
for pquery = tail; pquery != nil && *pquery != 0; pquery = tail {
var stmt *C.sqlite3_stmt
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail)
if rv != C.SQLITE_OK {
rows.Close()
return nil, c.lastError()
}
start += na
tail := s.(*SQLiteStmt).t
if tail == "" {
return rows, nil
if stmt != nil {
rows.Close()
return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN
}
rows.Close()
s.Close()
query = tail
}

return rows, nil
}

// Begin transaction.
Expand Down Expand Up @@ -1818,8 +1832,11 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
return nil, c.lastError()
}
var t string
if tail != nil && *tail != '\000' {
t = strings.TrimSpace(C.GoString(tail))
if tail != nil && *tail != 0 {
n := int(uintptr(unsafe.Pointer(tail))) - int(uintptr(unsafe.Pointer(pquery)))
if 0 <= n && n < len(query) {
t = strings.TrimSpace(query[n:])
}
}
ss := &SQLiteStmt{c: c, s: s, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
Expand Down Expand Up @@ -1913,6 +1930,13 @@ func (s *SQLiteStmt) Close() error {
return nil
}

func (s *SQLiteStmt) finalize() {
if s.s != nil {
C.sqlite3_finalize(s.s)
s.s = nil
}
}

// NumInput return a number of parameters.
func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
Expand Down Expand Up @@ -2000,7 +2024,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
return s.query(context.Background(), list)
}

func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) {
if err := s.bind(args); err != nil {
return nil, err
}
Expand Down
Loading
Loading