diff --git a/sqlite.go b/sqlite.go index 14a0e91..41ba69f 100644 --- a/sqlite.go +++ b/sqlite.go @@ -21,7 +21,7 @@ import ( "modernc.org/libc" "modernc.org/libc/sys/types" - "modernc.org/sqlite/lib" + sqlite3 "modernc.org/sqlite/lib" ) var ( @@ -435,8 +435,9 @@ func newStmt(c *conn, sql string) (*stmt, error) { if err != nil { return nil, err } + stm := stmt{c: c, psql: p} - return &stmt{c: c, psql: p}, nil + return &stm, nil } // Close closes the statement. @@ -468,25 +469,25 @@ func toNamedValues(vals []driver.Value) (r []driver.NamedValue) { func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { var pstmt uintptr - - donech := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - if pstmt != 0 { + done := false + if ctx != nil && ctx.Done() != nil { + donech := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + done = true s.c.interrupt(s.c.db) + case <-donech: } - case <-donech: - } - }() + }() - defer func() { - pstmt = 0 - close(donech) - }() + defer func() { + close(donech) + }() + } - for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0; { + for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && !done; { if pstmt, err = s.c.prepareV2(&psql); err != nil { return nil, err } @@ -494,13 +495,7 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res if pstmt == 0 { continue } - - if err := func() (err error) { - defer func() { - if e := s.c.finalize(pstmt); e != nil && err == nil { - err = e - } - }() + err = func() (err error) { n, err := s.c.bindParameterCount(pstmt) if err != nil { @@ -535,7 +530,13 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res } return nil - }(); err != nil { + }() + + if e := s.c.finalize(pstmt); e != nil && err == nil { + err = e + } + + if err != nil { return nil, err } } @@ -567,25 +568,26 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { //TODO StmtQuer func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) { var pstmt uintptr - donech := make(chan struct{}) + done := false + if ctx != nil && ctx.Done() != nil { + donech := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - if pstmt != 0 { + go func() { + select { + case <-ctx.Done(): + done = true s.c.interrupt(s.c.db) + case <-donech: } - case <-donech: - } - }() + }() - defer func() { - pstmt = 0 - close(donech) - }() + defer func() { + close(donech) + }() + } var allocs []uintptr - for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0; { + for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && !done; { if pstmt, err = s.c.prepareV2(&psql); err != nil { return nil, err } @@ -594,13 +596,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro continue } - if err = func() (err error) { - defer func() { - if e := s.c.finalize(pstmt); e != nil && err == nil { - err = e - } - }() - + err = func() (err error) { n, err := s.c.bindParameterCount(pstmt) if err != nil { return err @@ -649,7 +645,12 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro pstmt = 0 } return nil - }(); err != nil { + }() + if e := s.c.finalize(pstmt); e != nil && err == nil { + err = e + } + + if err != nil { return nil, err } } @@ -692,19 +693,23 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) { } defer t.c.free(psql) - //TODO use t.conn.ExecContext() instead - donech := make(chan struct{}) - defer close(donech) + if ctx != nil && ctx.Done() != nil { + donech := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - t.c.interrupt(t.c.db) - case <-donech: - } - }() + go func() { + select { + case <-ctx.Done(): + t.c.interrupt(t.c.db) + case <-donech: + } + }() + + defer func() { + close(donech) + }() + } if rc := sqlite3.Xsqlite3_exec(t.c.tls, t.c.db, psql, 0, 0, 0); rc != sqlite3.SQLITE_OK { return t.c.errstr(rc)