From fdbbeb3f2654123c2f7ae6a129978aea1c75b5b5 Mon Sep 17 00:00:00 2001 From: Lz Date: Sun, 28 Apr 2024 20:30:49 +0800 Subject: [PATCH 1/4] !feat(dtc):renamed Context and added DTC --- CHANGELOG.md | 6 +- context.go => conn.go | 26 ++--- context_stmt.go => conn_stmt.go | 6 +- context_stmt_test.go => conn_stmt_test.go | 0 connector.go | 2 +- db.go | 14 +-- dtc.go | 108 ++++++++++++++++++++ dtc_test.go | 118 ++++++++++++++++++++++ queryer_mapr.go | 8 +- 9 files changed, 259 insertions(+), 29 deletions(-) rename context.go => conn.go (70%) rename context_stmt.go => conn_stmt.go (86%) rename context_stmt_test.go => conn_stmt_test.go (100%) create mode 100644 dtc.go create mode 100644 dtc_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 8031c00..5cfae1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [1.4.7] +## [1.5.0] +### Changed +- renamed `Context` with `Conn` (#44) + ### Added - added `Connector` interface (#43) - added nullable `Time` with better json support (#44) +- added `DistributedTransactions` (#44) ### Changed - renamed `BitBool` with shorter name `Bool` (#44) diff --git a/context.go b/conn.go similarity index 70% rename from context.go rename to conn.go index 6ffe80d..05226d0 100644 --- a/context.go +++ b/conn.go @@ -8,7 +8,7 @@ import ( "time" ) -type Context struct { +type Conn struct { *sql.DB sync.Mutex _ noCopy @@ -20,11 +20,11 @@ type Context struct { Index int } -func (db *Context) Query(query string, args ...any) (*Rows, error) { +func (db *Conn) Query(query string, args ...any) (*Rows, error) { return db.QueryContext(context.Background(), query, args...) } -func (db *Context) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) { +func (db *Conn) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) { query, args, err := b.Build() if err != nil { return nil, err @@ -33,7 +33,7 @@ func (db *Context) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) return db.QueryContext(ctx, query, args...) } -func (db *Context) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { +func (db *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { var rows *sql.Rows var stmt *Stmt var err error @@ -59,11 +59,11 @@ func (db *Context) QueryContext(ctx context.Context, query string, args ...any) return &Rows{Rows: rows, stmt: stmt, query: query}, nil } -func (db *Context) QueryRow(query string, args ...any) *Row { +func (db *Conn) QueryRow(query string, args ...any) *Row { return db.QueryRowContext(context.Background(), query, args...) } -func (db *Context) QueryRowBuilder(ctx context.Context, b *Builder) *Row { +func (db *Conn) QueryRowBuilder(ctx context.Context, b *Builder) *Row { query, args, err := b.Build() if err != nil { return &Row{ @@ -75,7 +75,7 @@ func (db *Context) QueryRowBuilder(ctx context.Context, b *Builder) *Row { return db.QueryRowContext(ctx, query, args...) } -func (db *Context) QueryRowContext(ctx context.Context, query string, args ...any) *Row { +func (db *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row { var rows *sql.Rows var stmt *Stmt var err error @@ -108,11 +108,11 @@ func (db *Context) QueryRowContext(ctx context.Context, query string, args ...an } } -func (db *Context) Exec(query string, args ...any) (sql.Result, error) { +func (db *Conn) Exec(query string, args ...any) (sql.Result, error) { return db.ExecContext(context.Background(), query, args...) } -func (db *Context) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) { +func (db *Conn) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) { query, args, err := b.Build() if err != nil { return nil, err @@ -121,7 +121,7 @@ func (db *Context) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, err return db.ExecContext(ctx, query, args...) } -func (db *Context) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { +func (db *Conn) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { if len(args) > 0 { stmt, err := db.prepareStmt(ctx, query) if err != nil { @@ -135,12 +135,12 @@ func (db *Context) ExecContext(ctx context.Context, query string, args ...any) ( return db.DB.ExecContext(context.Background(), query, args...) } -func (db *Context) Begin(opts *sql.TxOptions) (*Tx, error) { +func (db *Conn) Begin(opts *sql.TxOptions) (*Tx, error) { return db.BeginTx(context.TODO(), opts) } -func (db *Context) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { +func (db *Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { tx, err := db.DB.BeginTx(ctx, opts) if err != nil { return nil, err @@ -149,7 +149,7 @@ func (db *Context) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error return &Tx{Tx: tx, stmts: make(map[string]*sql.Stmt)}, nil } -func (db *Context) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error { +func (db *Conn) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error { tx, err := db.BeginTx(ctx, opts) if err != nil { return err diff --git a/context_stmt.go b/conn_stmt.go similarity index 86% rename from context_stmt.go rename to conn_stmt.go index 3c94352..8860772 100644 --- a/context_stmt.go +++ b/conn_stmt.go @@ -21,7 +21,7 @@ func (s *Stmt) Reuse() { s.isUsing = false } -func (db *Context) prepareStmt(ctx context.Context, query string) (*Stmt, error) { +func (db *Conn) prepareStmt(ctx context.Context, query string) (*Stmt, error) { db.stmtsMutex.Lock() defer db.stmtsMutex.Unlock() s, ok := db.stmts[query] @@ -48,7 +48,7 @@ func (db *Context) prepareStmt(ctx context.Context, query string) (*Stmt, error) return s, nil } -func (db *Context) closeStaleStmt() { +func (db *Conn) closeStaleStmt() { db.stmtsMutex.Lock() defer db.stmtsMutex.Unlock() @@ -64,7 +64,7 @@ func (db *Context) closeStaleStmt() { } -func (db *Context) checkIdleStmt() { +func (db *Conn) checkIdleStmt() { delay := time.NewTicker(db.stmtMaxIdleTime) defer delay.Stop() diff --git a/context_stmt_test.go b/conn_stmt_test.go similarity index 100% rename from context_stmt_test.go rename to conn_stmt_test.go diff --git a/connector.go b/connector.go index cc6107c..514e683 100644 --- a/connector.go +++ b/connector.go @@ -6,7 +6,7 @@ import ( ) // Connector represents a database connector that provides methods for executing queries and commands. -// Context and Tx both implement this interface. +// Conn and Tx both implement this interface. type Connector interface { // Query executes a query that returns multiple rows. // It takes a query string and optional arguments. diff --git a/db.go b/db.go index 13d8208..5c7aed9 100644 --- a/db.go +++ b/db.go @@ -16,12 +16,12 @@ var ( // DB represents a database connection pool with sharding support. type DB struct { - *Context + *Conn _ noCopy //nolint: unused mu sync.RWMutex dhts map[string]*shardid.DHT - dbs []*Context + dbs []*Conn } // Open creates a new DB instance with the provided database connections. @@ -31,7 +31,7 @@ func Open(dbs ...*sql.DB) *DB { } for i, db := range dbs { - ctx := &Context{ + ctx := &Conn{ DB: db, Index: i, stmts: make(map[string]*Stmt), @@ -41,7 +41,7 @@ func Open(dbs ...*sql.DB) *DB { go ctx.checkIdleStmt() } - d.Context = d.dbs[0] + d.Conn = d.dbs[0] return d } @@ -54,7 +54,7 @@ func (db *DB) Add(dbs ...*sql.DB) { n := len(db.dbs) for i, d := range dbs { - ctx := &Context{ + ctx := &Conn{ DB: d, Index: n + i, stmts: make(map[string]*Stmt), @@ -66,7 +66,7 @@ func (db *DB) Add(dbs ...*sql.DB) { } // On selects the database context based on the shardid ID. -func (db *DB) On(id shardid.ID) *Context { +func (db *DB) On(id shardid.ID) *Conn { db.mu.RLock() defer db.mu.RUnlock() @@ -90,7 +90,7 @@ func (db *DB) GetDHT(name string) *shardid.DHT { } // OnDHT selects the database context based on the DHT (Distributed Hash Table) key. -func (db *DB) OnDHT(key string, names ...string) (*Context, error) { +func (db *DB) OnDHT(key string, names ...string) (*Conn, error) { db.mu.RLock() defer db.mu.RUnlock() diff --git a/dtc.go b/dtc.go new file mode 100644 index 0000000..8c58634 --- /dev/null +++ b/dtc.go @@ -0,0 +1,108 @@ +package sqle + +import ( + "context" + "database/sql" +) + +// DTC Distributed Transaction Coordinator +type DTC struct { + ctx context.Context + opts *sql.TxOptions + + sessions []*session +} + +// session represents a transaction session. +type session struct { + committed bool + conn *Conn + tx *Tx + exec []func(ctx context.Context, c Connector) error + revert []func(ctx context.Context, c Connector) error +} + +// NewDTC creates a new instance of DTC. +func NewDTC(ctx context.Context, opts *sql.TxOptions) *DTC { + return &DTC{ + ctx: ctx, + opts: opts, + } +} + +// Prepare adds a new transaction session to the DTC. +func (d *DTC) Prepare(conn *Conn, exec func(ctx context.Context, c Connector) error, revert func(ctx context.Context, c Connector) error) { + for _, s := range d.sessions { + if s.conn.Index == conn.Index { + s.exec = append(s.exec, exec) + s.revert = append(s.revert, revert) + return + } + } + + s := &session{ + committed: false, + conn: conn, + exec: []func(ctx context.Context, c Connector) error{ + exec, + }, + revert: []func(ctx context.Context, c Connector) error{ + revert, + }, + } + + d.sessions = append(d.sessions, s) + +} + +// Commit commits all the prepared transactions in the DTC. +func (d *DTC) Commit() error { + for _, s := range d.sessions { + tx, err := s.conn.BeginTx(d.ctx, d.opts) + if err != nil { + return err + } + + s.tx = tx + + for _, exec := range s.exec { + err = exec(d.ctx, tx) + if err != nil { + return err + } + } + } + + for _, s := range d.sessions { + err := s.tx.Commit() + if err != nil { + return err + } + + s.committed = true + } + + return nil +} + +// Rollback rolls back all the prepared transactions in the DTC. +func (d *DTC) Rollback() []error { + var errs []error + + for _, s := range d.sessions { + if s.committed { + for _, revert := range s.revert { + if err := revert(d.ctx, s.conn); err != nil { + errs = append(errs, err) + } + } + + } else { + if err := s.tx.Rollback(); err != nil { + errs = append(errs, err) + } + } + } + + return errs +} diff --git a/dtc_test.go b/dtc_test.go new file mode 100644 index 0000000..4a1b53b --- /dev/null +++ b/dtc_test.go @@ -0,0 +1,118 @@ +package sqle + +import ( + "context" + "database/sql" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDTCWithDB(t *testing.T) { + os.Remove("dtc_1.db") + + d, err := sql.Open("sqlite3", "file:dtc_1.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d.Exec("CREATE TABLE `dtc_1` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + db := Open(d) + + var tests = []struct { + name string + setup func() *DTC + assert func(ra *require.Assertions) + }{ + { + name: "multiple_txs_commit_should_work", + setup: func() *DTC { + dtc := NewDTC(context.Background(), nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 1, "1@mail.com") + + return err + }, nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 2, "2@mail.com") + + return err + }, nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 3, "3@mail.com") + + return err + }, nil) + + return dtc + }, + assert: func(ra *require.Assertions) { + var id int + err := db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 1).Scan(&id) + ra.NoError(err) + ra.Equal(1, id) + + err = db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 2).Scan(&id) + ra.NoError(err) + ra.Equal(2, id) + + err = db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 3).Scan(&id) + ra.NoError(err) + ra.Equal(3, id) + }, + }, + { + name: "multiple_txs_rollback_should_work", + setup: func() *DTC { + dtc := NewDTC(context.Background(), nil) + dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 11, "1@mail.com") + + return err + }, nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 12, "2@mail.com") + + return err + }, nil) + + dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 13) + + return err + }, nil) + + return dtc + }, + assert: func(ra *require.Assertions) { + var id int + err := db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 11).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 12).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 13).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dtc := test.setup() + + err := dtc.Commit() + if err != nil { + dtc.Rollback() + } + + test.assert(require.New(t)) + }) + } +} diff --git a/queryer_mapr.go b/queryer_mapr.go index b1e3a8c..d08e8df 100644 --- a/queryer_mapr.go +++ b/queryer_mapr.go @@ -11,7 +11,7 @@ import ( // MapR is a Map/Reduce Query Provider based on databases. type MapR[T any] struct { - dbs []*Context + dbs []*Conn } // First executes the query and returns the first result. @@ -28,7 +28,7 @@ func (q *MapR[T]) First(ctx context.Context, rotatedTables []string, b *Builder) for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { - w.Add(func(db *Context, qr string) func(context.Context) (T, error) { + w.Add(func(db *Conn, qr string) func(context.Context) (T, error) { return func(ctx context.Context) (T, error) { var t T err := db.QueryRowContext(ctx, qr, args...).Bind(&t) @@ -59,7 +59,7 @@ func (q *MapR[T]) Count(ctx context.Context, rotatedTables []string, b *Builder) for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { - w.Add(func(db *Context, qr string) func(context.Context) (int64, error) { + w.Add(func(db *Conn, qr string) func(context.Context) (int64, error) { return func(ctx context.Context) (int64, error) { var i int64 err := db.QueryRowContext(ctx, qr, args...).Scan(&i) @@ -102,7 +102,7 @@ func (q *MapR[T]) Query(ctx context.Context, rotatedTables []string, b *Builder, for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { - w.Add(func(db *Context, qr string) func(context.Context) ([]T, error) { + w.Add(func(db *Conn, qr string) func(context.Context) ([]T, error) { return func(context.Context) ([]T, error) { var t []T rows, err := db.QueryContext(ctx, qr, args...) From c5a6c42c418c681d0e1e679271b64fa3ecfc9d43 Mon Sep 17 00:00:00 2001 From: Lz Date: Sun, 28 Apr 2024 21:01:07 +0800 Subject: [PATCH 2/4] fix(dtc): fixed commit and added revert testing --- dtc.go | 2 +- dtc_test.go | 200 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 200 insertions(+), 2 deletions(-) diff --git a/dtc.go b/dtc.go index 8c58634..01baab5 100644 --- a/dtc.go +++ b/dtc.go @@ -33,7 +33,7 @@ func NewDTC(ctx context.Context, opts *sql.TxOptions) *DTC { // Prepare adds a new transaction session to the DTC. func (d *DTC) Prepare(conn *Conn, exec func(ctx context.Context, c Connector) error, revert func(ctx context.Context, c Connector) error) { for _, s := range d.sessions { - if s.conn.Index == conn.Index { + if s.conn == conn { s.exec = append(s.exec, exec) s.revert = append(s.revert, revert) return diff --git a/dtc_test.go b/dtc_test.go index 4a1b53b..0d308aa 100644 --- a/dtc_test.go +++ b/dtc_test.go @@ -10,7 +10,7 @@ import ( ) func TestDTCWithDB(t *testing.T) { - os.Remove("dtc_1.db") + os.Remove("dtc_db.db") d, err := sql.Open("sqlite3", "file:dtc_1.db?cache=shared&mode=rwc") require.NoError(t, err) @@ -116,3 +116,201 @@ func TestDTCWithDB(t *testing.T) { }) } } + +func TestDTCWithDBs(t *testing.T) { + os.Remove("dtc_dbs_1.db") + os.Remove("dtc_dbs_2.db") + os.Remove("dtc_dbs_3.db") + + d1, err := sql.Open("sqlite3", "file:dtc_dbs_1.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d1.Exec("CREATE TABLE `dtc_1` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + d2, err := sql.Open("sqlite3", "file:dtc_dbs_2.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d2.Exec("CREATE TABLE `dtc_2` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + d3, err := sql.Open("sqlite3", "file:dtc_dbs_3.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d3.Exec("CREATE TABLE `dtc_3` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + db1 := Open(d1) + db2 := Open(d2) + db3 := Open(d3) + + var tests = []struct { + name string + setup func() *DTC + assert func(ra *require.Assertions) + }{ + { + name: "multiple_txs_commit_should_work", + setup: func() *DTC { + dtc := NewDTC(context.Background(), nil) + + dtc.Prepare(db1.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_1` (`id`,`email`) VALUES(?,?)", 1, "1@mail.com") + + return err + }, nil) + + dtc.Prepare(db2.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_2` (`id`,`email`) VALUES(?,?)", 2, "2@mail.com") + + return err + }, nil) + + dtc.Prepare(db3.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_3` (`id`,`email`) VALUES(?,?)", 3, "3@mail.com") + + return err + }, nil) + + return dtc + }, + assert: func(ra *require.Assertions) { + var id int + err := db1.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 1).Scan(&id) + ra.NoError(err) + ra.Equal(1, id) + + err = db2.QueryRow("SELECT id FROM `dtc_2` WHERE id=?", 2).Scan(&id) + ra.NoError(err) + ra.Equal(2, id) + + err = db3.QueryRow("SELECT id FROM `dtc_3` WHERE id=?", 3).Scan(&id) + ra.NoError(err) + ra.Equal(3, id) + }, + }, + { + name: "multiple_txs_rollback_should_work", + setup: func() *DTC { + dtc := NewDTC(context.Background(), nil) + dtc.Prepare(db1.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 11, "1@mail.com") + + return err + }, nil) + + dtc.Prepare(db2.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_2`(`id`,`email`) VALUES(?,?)", 12, "2@mail.com") + + return err + }, nil) + + dtc.Prepare(db3.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_3`(`id`,`email`) VALUES(?,?)", 13) + + return err + }, nil) + + return dtc + }, + assert: func(ra *require.Assertions) { + var id int + err := db1.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 11).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db2.QueryRow("SELECT id FROM `dtc_2` WHERE id=?", 12).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db3.QueryRow("SELECT id FROM `dtc_3` WHERE id=?", 13).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dtc := test.setup() + + err := dtc.Commit() + if err != nil { + dtc.Rollback() + } + + test.assert(require.New(t)) + }) + } +} + +func TestDTCRevert(t *testing.T) { + os.Remove("dtc_revert_1.db") + os.Remove("dtc_revert_2.db") + os.Remove("dtc_revert_3.db") + + d1, err := sql.Open("sqlite3", "file:dtc_revert_1.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d1.Exec("CREATE TABLE `dtc_1` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + d2, err := sql.Open("sqlite3", "file:dtc_revert_2.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d2.Exec("CREATE TABLE `dtc_2` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + d3, err := sql.Open("sqlite3", "file:dtc_revert_3.db?cache=shared&mode=rwc") + require.NoError(t, err) + + _, err = d3.Exec("CREATE TABLE `dtc_3` (`id` int , `email` varchar(50),`created_at` DATETIME, PRIMARY KEY (`id`))") + require.NoError(t, err) + + db1 := Open(d1) + db2 := Open(d2) + db3 := Open(d3) + + dtc := NewDTC(context.Background(), nil) + + dtc.Prepare(db1.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_1` (`id`,`email`) VALUES(?,?)", 1, "1@mail.com") + return err + }, func(ctx context.Context, c Connector) error { + _, err := c.Exec("DELETE FROM `dtc_1` WHERE id=?", 1) + return err + }) + + dtc.Prepare(db2.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_2` (`id`,`email`) VALUES(?,?)", 2, "2@mail.com") + + return err + }, func(ctx context.Context, c Connector) error { + _, err := c.Exec("DELETE FROM `dtc_2` WHERE id=?", 2) + return err + }) + + dtc.Prepare(db3.dbs[0], func(ctx context.Context, c Connector) error { + _, err := c.Exec("INSERT INTO `dtc_3` (`id`,`email`) VALUES(?,?)", 3, "3@mail.com") + + return err + }, func(ctx context.Context, c Connector) error { + _, err := c.Exec("DELETE FROM `dtc_3` WHERE id=?", 3) + return err + }) + + ra := require.New(t) + err = dtc.Commit() + ra.NoError(err) + + errs := dtc.Rollback() + ra.Len(errs, 0) + + var id int + err = db1.QueryRow("SELECT id FROM `dtc_1` WHERE id=?", 11).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db2.QueryRow("SELECT id FROM `dtc_2` WHERE id=?", 12).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + + err = db3.QueryRow("SELECT id FROM `dtc_3` WHERE id=?", 13).Scan(&id) + ra.ErrorIs(err, sql.ErrNoRows) + +} From c969125b49685ad7d985f71e9b00c825332e3223 Mon Sep 17 00:00:00 2001 From: Lz Date: Sun, 28 Apr 2024 21:08:18 +0800 Subject: [PATCH 3/4] !fix(context): renamed Context to Client" --- conn.go => client.go | 26 +++++----- conn_stmt.go => client_stmt.go | 6 +-- conn_stmt_test.go => client_stmt_test.go | 0 db.go | 14 +++--- dtc.go | 16 +++---- dtc_test.go | 60 ++++++++++++------------ queryer_mapr.go | 8 ++-- 7 files changed, 65 insertions(+), 65 deletions(-) rename conn.go => client.go (70%) rename conn_stmt.go => client_stmt.go (86%) rename conn_stmt_test.go => client_stmt_test.go (100%) diff --git a/conn.go b/client.go similarity index 70% rename from conn.go rename to client.go index 05226d0..a7d4c27 100644 --- a/conn.go +++ b/client.go @@ -8,7 +8,7 @@ import ( "time" ) -type Conn struct { +type Client struct { *sql.DB sync.Mutex _ noCopy @@ -20,11 +20,11 @@ type Conn struct { Index int } -func (db *Conn) Query(query string, args ...any) (*Rows, error) { +func (db *Client) Query(query string, args ...any) (*Rows, error) { return db.QueryContext(context.Background(), query, args...) } -func (db *Conn) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) { +func (db *Client) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) { query, args, err := b.Build() if err != nil { return nil, err @@ -33,7 +33,7 @@ func (db *Conn) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error) { return db.QueryContext(ctx, query, args...) } -func (db *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { +func (db *Client) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { var rows *sql.Rows var stmt *Stmt var err error @@ -59,11 +59,11 @@ func (db *Conn) QueryContext(ctx context.Context, query string, args ...any) (*R return &Rows{Rows: rows, stmt: stmt, query: query}, nil } -func (db *Conn) QueryRow(query string, args ...any) *Row { +func (db *Client) QueryRow(query string, args ...any) *Row { return db.QueryRowContext(context.Background(), query, args...) } -func (db *Conn) QueryRowBuilder(ctx context.Context, b *Builder) *Row { +func (db *Client) QueryRowBuilder(ctx context.Context, b *Builder) *Row { query, args, err := b.Build() if err != nil { return &Row{ @@ -75,7 +75,7 @@ func (db *Conn) QueryRowBuilder(ctx context.Context, b *Builder) *Row { return db.QueryRowContext(ctx, query, args...) } -func (db *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row { +func (db *Client) QueryRowContext(ctx context.Context, query string, args ...any) *Row { var rows *sql.Rows var stmt *Stmt var err error @@ -108,11 +108,11 @@ func (db *Conn) QueryRowContext(ctx context.Context, query string, args ...any) } } -func (db *Conn) Exec(query string, args ...any) (sql.Result, error) { +func (db *Client) Exec(query string, args ...any) (sql.Result, error) { return db.ExecContext(context.Background(), query, args...) } -func (db *Conn) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) { +func (db *Client) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) { query, args, err := b.Build() if err != nil { return nil, err @@ -121,7 +121,7 @@ func (db *Conn) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) return db.ExecContext(ctx, query, args...) } -func (db *Conn) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { +func (db *Client) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { if len(args) > 0 { stmt, err := db.prepareStmt(ctx, query) if err != nil { @@ -135,12 +135,12 @@ func (db *Conn) ExecContext(ctx context.Context, query string, args ...any) (sql return db.DB.ExecContext(context.Background(), query, args...) } -func (db *Conn) Begin(opts *sql.TxOptions) (*Tx, error) { +func (db *Client) Begin(opts *sql.TxOptions) (*Tx, error) { return db.BeginTx(context.TODO(), opts) } -func (db *Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { +func (db *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { tx, err := db.DB.BeginTx(ctx, opts) if err != nil { return nil, err @@ -149,7 +149,7 @@ func (db *Conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { return &Tx{Tx: tx, stmts: make(map[string]*sql.Stmt)}, nil } -func (db *Conn) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error { +func (db *Client) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error { tx, err := db.BeginTx(ctx, opts) if err != nil { return err diff --git a/conn_stmt.go b/client_stmt.go similarity index 86% rename from conn_stmt.go rename to client_stmt.go index 8860772..fa31618 100644 --- a/conn_stmt.go +++ b/client_stmt.go @@ -21,7 +21,7 @@ func (s *Stmt) Reuse() { s.isUsing = false } -func (db *Conn) prepareStmt(ctx context.Context, query string) (*Stmt, error) { +func (db *Client) prepareStmt(ctx context.Context, query string) (*Stmt, error) { db.stmtsMutex.Lock() defer db.stmtsMutex.Unlock() s, ok := db.stmts[query] @@ -48,7 +48,7 @@ func (db *Conn) prepareStmt(ctx context.Context, query string) (*Stmt, error) { return s, nil } -func (db *Conn) closeStaleStmt() { +func (db *Client) closeStaleStmt() { db.stmtsMutex.Lock() defer db.stmtsMutex.Unlock() @@ -64,7 +64,7 @@ func (db *Conn) closeStaleStmt() { } -func (db *Conn) checkIdleStmt() { +func (db *Client) checkIdleStmt() { delay := time.NewTicker(db.stmtMaxIdleTime) defer delay.Stop() diff --git a/conn_stmt_test.go b/client_stmt_test.go similarity index 100% rename from conn_stmt_test.go rename to client_stmt_test.go diff --git a/db.go b/db.go index 5c7aed9..b201fac 100644 --- a/db.go +++ b/db.go @@ -16,12 +16,12 @@ var ( // DB represents a database connection pool with sharding support. type DB struct { - *Conn + *Client _ noCopy //nolint: unused mu sync.RWMutex dhts map[string]*shardid.DHT - dbs []*Conn + dbs []*Client } // Open creates a new DB instance with the provided database connections. @@ -31,7 +31,7 @@ func Open(dbs ...*sql.DB) *DB { } for i, db := range dbs { - ctx := &Conn{ + ctx := &Client{ DB: db, Index: i, stmts: make(map[string]*Stmt), @@ -41,7 +41,7 @@ func Open(dbs ...*sql.DB) *DB { go ctx.checkIdleStmt() } - d.Conn = d.dbs[0] + d.Client = d.dbs[0] return d } @@ -54,7 +54,7 @@ func (db *DB) Add(dbs ...*sql.DB) { n := len(db.dbs) for i, d := range dbs { - ctx := &Conn{ + ctx := &Client{ DB: d, Index: n + i, stmts: make(map[string]*Stmt), @@ -66,7 +66,7 @@ func (db *DB) Add(dbs ...*sql.DB) { } // On selects the database context based on the shardid ID. -func (db *DB) On(id shardid.ID) *Conn { +func (db *DB) On(id shardid.ID) *Client { db.mu.RLock() defer db.mu.RUnlock() @@ -90,7 +90,7 @@ func (db *DB) GetDHT(name string) *shardid.DHT { } // OnDHT selects the database context based on the DHT (Distributed Hash Table) key. -func (db *DB) OnDHT(key string, names ...string) (*Conn, error) { +func (db *DB) OnDHT(key string, names ...string) (*Client, error) { db.mu.RLock() defer db.mu.RUnlock() diff --git a/dtc.go b/dtc.go index 01baab5..acd0658 100644 --- a/dtc.go +++ b/dtc.go @@ -16,10 +16,10 @@ type DTC struct { // session represents a transaction session. type session struct { committed bool - conn *Conn + client *Client tx *Tx - exec []func(ctx context.Context, c Connector) error - revert []func(ctx context.Context, c Connector) error + exec []func(context.Context, Connector) error + revert []func(context.Context, Connector) error } // NewDTC creates a new instance of DTC. @@ -31,9 +31,9 @@ func NewDTC(ctx context.Context, opts *sql.TxOptions) *DTC { } // Prepare adds a new transaction session to the DTC. -func (d *DTC) Prepare(conn *Conn, exec func(ctx context.Context, c Connector) error, revert func(ctx context.Context, c Connector) error) { +func (d *DTC) Prepare(client *Client, exec func(ctx context.Context, conn Connector) error, revert func(ctx context.Context, conn Connector) error) { for _, s := range d.sessions { - if s.conn == conn { + if s.client == client { s.exec = append(s.exec, exec) s.revert = append(s.revert, revert) return @@ -42,7 +42,7 @@ func (d *DTC) Prepare(conn *Conn, exec func(ctx context.Context, c Connector) er s := &session{ committed: false, - conn: conn, + client: client, exec: []func(ctx context.Context, c Connector) error{ exec, }, @@ -58,7 +58,7 @@ func (d *DTC) Prepare(conn *Conn, exec func(ctx context.Context, c Connector) er // Commit commits all the prepared transactions in the DTC. func (d *DTC) Commit() error { for _, s := range d.sessions { - tx, err := s.conn.BeginTx(d.ctx, d.opts) + tx, err := s.client.BeginTx(d.ctx, d.opts) if err != nil { return err } @@ -92,7 +92,7 @@ func (d *DTC) Rollback() []error { for _, s := range d.sessions { if s.committed { for _, revert := range s.revert { - if err := revert(d.ctx, s.conn); err != nil { + if err := revert(d.ctx, s.client); err != nil { errs = append(errs, err) } } diff --git a/dtc_test.go b/dtc_test.go index 0d308aa..2459c1c 100644 --- a/dtc_test.go +++ b/dtc_test.go @@ -30,20 +30,20 @@ func TestDTCWithDB(t *testing.T) { setup: func() *DTC { dtc := NewDTC(context.Background(), nil) - dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 1, "1@mail.com") + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 1, "1@mail.com") return err }, nil) - dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 2, "2@mail.com") + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 2, "2@mail.com") return err }, nil) - dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 3, "3@mail.com") + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 3, "3@mail.com") return err }, nil) @@ -69,20 +69,20 @@ func TestDTCWithDB(t *testing.T) { name: "multiple_txs_rollback_should_work", setup: func() *DTC { dtc := NewDTC(context.Background(), nil) - dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 11, "1@mail.com") + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 11, "1@mail.com") return err }, nil) - dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 12, "2@mail.com") + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 12, "2@mail.com") return err }, nil) - dtc.Prepare(db.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 13) + dtc.Prepare(db.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 13) return err }, nil) @@ -154,20 +154,20 @@ func TestDTCWithDBs(t *testing.T) { setup: func() *DTC { dtc := NewDTC(context.Background(), nil) - dtc.Prepare(db1.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_1` (`id`,`email`) VALUES(?,?)", 1, "1@mail.com") + dtc.Prepare(db1.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1` (`id`,`email`) VALUES(?,?)", 1, "1@mail.com") return err }, nil) - dtc.Prepare(db2.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_2` (`id`,`email`) VALUES(?,?)", 2, "2@mail.com") + dtc.Prepare(db2.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_2` (`id`,`email`) VALUES(?,?)", 2, "2@mail.com") return err }, nil) - dtc.Prepare(db3.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_3` (`id`,`email`) VALUES(?,?)", 3, "3@mail.com") + dtc.Prepare(db3.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_3` (`id`,`email`) VALUES(?,?)", 3, "3@mail.com") return err }, nil) @@ -193,20 +193,20 @@ func TestDTCWithDBs(t *testing.T) { name: "multiple_txs_rollback_should_work", setup: func() *DTC { dtc := NewDTC(context.Background(), nil) - dtc.Prepare(db1.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 11, "1@mail.com") + dtc.Prepare(db1.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1`(`id`,`email`) VALUES(?,?)", 11, "1@mail.com") return err }, nil) - dtc.Prepare(db2.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_2`(`id`,`email`) VALUES(?,?)", 12, "2@mail.com") + dtc.Prepare(db2.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_2`(`id`,`email`) VALUES(?,?)", 12, "2@mail.com") return err }, nil) - dtc.Prepare(db3.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_3`(`id`,`email`) VALUES(?,?)", 13) + dtc.Prepare(db3.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_3`(`id`,`email`) VALUES(?,?)", 13) return err }, nil) @@ -270,16 +270,16 @@ func TestDTCRevert(t *testing.T) { dtc := NewDTC(context.Background(), nil) - dtc.Prepare(db1.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_1` (`id`,`email`) VALUES(?,?)", 1, "1@mail.com") + dtc.Prepare(db1.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_1` (`id`,`email`) VALUES(?,?)", 1, "1@mail.com") return err }, func(ctx context.Context, c Connector) error { _, err := c.Exec("DELETE FROM `dtc_1` WHERE id=?", 1) return err }) - dtc.Prepare(db2.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_2` (`id`,`email`) VALUES(?,?)", 2, "2@mail.com") + dtc.Prepare(db2.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_2` (`id`,`email`) VALUES(?,?)", 2, "2@mail.com") return err }, func(ctx context.Context, c Connector) error { @@ -287,8 +287,8 @@ func TestDTCRevert(t *testing.T) { return err }) - dtc.Prepare(db3.dbs[0], func(ctx context.Context, c Connector) error { - _, err := c.Exec("INSERT INTO `dtc_3` (`id`,`email`) VALUES(?,?)", 3, "3@mail.com") + dtc.Prepare(db3.dbs[0], func(ctx context.Context, conn Connector) error { + _, err := conn.Exec("INSERT INTO `dtc_3` (`id`,`email`) VALUES(?,?)", 3, "3@mail.com") return err }, func(ctx context.Context, c Connector) error { diff --git a/queryer_mapr.go b/queryer_mapr.go index d08e8df..bf6c00d 100644 --- a/queryer_mapr.go +++ b/queryer_mapr.go @@ -11,7 +11,7 @@ import ( // MapR is a Map/Reduce Query Provider based on databases. type MapR[T any] struct { - dbs []*Conn + dbs []*Client } // First executes the query and returns the first result. @@ -28,7 +28,7 @@ func (q *MapR[T]) First(ctx context.Context, rotatedTables []string, b *Builder) for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { - w.Add(func(db *Conn, qr string) func(context.Context) (T, error) { + w.Add(func(db *Client, qr string) func(context.Context) (T, error) { return func(ctx context.Context) (T, error) { var t T err := db.QueryRowContext(ctx, qr, args...).Bind(&t) @@ -59,7 +59,7 @@ func (q *MapR[T]) Count(ctx context.Context, rotatedTables []string, b *Builder) for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { - w.Add(func(db *Conn, qr string) func(context.Context) (int64, error) { + w.Add(func(db *Client, qr string) func(context.Context) (int64, error) { return func(ctx context.Context) (int64, error) { var i int64 err := db.QueryRowContext(ctx, qr, args...).Scan(&i) @@ -102,7 +102,7 @@ func (q *MapR[T]) Query(ctx context.Context, rotatedTables []string, b *Builder, for _, r := range rotatedTables { qr := strings.ReplaceAll(query, "", r) for _, db := range q.dbs { - w.Add(func(db *Conn, qr string) func(context.Context) ([]T, error) { + w.Add(func(db *Client, qr string) func(context.Context) ([]T, error) { return func(context.Context) ([]T, error) { var t []T rows, err := db.QueryContext(ctx, qr, args...) From 119d86058c60b41a545aa836ec1fbe153f1a2368 Mon Sep 17 00:00:00 2001 From: Lz Date: Mon, 29 Apr 2024 08:43:53 +0800 Subject: [PATCH 4/4] chore(docs): updated CHANGELOG.md --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cfae1f..b85f01c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,15 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [1.5.0] ### Changed -- renamed `Context` with `Conn` (#44) +- !renamed `Context` with `Client` (#45) ### Added - added `Connector` interface (#43) - added nullable `Time` with better json support (#44) -- added `DistributedTransactions` (#44) +- added `DTC` (#45) ### Changed -- renamed `BitBool` with shorter name `Bool` (#44) +- !renamed `BitBool` with shorter name `Bool` (#44) ## [1.4.6] - 2014-04-23 ### Changed