From 19c5721ef2a87b5c860a543d800625f93bb32d39 Mon Sep 17 00:00:00 2001 From: sm Date: Wed, 13 Sep 2023 09:54:50 +0200 Subject: [PATCH] v1.5.2 --- RELEASENOTES.md | 3 +++ driver/connection.go | 41 ++++++++++++++++++++++++++++------------- driver/driver.go | 2 +- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index c955df82..8a4ed640 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -5,6 +5,9 @@ Release Notes ### Minor revisions +#### v1.5.2 +- fixed race condition in connection conn and stmt methods in case of context cancelling + #### v1.5.1 - updated dependencies diff --git a/driver/connection.go b/driver/connection.go index fde82887..678ced74 100644 --- a/driver/connection.go +++ b/driver/connection.go @@ -417,7 +417,7 @@ func (c *conn) isBad() bool { return errors.Is(c.lastError, driver.ErrBadConn) } func (c *conn) IsValid() bool { return !c.isBad() } // Ping implements the driver.Pinger interface. -func (c *conn) Ping(ctx context.Context) (err error) { +func (c *conn) Ping(ctx context.Context) error { if c.isBad() { return driver.ErrBadConn } @@ -428,6 +428,7 @@ func (c *conn) Ping(ctx context.Context) (err error) { } done := make(chan struct{}) + var err error go func() { _, err = c._queryDirect(ctx, dummyQuery, !c.inTx) close(done) @@ -444,7 +445,7 @@ func (c *conn) Ping(ctx context.Context) (err error) { } // PrepareContext implements the driver.ConnPrepareContext interface. -func (c *conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { +func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if c.isBad() { return nil, driver.ErrBadConn } @@ -455,6 +456,8 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (stmt driver.St } done := make(chan struct{}) + var stmt driver.Stmt + var err error go func() { var pr *prepareResult @@ -487,7 +490,7 @@ func (c *conn) Close() error { } // BeginTx implements the driver.ConnBeginTx interface. -func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { +func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if c.isBad() { return nil, driver.ErrBadConn } @@ -501,6 +504,8 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx } done := make(chan struct{}) + var tx driver.Tx + var err error go func() { // set isolation level query := strings.Join([]string{setIsolationLevel, level}, " ") @@ -532,7 +537,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx var callStmt = regexp.MustCompile(`(?i)^\s*call\s+.*`) // sql statement beginning with call // QueryContext implements the driver.QueryerContext interface. -func (c *conn) QueryContext(ctx context.Context, query string, nvargs []driver.NamedValue) (rows driver.Rows, err error) { +func (c *conn) QueryContext(ctx context.Context, query string, nvargs []driver.NamedValue) (driver.Rows, error) { if c.isBad() { return nil, driver.ErrBadConn } @@ -549,6 +554,8 @@ func (c *conn) QueryContext(ctx context.Context, query string, nvargs []driver.N } done := make(chan struct{}) + var rows driver.Rows + var err error go func() { rows, err = c._queryDirect(ctx, query, !c.inTx) close(done) @@ -565,7 +572,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, nvargs []driver.N } // ExecContext implements the driver.ExecerContext interface. -func (c *conn) ExecContext(ctx context.Context, query string, nvargs []driver.NamedValue) (r driver.Result, err error) { +func (c *conn) ExecContext(ctx context.Context, query string, nvargs []driver.NamedValue) (driver.Result, error) { if c.isBad() { return nil, driver.ErrBadConn } @@ -579,9 +586,11 @@ func (c *conn) ExecContext(ctx context.Context, query string, nvargs []driver.Na } done := make(chan struct{}) + var result driver.Result + var err error go func() { // handle procesure call without parameters here as well - r, err = c._execDirect(ctx, query, !c.inTx) + result, err = c._execDirect(ctx, query, !c.inTx) close(done) }() @@ -591,7 +600,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, nvargs []driver.Na return nil, ctx.Err() case <-done: c.lastError = err - return r, err + return result, err } } @@ -614,12 +623,14 @@ func (c *conn) HDBVersion() *Version { return c.hdbVersion } func (c *conn) DatabaseName() string { return c._databaseName() } // DBConnectInfo implements the Conn interface. -func (c *conn) DBConnectInfo(ctx context.Context, databaseName string) (ci *DBConnectInfo, err error) { +func (c *conn) DBConnectInfo(ctx context.Context, databaseName string) (*DBConnectInfo, error) { if c.isBad() { return nil, driver.ErrBadConn } done := make(chan struct{}) + var ci *DBConnectInfo + var err error go func() { ci, err = c._dbConnectInfo(ctx, databaseName) close(done) @@ -726,7 +737,7 @@ func (s *stmt) Close() error { return c._dropStatementID(context.Background(), s.pr.stmtID) } -func (s *stmt) QueryContext(ctx context.Context, nvargs []driver.NamedValue) (rows driver.Rows, err error) { +func (s *stmt) QueryContext(ctx context.Context, nvargs []driver.NamedValue) (driver.Rows, error) { c := s.conn if c.isBad() { return nil, driver.ErrBadConn @@ -741,6 +752,8 @@ func (s *stmt) QueryContext(ctx context.Context, nvargs []driver.NamedValue) (ro } done := make(chan struct{}) + var rows driver.Rows + var err error go func() { rows, err = s.conn._query(ctx, s.pr, nvargs, !s.conn.inTx) close(done) @@ -756,7 +769,7 @@ func (s *stmt) QueryContext(ctx context.Context, nvargs []driver.NamedValue) (ro } } -func (s *stmt) ExecContext(ctx context.Context, nvargs []driver.NamedValue) (r driver.Result, err error) { +func (s *stmt) ExecContext(ctx context.Context, nvargs []driver.NamedValue) (driver.Result, error) { c := s.conn if c.isBad() { @@ -772,11 +785,13 @@ func (s *stmt) ExecContext(ctx context.Context, nvargs []driver.NamedValue) (r d } done := make(chan struct{}) + var result driver.Result + var err error go func() { if s.pr.isProcedureCall() { - r, s.rows, err = s.conn._execCall(ctx, s.pr, nvargs) + result, s.rows, err = s.conn._execCall(ctx, s.pr, nvargs) } else { - r, err = s.exec(ctx, nvargs) + result, err = s.exec(ctx, nvargs) } close(done) }() @@ -787,7 +802,7 @@ func (s *stmt) ExecContext(ctx context.Context, nvargs []driver.NamedValue) (r d return nil, ctx.Err() case <-done: c.lastError = err - return r, err + return result, err } } diff --git a/driver/driver.go b/driver/driver.go index aabc2d1d..725b54ea 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -10,7 +10,7 @@ import ( ) // DriverVersion is the version number of the hdb driver. -const DriverVersion = "1.5.1" +const DriverVersion = "1.5.2" // DriverName is the driver name to use with sql.Open for hdb databases. const DriverName = "hdb"