Skip to content

Commit

Permalink
WIP: unstage me
Browse files Browse the repository at this point in the history
  • Loading branch information
charlievieth committed Nov 24, 2024
1 parent e82107b commit 828da04
Show file tree
Hide file tree
Showing 2 changed files with 347 additions and 40 deletions.
234 changes: 201 additions & 33 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) {
sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT);
}
static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes,
sqlite3_stmt **ppStmt, int *oBytes) {
const char *tail;
int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail);
if (rv != SQLITE_OK) {
return rv;
}
// TODO: combine this with the below logic
if (tail == NULL) {
// NB: this should not happen
*oBytes = nBytes;
return rv;
}
// Set oBytes to the number of bytes consumed instead of using the **pzTail
// out param since that requires storing a Go pointer in a C pointer, which
// is not allowed by CGO and will cause runtime.cgoCheckPointer to fail.
*oBytes = tail - zSql;
return rv;
}
int _sqlite3_create_function(
sqlite3 *db,
Expand Down Expand Up @@ -919,44 +938,56 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
}

func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
start := 0
var (
// stmtArgs []driver.NamedValue
start int
s SQLiteStmt // escapes to the heap so reuse it
sz C.int // number of query bytes consumed: escapes to the heap
)
query = strings.TrimSpace(query)
for {
stmtArgs := make([]driver.NamedValue, 0, len(args))
s, err := c.prepare(ctx, query)
if err != nil {
return nil, err
s = SQLiteStmt{c: c, cls: true} // reset
sz = 0
rv := C._sqlite3_prepare_v2(c.db, (*C.char)(unsafe.Pointer(stringData(query))),
C.int(len(query)), &s.s, &sz)
if rv != C.SQLITE_OK {
return nil, c.lastError()
}
s.(*SQLiteStmt).cls = true
query = strings.TrimSpace(query[sz:])

na := s.NumInput()
if len(args)-start < na {
s.Close()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
s.finalize()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
}
// consume the number of arguments used in the current
// statement and append all named arguments not contained
// therein
stmtArgs = append(stmtArgs, args[start:start+na]...)
// statement and append all named arguments not
// contained therein
// if stmtArgs == nil {
// stmtArgs = make([]driver.NamedValue, 0, na)
// }
// stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
stmtArgs := args[start : start+na : start+na]
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
}
}
start += na
for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
}
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)

rows, err := s.query(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.Close()
return rows, err
s.finalize()
return nil, err
}
start += na
tail := s.(*SQLiteStmt).t
if tail == "" {
if len(query) == 0 {
return rows, nil
}
rows.Close()
s.Close()
query = tail
s.finalize()
}
}

Expand Down Expand Up @@ -1914,33 +1945,106 @@ func (s *SQLiteStmt) Close() error {
return nil
}

func (s *SQLiteStmt) finalize() {
if s.s != nil {
C.sqlite3_finalize(s.s)
s.s = nil
}
}

// NumInput return a number of parameters.
func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
}

var placeHolder = []byte{0}

func hasNamedArgs(args []driver.NamedValue) bool {
for _, v := range args {
if v.Name != "" {
return true
}
}
return false
}

// TODO: return the column count as well
func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return s.c.lastError()
}

if hasNamedArgs(args) {
return s.bindIndices(args)
}

for _, arg := range args {
n := C.int(arg.Ordinal)
switch v := arg.Value.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
p := stringData(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v)))
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case bool:
val := 0
if v {
val = 1
}
rv = C.sqlite3_bind_int(s.s, n, C.int(val))
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
if v == nil {
rv = C.sqlite3_bind_null(s.s, n)
} else {
ln := len(v)
if ln == 0 {
v = placeHolder
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
}
case time.Time:
b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return s.c.lastError()
}
}
return nil
}

func (s *SQLiteStmt) bindIndices(args []driver.NamedValue) error {
// Find the longest named parameter name.
n := 0
for _, v := range args {
if m := len(v.Name); m > n {
n = m
}
}
buf := make([]byte, 0, n+2) // +2 for placeholder and null terminator

// TODO: Reduce the size of this slive by using uint32 or uint16.
// By default SQLITE_MAX_FUNCTION_ARG is 100 so uint16 should work.
bindIndices := make([][3]int, len(args))
prefixes := []string{":", "@", "$"}
for i, v := range args {
bindIndices[i][0] = args[i].Ordinal
if v.Name != "" {
for j := range prefixes {
cname := C.CString(prefixes[j] + v.Name)
bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname))
C.free(unsafe.Pointer(cname))
for j, c := range []byte{':', '@', '$'} {
buf = append(buf[:0], c)
buf = append(buf, v.Name...)
buf = append(buf, 0)
bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, (*C.char)(unsafe.Pointer(&buf[0]))))
}
args[i].Ordinal = bindIndices[i][0]
}
}

var rv C.int
for i, arg := range args {
for j := range bindIndices[i] {
if bindIndices[i][j] == 0 {
Expand All @@ -1951,20 +2055,16 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
if len(v) == 0 {
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
} else {
b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
p := stringData(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v)))
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case bool:
val := 0
if v {
rv = C.sqlite3_bind_int(s.s, n, 1)
} else {
rv = C.sqlite3_bind_int(s.s, n, 0)
val = 1
}
rv = C.sqlite3_bind_int(s.s, n, C.int(val))
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
Expand All @@ -1989,6 +2089,74 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
return nil
}

// func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
// rv := C.sqlite3_reset(s.s)
// if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
// return s.c.lastError()
// }
//
// bindIndices := make([][3]int, len(args))
// prefixes := []string{":", "@", "$"}
// for i, v := range args {
// bindIndices[i][0] = args[i].Ordinal
// if v.Name != "" {
// for j := range prefixes {
// cname := C.CString(prefixes[j] + v.Name)
// bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname))
// C.free(unsafe.Pointer(cname))
// }
// args[i].Ordinal = bindIndices[i][0]
// }
// }
//
// for i, arg := range args {
// for j := range bindIndices[i] {
// if bindIndices[i][j] == 0 {
// continue
// }
// n := C.int(bindIndices[i][j])
// switch v := arg.Value.(type) {
// case nil:
// rv = C.sqlite3_bind_null(s.s, n)
// case string:
// if len(v) == 0 {
// rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
// } else {
// b := []byte(v)
// rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
// }
// case int64:
// rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
// case bool:
// if v {
// rv = C.sqlite3_bind_int(s.s, n, 1)
// } else {
// rv = C.sqlite3_bind_int(s.s, n, 0)
// }
// case float64:
// rv = C.sqlite3_bind_double(s.s, n, C.double(v))
// case []byte:
// if v == nil {
// rv = C.sqlite3_bind_null(s.s, n)
// } else {
// ln := len(v)
// if ln == 0 {
// v = placeHolder
// }
// rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
// }
// case time.Time:
// b := []byte(v.Format(SQLiteTimestampFormats[0]))
// rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
// }
// if rv != C.SQLITE_OK {
// return s.c.lastError()
// }
// }
// }
// return nil
// }

// Query the statement with arguments. Return records.
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
list := make([]driver.NamedValue, len(args))
Expand Down
Loading

0 comments on commit 828da04

Please sign in to comment.