From ddb99717ca7f59f1678a241f0ab22e253868bd79 Mon Sep 17 00:00:00 2001 From: Daniel Theophanes Date: Wed, 13 Mar 2019 18:32:52 -0700 Subject: [PATCH] pq: support returning multiple result sets from a single query Tested with Cockroach v2.1.6 and PostgreSQL 11. --- conn.go | 73 ++++++++++++++++++++++++++++++---------------------- conn_test.go | 46 +++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 31 deletions(-) diff --git a/conn.go b/conn.go index 43c8df29..a12e6e66 100644 --- a/conn.go +++ b/conn.go @@ -704,7 +704,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it res = &rows{cn: cn} - res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) + res.rowsHeader = parsePortalRowDescribe(r) // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. @@ -861,17 +861,15 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { cn.readParseResponse() cn.readBindResponse() rows := &rows{cn: cn} - rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() + rows.rowsHeader = cn.readPortalDescribeResponse() cn.postExecuteWorkaround() return rows, nil } st := cn.prepareTo(query, "") st.exec(args) return &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, + cn: cn, + rowsHeader: st.rowsHeader, }, nil } @@ -1180,12 +1178,10 @@ var colFmtDataAllBinary = []byte{0, 1, 0, 1} var colFmtDataAllText = []byte{0, 0} type stmt struct { - cn *conn - name string - colNames []string - colFmts []format + cn *conn + name string + rowsHeader colFmtData []byte - colTyps []fieldDesc paramTyps []oid.Oid closed bool } @@ -1231,10 +1227,8 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { st.exec(v) return &rows{ - cn: st.cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, + cn: st.cn, + rowsHeader: st.rowsHeader, }, nil } @@ -1344,16 +1338,22 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { return driver.RowsAffected(n), commandTag } -type rows struct { - cn *conn - finish func() +type rowsHeader struct { colNames []string colTyps []fieldDesc colFmts []format - done bool - rb readBuf - result driver.Result - tag string +} + +type rows struct { + cn *conn + finish func() + rowsHeader + done bool + rb readBuf + result driver.Result + tag string + + next *rowsHeader } func (rs *rows) Close() error { @@ -1440,7 +1440,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } return case 'T': - rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb) + next := parsePortalRowDescribe(&rs.rb) + rs.next = &next return io.EOF default: errorf("unexpected message after execute: %q", t) @@ -1449,10 +1450,16 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } func (rs *rows) HasNextResultSet() bool { - return !rs.done + hasNext := rs.next != nil && !rs.done + return hasNext } func (rs *rows) NextResultSet() error { + if rs.next == nil { + return io.EOF + } + rs.rowsHeader = *rs.next + rs.next = nil return nil } @@ -1630,13 +1637,13 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ } } -func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) { +func (cn *conn) readPortalDescribeResponse() rowsHeader { t, r := cn.recv1() switch t { case 'T': return parsePortalRowDescribe(r) case 'n': - return nil, nil, nil + return rowsHeader{} case 'E': err := parseError(r) cn.readReadyForQuery() @@ -1742,11 +1749,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe return } -func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) { +func parsePortalRowDescribe(r *readBuf) rowsHeader { n := r.int16() - colNames = make([]string, n) - colFmts = make([]format, n) - colTyps = make([]fieldDesc, n) + colNames := make([]string, n) + colFmts := make([]format, n) + colTyps := make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) @@ -1755,7 +1762,11 @@ func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, co colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } - return + return rowsHeader{ + colNames: colNames, + colFmts: colFmts, + colTyps: colTyps, + } } // parseEnviron tries to mimic some of libpq's environment handling diff --git a/conn_test.go b/conn_test.go index e654b85b..0eba3e5a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1657,3 +1657,49 @@ func TestQuickClose(t *testing.T) { t.Fatal(err) } } + +func TestMultipleResult(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query(` + begin; + select * from information_schema.tables limit 1; + select * from information_schema.columns limit 2; + commit; + `) + if err != nil { + t.Fatal(err) + } + type set struct { + cols []string + rowCount int + } + buf := []*set{} + for { + cols, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + s := &set{ + cols: cols, + } + buf = append(buf, s) + + for rows.Next() { + s.rowCount++ + } + if !rows.NextResultSet() { + break + } + } + if len(buf) != 2 { + t.Fatalf("got %d sets, expected 2", len(buf)) + } + if len(buf[0].cols) == len(buf[1].cols) || len(buf[1].cols) == 0 { + t.Fatal("invalid cols size, expected different column count and greater then zero") + } + if buf[0].rowCount != 1 || buf[1].rowCount != 2 { + t.Fatal("incorrect number of rows returned") + } +}