From 79405af20d27fda899c3c9997cd171a23efa2f1e Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 24 Oct 2024 12:30:20 +0200 Subject: [PATCH] Statement cache clean up. --- go.mod | 2 +- internal/jet/db/db.go | 17 +++++++++--- internal/jet/db/tx.go | 2 +- tests/mysql/main_test.go | 36 ++++++++++++++++--------- tests/postgres/main_test.go | 48 +++++++++++++++++---------------- tests/postgres/sample_test.go | 2 +- tests/sqlite/delete_test.go | 2 +- tests/sqlite/main_test.go | 50 +++++++++++++++++++++-------------- 8 files changed, 96 insertions(+), 63 deletions(-) diff --git a/go.mod b/go.mod index 890dc0d6..4b48c663 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/go-jet/jet/v2 -go 1.18 +go 1.20 require ( github.com/go-sql-driver/mysql v1.8.1 diff --git a/internal/jet/db/db.go b/internal/jet/db/db.go index 94dfecad..da930034 100644 --- a/internal/jet/db/db.go +++ b/internal/jet/db/db.go @@ -3,6 +3,7 @@ package db import ( "context" "database/sql" + "errors" "fmt" "sync" ) @@ -157,8 +158,8 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error return prepStmt, nil } -// Clear will close all cached prepared statements -func (d *DB) Clear() error { +// ClearStatementsCache will close all cached prepared statements and clear statements cache +func (d *DB) ClearStatementsCache() error { d.lock.Lock() defer d.lock.Unlock() @@ -168,15 +169,23 @@ func (d *DB) Clear() error { closeErr := statement.Close() if closeErr != nil { - err = closeErr + err = errors.Join(err, closeErr) } } d.statements = make(map[string]*sql.Stmt) if err != nil { - return fmt.Errorf("some of the prepared statements failed to close, last err: %w", err) + return errors.Join(errors.New("jet: some of the prepared statements failed to close"), err) } return nil } + +// Close will clear the statements cache and close the underlying db connection +func (d *DB) Close() error { + clearErr := d.ClearStatementsCache() + closeErr := d.DB.Close() + + return errors.Join(clearErr, closeErr) +} diff --git a/internal/jet/db/tx.go b/internal/jet/db/tx.go index b64b2315..f850508b 100644 --- a/internal/jet/db/tx.go +++ b/internal/jet/db/tx.go @@ -74,7 +74,7 @@ func (t *Tx) Prepare(query string) (*sql.Stmt, error) { // automatically upon the completion of the transaction, whether it's committed or rolled back. func (t *Tx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { if !t.db.statementsCaching { - return t.PrepareContext(ctx, query) + return t.Tx.PrepareContext(ctx, query) } prepStmt, ok := t.statements[query] diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index ee02a774..b4cffbaa 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -3,6 +3,7 @@ package mysql import ( "context" "database/sql" + "fmt" jetmysql "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/dbconfig" @@ -18,11 +19,13 @@ import ( var db *jetmysql.DB var source string +var skipStatementsCaching bool const MariaDB = "MariaDB" func init() { source = os.Getenv("MY_SQL_SOURCE") + skipStatementsCaching = os.Getenv("JET_TESTS_NO_STMT_CACHE") == "true" } func sourceIsMariaDB() bool { @@ -32,20 +35,29 @@ func sourceIsMariaDB() bool { func TestMain(m *testing.M) { defer profile.Start().Stop() - var err error - sqlDB, err := sql.Open("mysql", dbconfig.MySQLConnectionString(sourceIsMariaDB(), "")) - if err != nil { - panic("Failed to connect to test db" + err.Error()) - } - - db = jetmysql.NewDB(sqlDB).WithStatementsCaching(true) - defer db.Close() + for _, cachingEnabled := range []bool{false, true} { - for i := 0; i < 2; i++ { - ret := m.Run() - if ret != 0 { - os.Exit(ret) + if cachingEnabled && skipStatementsCaching { + continue //skipped by global env variable } + + func() { + fmt.Printf("\nRunning mysql tests caching enabled: %t \n", cachingEnabled) + + var err error + sqlDB, err := sql.Open("mysql", dbconfig.MySQLConnectionString(sourceIsMariaDB(), "")) + if err != nil { + panic("Failed to connect to test db" + err.Error()) + } + + db = jetmysql.NewDB(sqlDB).WithStatementsCaching(cachingEnabled) + defer db.Close() + + ret := m.Run() + if ret != 0 { + os.Exit(ret) + } + }() } } diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 1f02e695..991c6d35 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -23,11 +23,13 @@ var db *postgres.DB var testRoot string var source string +var skipStatementsCaching bool const CockroachDB = "COCKROACH_DB" func init() { source = os.Getenv("PG_SOURCE") + skipStatementsCaching = os.Getenv("JET_TESTS_NO_STMT_CACHE") == "true" } func sourceIsCockroachDB() bool { @@ -45,39 +47,39 @@ func TestMain(m *testing.M) { setTestRoot() - for _, driverName := range []string{"pgx", "postgres"} { - fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName) + for _, cachingEnabled := range []bool{false, true} { - func() { + if cachingEnabled && skipStatementsCaching { + continue //skipped by global env variable + } - connectionString := dbconfig.PostgresConnectString + for _, driverName := range []string{"pgx", "postgres"} { - if sourceIsCockroachDB() { - connectionString = dbconfig.CockroachConnectString - } + fmt.Printf("\nRunning postgres tests for driver: %s, caching enabled: %t \n", driverName, cachingEnabled) - sqlDB, err := sql.Open(driverName, connectionString) - if err != nil { - fmt.Println(err.Error()) - panic("Failed to connect to test db") - } - db = postgres.NewDB(sqlDB).WithStatementsCaching(true) - defer db.Close() + func() { + connectionString := dbconfig.PostgresConnectString + + if sourceIsCockroachDB() { + connectionString = dbconfig.CockroachConnectString + } + + sqlDB, err := sql.Open(driverName, connectionString) + if err != nil { + fmt.Println(err.Error()) + panic("Failed to connect to test db") + } + db = postgres.NewDB(sqlDB).WithStatementsCaching(cachingEnabled) + defer db.Close() - for i := 0; i < 2; i++ { ret := m.Run() if ret != 0 { os.Exit(ret) } - } - - err = db.Clear() - - if err != nil { - os.Exit(-2) - } - }() + }() + } } + } func setTestRoot() { diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 73b652c9..7ac05068 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -1,8 +1,8 @@ package postgres import ( - "github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" "testing" diff --git a/tests/sqlite/delete_test.go b/tests/sqlite/delete_test.go index 3c8c0c4a..f700902b 100644 --- a/tests/sqlite/delete_test.go +++ b/tests/sqlite/delete_test.go @@ -69,7 +69,7 @@ func TestDeleteContextDeadlineExceeded(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) defer cancel() - time.Sleep(10 * time.Millisecond) + time.Sleep(20 * time.Millisecond) testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx qrm.DB) { var dest []model.Link diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go index 593d6300..8b04fb8a 100644 --- a/tests/sqlite/main_test.go +++ b/tests/sqlite/main_test.go @@ -19,36 +19,46 @@ import ( var db *sqlite.DB var sampleDB *sqlite.DB -var testRoot string + +var skipStatementsCaching bool + +func init() { + skipStatementsCaching = os.Getenv("JET_TESTS_NO_STMT_CACHE") == "true" +} func TestMain(m *testing.M) { defer profile.Start().Stop() - sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) - throw.OnError(err) - db = sqlite.NewDB(sqlDB).WithStatementsCaching(true) - defer db.Close() + for _, cachingEnabled := range []bool{false, true} { - _, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath)) - throw.OnError(err) + if cachingEnabled && skipStatementsCaching { + continue //skipped by global env variable + } - sqlSampleDB, err := sql.Open("sqlite3", dbconfig.TestSampleDBPath) - throw.OnError(err) - sampleDB = sqlite.NewDB(sqlSampleDB).WithStatementsCaching(true) - defer sampleDB.Close() + func() { - for i := 0; i < 2; i++ { - ret := m.Run() - if ret != 0 { - os.Exit(ret) - } - } + sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) + throw.OnError(err) + db = sqlite.NewDB(sqlDB).WithStatementsCaching(cachingEnabled) + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath)) + throw.OnError(err) + + sqlSampleDB, err := sql.Open("sqlite3", dbconfig.TestSampleDBPath) + throw.OnError(err) + sampleDB = sqlite.NewDB(sqlSampleDB).WithStatementsCaching(cachingEnabled) + defer sampleDB.Close() - err = sampleDB.Clear() + ret := m.Run() + if ret != 0 { + os.Exit(ret) + } + + }() - if err != nil { - panic(err) } + } var loggedSQL string