Skip to content

Commit

Permalink
Merge pull request #837 from kardianos/master
Browse files Browse the repository at this point in the history
pq: support returning multiple result sets from a single query
  • Loading branch information
maddyblue authored Mar 14, 2019
2 parents 9eb73ef + ddb9971 commit 7aad666
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 31 deletions.
73 changes: 42 additions & 31 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
res.rowsHeader = parsePortalRowDescribe(r)

// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
Expand Down Expand Up @@ -861,17 +861,15 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
cn.readParseResponse()
cn.readBindResponse()
rows := &rows{cn: cn}
rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
rows.rowsHeader = cn.readPortalDescribeResponse()
cn.postExecuteWorkaround()
return rows, nil
}
st := cn.prepareTo(query, "")
st.exec(args)
return &rows{
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
cn: cn,
rowsHeader: st.rowsHeader,
}, nil
}

Expand Down Expand Up @@ -1180,12 +1178,10 @@ var colFmtDataAllBinary = []byte{0, 1, 0, 1}
var colFmtDataAllText = []byte{0, 0}

type stmt struct {
cn *conn
name string
colNames []string
colFmts []format
cn *conn
name string
rowsHeader
colFmtData []byte
colTyps []fieldDesc
paramTyps []oid.Oid
closed bool
}
Expand Down Expand Up @@ -1231,10 +1227,8 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {

st.exec(v)
return &rows{
cn: st.cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
cn: st.cn,
rowsHeader: st.rowsHeader,
}, nil
}

Expand Down Expand Up @@ -1344,16 +1338,22 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
return driver.RowsAffected(n), commandTag
}

type rows struct {
cn *conn
finish func()
type rowsHeader struct {
colNames []string
colTyps []fieldDesc
colFmts []format
done bool
rb readBuf
result driver.Result
tag string
}

type rows struct {
cn *conn
finish func()
rowsHeader
done bool
rb readBuf
result driver.Result
tag string

next *rowsHeader
}

func (rs *rows) Close() error {
Expand Down Expand Up @@ -1440,7 +1440,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}
return
case 'T':
rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
next := parsePortalRowDescribe(&rs.rb)
rs.next = &next
return io.EOF
default:
errorf("unexpected message after execute: %q", t)
Expand All @@ -1449,10 +1450,16 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}

func (rs *rows) HasNextResultSet() bool {
return !rs.done
hasNext := rs.next != nil && !rs.done
return hasNext
}

func (rs *rows) NextResultSet() error {
if rs.next == nil {
return io.EOF
}
rs.rowsHeader = *rs.next
rs.next = nil
return nil
}

Expand Down Expand Up @@ -1630,13 +1637,13 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
}
}

func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) {
func (cn *conn) readPortalDescribeResponse() rowsHeader {
t, r := cn.recv1()
switch t {
case 'T':
return parsePortalRowDescribe(r)
case 'n':
return nil, nil, nil
return rowsHeader{}
case 'E':
err := parseError(r)
cn.readReadyForQuery()
Expand Down Expand Up @@ -1742,11 +1749,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe
return
}

func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) {
func parsePortalRowDescribe(r *readBuf) rowsHeader {
n := r.int16()
colNames = make([]string, n)
colFmts = make([]format, n)
colTyps = make([]fieldDesc, n)
colNames := make([]string, n)
colFmts := make([]format, n)
colTyps := make([]fieldDesc, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
Expand All @@ -1755,7 +1762,11 @@ func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, co
colTyps[i].Mod = r.int32()
colFmts[i] = format(r.int16())
}
return
return rowsHeader{
colNames: colNames,
colFmts: colFmts,
colTyps: colTyps,
}
}

// parseEnviron tries to mimic some of libpq's environment handling
Expand Down
46 changes: 46 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1657,3 +1657,49 @@ func TestQuickClose(t *testing.T) {
t.Fatal(err)
}
}

func TestMultipleResult(t *testing.T) {
db := openTestConn(t)
defer db.Close()

rows, err := db.Query(`
begin;
select * from information_schema.tables limit 1;
select * from information_schema.columns limit 2;
commit;
`)
if err != nil {
t.Fatal(err)
}
type set struct {
cols []string
rowCount int
}
buf := []*set{}
for {
cols, err := rows.Columns()
if err != nil {
t.Fatal(err)
}
s := &set{
cols: cols,
}
buf = append(buf, s)

for rows.Next() {
s.rowCount++
}
if !rows.NextResultSet() {
break
}
}
if len(buf) != 2 {
t.Fatalf("got %d sets, expected 2", len(buf))
}
if len(buf[0].cols) == len(buf[1].cols) || len(buf[1].cols) == 0 {
t.Fatal("invalid cols size, expected different column count and greater then zero")
}
if buf[0].rowCount != 1 || buf[1].rowCount != 2 {
t.Fatal("incorrect number of rows returned")
}
}

0 comments on commit 7aad666

Please sign in to comment.