diff --git a/conn_go18.go b/conn_go18.go index 81c9ee47..55f3fd42 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -74,6 +74,18 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, return tx, nil } +func (cn *conn) Ping(ctx context.Context) error { + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + rows, err := cn.simpleQuery("SELECT 'lib/pq ping test';") + if err != nil { + return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger + } + rows.Close() + return nil +} + func (cn *conn) watchCancel(ctx context.Context) func() { if done := ctx.Done(); done != nil { finished := make(chan struct{}) diff --git a/go19_test.go b/go19_test.go new file mode 100644 index 00000000..1949249d --- /dev/null +++ b/go19_test.go @@ -0,0 +1,69 @@ +// +build go1.9 + +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "reflect" + "testing" +) + +func TestPing(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + db := openTestConn(t) + defer db.Close() + + if _, ok := reflect.TypeOf(db).MethodByName("Conn"); !ok { + t.Skipf("Conn method undefined on type %T, skipping test (requires at least go1.9)", db) + } + + if err := db.PingContext(ctx); err != nil { + t.Fatal("expected Ping to succeed") + } + defer cancel() + + // grab a connection + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + // start a transaction and read backend pid of our connection + tx, err := conn.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelDefault, + ReadOnly: true, + }) + if err != nil { + t.Fatal(err) + } + + rows, err := tx.Query("SELECT pg_backend_pid()") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + // read the pid from result + var pid int + for rows.Next() { + if err := rows.Scan(&pid); err != nil { + t.Fatal(err) + } + } + if rows.Err() != nil { + t.Fatal(err) + } + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + // kill the process which handles our connection and test if the ping fails + if _, err := db.Exec("SELECT pg_terminate_backend($1)", pid); err != nil { + t.Fatal(err) + } + if err := conn.PingContext(ctx); err != driver.ErrBadConn { + t.Fatalf("expected error %s, instead got %s", driver.ErrBadConn, err) + } +}