diff --git a/sqlite.go b/sqlite.go
index 3fe9c78..2c7e13d 100644
--- a/sqlite.go
+++ b/sqlite.go
@@ -21,7 +21,7 @@ import (
 
 	"modernc.org/libc"
 	"modernc.org/libc/sys/types"
-	"modernc.org/sqlite/lib"
+	sqlite3 "modernc.org/sqlite/lib"
 )
 
 var (
@@ -435,8 +435,9 @@ func newStmt(c *conn, sql string) (*stmt, error) {
 	if err != nil {
 		return nil, err
 	}
+	stm := stmt{c: c, psql: p}
 
-	return &stmt{c: c, psql: p}, nil
+	return &stm, nil
 }
 
 // Close closes the statement.
@@ -468,25 +469,25 @@ func toNamedValues(vals []driver.Value) (r []driver.NamedValue) {
 
 func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
 	var pstmt uintptr
-
-	donech := make(chan struct{})
-
-	go func() {
-		select {
-		case <-ctx.Done():
-			if pstmt != 0 {
+	done := false
+	if ctx != nil && ctx.Done() != nil {
+		donech := make(chan struct{})
+
+		go func() {
+			select {
+			case <-ctx.Done():
+				done = true
 				s.c.interrupt(s.c.db)
+			case <-donech:
 			}
-		case <-donech:
-		}
-	}()
+		}()
 
-	defer func() {
-		pstmt = 0
-		close(donech)
-	}()
+		defer func() {
+			close(donech)
+		}()
+	}
 
-	for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0; {
+	for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && !done; {
 		if pstmt, err = s.c.prepareV2(&psql); err != nil {
 			return nil, err
 		}
@@ -494,13 +495,7 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res
 		if pstmt == 0 {
 			continue
 		}
-
-		if err := func() (err error) {
-			defer func() {
-				if e := s.c.finalize(pstmt); e != nil && err == nil {
-					err = e
-				}
-			}()
+		err = func() (err error) {
 
 			n, err := s.c.bindParameterCount(pstmt)
 			if err != nil {
@@ -535,7 +530,13 @@ func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Res
 			}
 
 			return nil
-		}(); err != nil {
+		}()
+
+		if e := s.c.finalize(pstmt); e != nil && err == nil {
+			err = e
+		}
+
+		if err != nil {
 			return nil, err
 		}
 	}
@@ -567,25 +568,26 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { //TODO StmtQuer
 func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) {
 	var pstmt uintptr
 
-	donech := make(chan struct{})
+	done := false
+	if ctx != nil && ctx.Done() != nil {
+		donech := make(chan struct{})
 
-	go func() {
-		select {
-		case <-ctx.Done():
-			if pstmt != 0 {
+		go func() {
+			select {
+			case <-ctx.Done():
+				done = true
 				s.c.interrupt(s.c.db)
+			case <-donech:
 			}
-		case <-donech:
-		}
-	}()
+		}()
 
-	defer func() {
-		pstmt = 0
-		close(donech)
-	}()
+		defer func() {
+			close(donech)
+		}()
+	}
 
 	var allocs []uintptr
-	for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0; {
+	for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && !done; {
 		if pstmt, err = s.c.prepareV2(&psql); err != nil {
 			return nil, err
 		}
@@ -594,13 +596,7 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
 			continue
 		}
 
-		if err = func() (err error) {
-			defer func() {
-				if e := s.c.finalize(pstmt); e != nil && err == nil {
-					err = e
-				}
-			}()
-
+		err = func() (err error) {
 			n, err := s.c.bindParameterCount(pstmt)
 			if err != nil {
 				return err
@@ -649,7 +645,12 @@ func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Ro
 				pstmt = 0
 			}
 			return nil
-		}(); err != nil {
+		}()
+		if e := s.c.finalize(pstmt); e != nil && err == nil {
+			err = e
+		}
+
+		if err != nil {
 			return nil, err
 		}
 	}
@@ -692,19 +693,23 @@ func (t *tx) exec(ctx context.Context, sql string) (err error) {
 	}
 
 	defer t.c.free(psql)
-
 	//TODO use t.conn.ExecContext() instead
-	donech := make(chan struct{})
 
-	defer close(donech)
+	if ctx != nil && ctx.Done() != nil {
+		donech := make(chan struct{})
 
-	go func() {
-		select {
-		case <-ctx.Done():
-			t.c.interrupt(t.c.db)
-		case <-donech:
-		}
-	}()
+		go func() {
+			select {
+			case <-ctx.Done():
+				t.c.interrupt(t.c.db)
+			case <-donech:
+			}
+		}()
+
+		defer func() {
+			close(donech)
+		}()
+	}
 
 	if rc := sqlite3.Xsqlite3_exec(t.c.tls, t.c.db, psql, 0, 0, 0); rc != sqlite3.SQLITE_OK {
 		return t.c.errstr(rc)