Skip to content

Commit

Permalink
Merge pull request #1047 from michaelshobbs/feature/context-interfaces
Browse files Browse the repository at this point in the history
implement ConnPrepareContext/StmtQueryContext/StmtExecContext interfaces
  • Loading branch information
otan authored Sep 2, 2021
2 parents 2140507 + 9fa33e2 commit 8667c6b
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 2 deletions.
4 changes: 4 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,10 @@ func (st *stmt) Close() (err error) {
}

func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
return st.query(v)
}

func (st *stmt) query(v []driver.Value) (r *rows, err error) {
if st.cn.getBad() {
return nil, driver.ErrBadConn
}
Expand Down
79 changes: 78 additions & 1 deletion conn_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
"time"
)

const (
watchCancelDialContextTimeout = time.Second * 10
)

// Implement the "QueryerContext" interface
func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
list := make([]driver.Value, len(args))
Expand Down Expand Up @@ -43,6 +47,14 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam
return cn.Exec(query, list)
}

// Implement the "ConnPrepareContext" interface
func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if finish := cn.watchCancel(ctx); finish != nil {
defer finish()
}
return cn.Prepare(query)
}

// Implement the "ConnBeginTx" interface
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
var mode string
Expand Down Expand Up @@ -109,7 +121,7 @@ func (cn *conn) watchCancel(ctx context.Context) func() {
// so it must not be used for the additional network
// request to cancel the query.
// Create a new context to pass into the dial.
ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10)
ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
defer cancel()

_ = cn.cancel(ctxCancel)
Expand Down Expand Up @@ -172,3 +184,68 @@ func (cn *conn) cancel(ctx context.Context) error {
return err
}
}

// Implement the "StmtQueryContext" interface
func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}
finish := st.watchCancel(ctx)
r, err := st.query(list)
if err != nil {
if finish != nil {
finish()
}
return nil, err
}
r.finish = finish
return r, nil
}

// Implement the "StmtExecContext" interface
func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}

if finish := st.watchCancel(ctx); finish != nil {
defer finish()
}

return st.Exec(list)
}

// watchCancel is implemented on stmt in order to not mark the parent conn as bad
func (st *stmt) watchCancel(ctx context.Context) func() {
if done := ctx.Done(); done != nil {
finished := make(chan struct{})
go func() {
select {
case <-done:
// At this point the function level context is canceled,
// so it must not be used for the additional network
// request to cancel the query.
// Create a new context to pass into the dial.
ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
defer cancel()

_ = st.cancel(ctxCancel)
finished <- struct{}{}
case <-finished:
}
}()
return func() {
select {
case <-finished:
case finished <- struct{}{}:
}
}
}
return nil
}

func (st *stmt) cancel(ctx context.Context) error {
return st.cn.cancel(ctx)
}
164 changes: 164 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1806,3 +1806,167 @@ func TestCopyInStmtAffectedRows(t *testing.T) {
res.RowsAffected()
res.LastInsertId()
}

func TestConnPrepareContext(t *testing.T) {
db := openTestConn(t)
defer db.Close()

tests := []struct {
name string
ctx func() (context.Context, context.CancelFunc)
sql string
err error
}{
{
name: "context.Background",
ctx: func() (context.Context, context.CancelFunc) {
return context.Background(), nil
},
sql: "SELECT 1",
err: nil,
},
{
name: "context.WithTimeout exceeded",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Microsecond)
},
sql: "SELECT 1",
err: context.DeadlineExceeded,
},
{
name: "context.WithTimeout",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Minute)
},
sql: "SELECT 1",
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := tt.ctx()
if cancel != nil {
defer cancel()
}
_, err := db.PrepareContext(ctx, tt.sql)
switch {
case (err != nil) != (tt.err != nil):
t.Fatalf("conn.PrepareContext() unexpected nil err got = %v, expected = %v", err, tt.err)
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
t.Errorf("conn.PrepareContext() got = %v, expected = %v", err.Error(), tt.err.Error())
}
})
}
}

func TestStmtQueryContext(t *testing.T) {
db := openTestConn(t)
defer db.Close()

tests := []struct {
name string
ctx func() (context.Context, context.CancelFunc)
sql string
err error
}{
{
name: "context.Background",
ctx: func() (context.Context, context.CancelFunc) {
return context.Background(), nil
},
sql: "SELECT pg_sleep(1);",
err: nil,
},
{
name: "context.WithTimeout exceeded",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 1*time.Second)
},
sql: "SELECT pg_sleep(10);",
err: &Error{Message: "canceling statement due to user request"},
},
{
name: "context.WithTimeout",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Minute)
},
sql: "SELECT pg_sleep(1);",
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := tt.ctx()
if cancel != nil {
defer cancel()
}
stmt, err := db.PrepareContext(ctx, tt.sql)
if err != nil {
t.Fatal(err)
}
_, err = stmt.QueryContext(ctx)
switch {
case (err != nil) != (tt.err != nil):
t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, expected = %v", err, tt.err)
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
t.Errorf("stmt.QueryContext() got = %v, expected = %v", err.Error(), tt.err.Error())
}
})
}
}

func TestStmtExecContext(t *testing.T) {
db := openTestConn(t)
defer db.Close()

tests := []struct {
name string
ctx func() (context.Context, context.CancelFunc)
sql string
err error
}{
{
name: "context.Background",
ctx: func() (context.Context, context.CancelFunc) {
return context.Background(), nil
},
sql: "SELECT pg_sleep(1);",
err: nil,
},
{
name: "context.WithTimeout exceeded",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), 1*time.Second)
},
sql: "SELECT pg_sleep(10);",
err: &Error{Message: "canceling statement due to user request"},
},
{
name: "context.WithTimeout",
ctx: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Minute)
},
sql: "SELECT pg_sleep(1);",
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := tt.ctx()
if cancel != nil {
defer cancel()
}
stmt, err := db.PrepareContext(ctx, tt.sql)
if err != nil {
t.Fatal(err)
}
_, err = stmt.ExecContext(ctx)
switch {
case (err != nil) != (tt.err != nil):
t.Fatalf("stmt.ExecContext() unexpected nil err got = %v, expected = %v", err, tt.err)
case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()):
t.Errorf("stmt.ExecContext() got = %v, expected = %v", err.Error(), tt.err.Error())
}
})
}
}
36 changes: 35 additions & 1 deletion issues_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package pq

import "testing"
import (
"context"
"testing"
"time"
)

func TestIssue494(t *testing.T) {
db := openTestConn(t)
Expand All @@ -24,3 +28,33 @@ func TestIssue494(t *testing.T) {
t.Fatal("expected error")
}
}

func TestIssue1046(t *testing.T) {
ctxTimeout := time.Second * 2

db := openTestConn(t)
defer db.Close()

ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
defer cancel()

stmt, err := db.PrepareContext(ctx, `SELECT pg_sleep(10) AS id`)
if err != nil {
t.Fatal(err)
}

var d []uint8
err = stmt.QueryRowContext(ctx).Scan(&d)
dl, _ := ctx.Deadline()
since := time.Since(dl)
if since > ctxTimeout {
t.Logf("FAIL %s: query returned after context deadline: %v\n", t.Name(), since)
t.Fail()
}
expectedErr := &Error{Message: "canceling statement due to user request"}
if err == nil || err.Error() != expectedErr.Error() {
t.Logf("ctx.Err(): [%T]%+v\n", ctx.Err(), ctx.Err())
t.Logf("got err: [%T] %+v expected err: [%T] %+v", err, err, expectedErr, expectedErr)
t.Fail()
}
}

0 comments on commit 8667c6b

Please sign in to comment.