Skip to content

Commit

Permalink
Statement cache clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
go-jet committed Oct 24, 2024
1 parent c9b2553 commit 79405af
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 63 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/go-jet/jet/v2

go 1.18
go 1.20

require (
github.com/go-sql-driver/mysql v1.8.1
Expand Down
17 changes: 13 additions & 4 deletions internal/jet/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
)
Expand Down Expand Up @@ -157,8 +158,8 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error
return prepStmt, nil
}

// Clear will close all cached prepared statements
func (d *DB) Clear() error {
// ClearStatementsCache will close all cached prepared statements and clear statements cache
func (d *DB) ClearStatementsCache() error {
d.lock.Lock()
defer d.lock.Unlock()

Expand All @@ -168,15 +169,23 @@ func (d *DB) Clear() error {
closeErr := statement.Close()

if closeErr != nil {
err = closeErr
err = errors.Join(err, closeErr)
}
}

d.statements = make(map[string]*sql.Stmt)

if err != nil {
return fmt.Errorf("some of the prepared statements failed to close, last err: %w", err)
return errors.Join(errors.New("jet: some of the prepared statements failed to close"), err)
}

return nil
}

// Close will clear the statements cache and close the underlying db connection
func (d *DB) Close() error {
clearErr := d.ClearStatementsCache()
closeErr := d.DB.Close()

return errors.Join(clearErr, closeErr)
}
2 changes: 1 addition & 1 deletion internal/jet/db/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (t *Tx) Prepare(query string) (*sql.Stmt, error) {
// automatically upon the completion of the transaction, whether it's committed or rolled back.
func (t *Tx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
if !t.db.statementsCaching {
return t.PrepareContext(ctx, query)
return t.Tx.PrepareContext(ctx, query)
}

prepStmt, ok := t.statements[query]
Expand Down
36 changes: 24 additions & 12 deletions tests/mysql/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mysql
import (
"context"
"database/sql"
"fmt"
jetmysql "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/dbconfig"
Expand All @@ -18,11 +19,13 @@ import (
var db *jetmysql.DB

var source string
var skipStatementsCaching bool

const MariaDB = "MariaDB"

func init() {
source = os.Getenv("MY_SQL_SOURCE")
skipStatementsCaching = os.Getenv("JET_TESTS_NO_STMT_CACHE") == "true"
}

func sourceIsMariaDB() bool {
Expand All @@ -32,20 +35,29 @@ func sourceIsMariaDB() bool {
func TestMain(m *testing.M) {
defer profile.Start().Stop()

var err error
sqlDB, err := sql.Open("mysql", dbconfig.MySQLConnectionString(sourceIsMariaDB(), ""))
if err != nil {
panic("Failed to connect to test db" + err.Error())
}

db = jetmysql.NewDB(sqlDB).WithStatementsCaching(true)
defer db.Close()
for _, cachingEnabled := range []bool{false, true} {

for i := 0; i < 2; i++ {
ret := m.Run()
if ret != 0 {
os.Exit(ret)
if cachingEnabled && skipStatementsCaching {
continue //skipped by global env variable
}

func() {
fmt.Printf("\nRunning mysql tests caching enabled: %t \n", cachingEnabled)

var err error
sqlDB, err := sql.Open("mysql", dbconfig.MySQLConnectionString(sourceIsMariaDB(), ""))
if err != nil {
panic("Failed to connect to test db" + err.Error())
}

db = jetmysql.NewDB(sqlDB).WithStatementsCaching(cachingEnabled)
defer db.Close()

ret := m.Run()
if ret != 0 {
os.Exit(ret)
}
}()
}
}

Expand Down
48 changes: 25 additions & 23 deletions tests/postgres/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ var db *postgres.DB
var testRoot string

var source string
var skipStatementsCaching bool

const CockroachDB = "COCKROACH_DB"

func init() {
source = os.Getenv("PG_SOURCE")
skipStatementsCaching = os.Getenv("JET_TESTS_NO_STMT_CACHE") == "true"
}

func sourceIsCockroachDB() bool {
Expand All @@ -45,39 +47,39 @@ func TestMain(m *testing.M) {

setTestRoot()

for _, driverName := range []string{"pgx", "postgres"} {
fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName)
for _, cachingEnabled := range []bool{false, true} {

func() {
if cachingEnabled && skipStatementsCaching {
continue //skipped by global env variable
}

connectionString := dbconfig.PostgresConnectString
for _, driverName := range []string{"pgx", "postgres"} {

if sourceIsCockroachDB() {
connectionString = dbconfig.CockroachConnectString
}
fmt.Printf("\nRunning postgres tests for driver: %s, caching enabled: %t \n", driverName, cachingEnabled)

sqlDB, err := sql.Open(driverName, connectionString)
if err != nil {
fmt.Println(err.Error())
panic("Failed to connect to test db")
}
db = postgres.NewDB(sqlDB).WithStatementsCaching(true)
defer db.Close()
func() {
connectionString := dbconfig.PostgresConnectString

if sourceIsCockroachDB() {
connectionString = dbconfig.CockroachConnectString
}

sqlDB, err := sql.Open(driverName, connectionString)
if err != nil {
fmt.Println(err.Error())
panic("Failed to connect to test db")
}
db = postgres.NewDB(sqlDB).WithStatementsCaching(cachingEnabled)
defer db.Close()

for i := 0; i < 2; i++ {
ret := m.Run()
if ret != 0 {
os.Exit(ret)
}
}

err = db.Clear()

if err != nil {
os.Exit(-2)
}
}()
}()
}
}

}

func setTestRoot() {
Expand Down
2 changes: 1 addition & 1 deletion tests/postgres/sample_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package postgres

import (
"github.com/go-jet/jet/v2/qrm"
"github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/go-jet/jet/v2/qrm"
"github.com/google/uuid"
"testing"

Expand Down
2 changes: 1 addition & 1 deletion tests/sqlite/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestDeleteContextDeadlineExceeded(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()

time.Sleep(10 * time.Millisecond)
time.Sleep(20 * time.Millisecond)

testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx qrm.DB) {
var dest []model.Link
Expand Down
50 changes: 30 additions & 20 deletions tests/sqlite/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,46 @@ import (

var db *sqlite.DB
var sampleDB *sqlite.DB
var testRoot string

var skipStatementsCaching bool

func init() {
skipStatementsCaching = os.Getenv("JET_TESTS_NO_STMT_CACHE") == "true"
}

func TestMain(m *testing.M) {
defer profile.Start().Stop()

sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath)
throw.OnError(err)
db = sqlite.NewDB(sqlDB).WithStatementsCaching(true)
defer db.Close()
for _, cachingEnabled := range []bool{false, true} {

_, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath))
throw.OnError(err)
if cachingEnabled && skipStatementsCaching {
continue //skipped by global env variable
}

sqlSampleDB, err := sql.Open("sqlite3", dbconfig.TestSampleDBPath)
throw.OnError(err)
sampleDB = sqlite.NewDB(sqlSampleDB).WithStatementsCaching(true)
defer sampleDB.Close()
func() {

for i := 0; i < 2; i++ {
ret := m.Run()
if ret != 0 {
os.Exit(ret)
}
}
sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath)
throw.OnError(err)
db = sqlite.NewDB(sqlDB).WithStatementsCaching(cachingEnabled)
defer db.Close()

_, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath))
throw.OnError(err)

sqlSampleDB, err := sql.Open("sqlite3", dbconfig.TestSampleDBPath)
throw.OnError(err)
sampleDB = sqlite.NewDB(sqlSampleDB).WithStatementsCaching(cachingEnabled)
defer sampleDB.Close()

err = sampleDB.Clear()
ret := m.Run()
if ret != 0 {
os.Exit(ret)
}

}()

if err != nil {
panic(err)
}

}

var loggedSQL string
Expand Down

0 comments on commit 79405af

Please sign in to comment.