From 0921bbd86e7d2a9cd83571a6fcbf6d189962d5b7 Mon Sep 17 00:00:00 2001 From: jszwec Date: Sun, 8 Jan 2017 20:46:50 -0500 Subject: [PATCH] Add Multi-Results support Fixes #420 --- AUTHORS | 1 + connection.go | 11 ++---- driver_go18_test.go | 93 +++++++++++++++++++++++++++++++++++++++++++++ packets.go | 43 +++++++++++---------- rows.go | 59 ++++++++++++++++++++++++---- statement.go | 10 +++-- 6 files changed, 178 insertions(+), 39 deletions(-) create mode 100644 driver_go18_test.go diff --git a/AUTHORS b/AUTHORS index 100370758..987bc2cfb 100644 --- a/AUTHORS +++ b/AUTHORS @@ -25,6 +25,7 @@ Hanno Braun Henri Yandell Hirotaka Yamamoto INADA Naoki +Jacek Szwec James Harr Jian Zhen Joshua Prunier diff --git a/connection.go b/connection.go index d82c728f3..5ebc7ca84 100644 --- a/connection.go +++ b/connection.go @@ -331,15 +331,13 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro var resLen int resLen, err = mc.readResultSetHeaderPacket() if err == nil { - rows := new(textRows) - rows.mc = mc - + rows := newTextRows(mc) if resLen == 0 { // no columns, no more data return emptyRows{}, nil } // Columns - rows.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(resLen) return rows, err } } @@ -357,9 +355,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() if err == nil { - rows := new(textRows) - rows.mc = mc - rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} + rows := newTextRows(mc) + rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns diff --git a/driver_go18_test.go b/driver_go18_test.go new file mode 100644 index 000000000..1a712d8fd --- /dev/null +++ b/driver_go18_test.go @@ -0,0 +1,93 @@ +// +build go1.8 + +package mysql + +import ( + "reflect" + "testing" +) + +func TestMultiResultSet(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + type result struct { + values [][]int + columns []string + } + + expected := []result{ + { + values: [][]int{{1, 2}, {3, 4}}, + columns: []string{"col1", "col2"}, + }, + { + values: [][]int{{1, 2, 3}, {4, 5, 6}}, + columns: []string{"col1", "col2", "col3"}, + }, + } + + query := ` +SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; +SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6` + + rows := dbt.mustQuery(query) + defer rows.Close() + + var res1 result + for rows.Next() { + var res [2]int + if err := rows.Scan(&res[0], &res[1]); err != nil { + dbt.Fatal(err) + } + res1.values = append(res1.values, res[:]) + } + + if rows.Next() { + dbt.Error("unexpected row") + } + + cols, err := rows.Columns() + if err != nil { + dbt.Fatal(err) + } + res1.columns = cols + + if !reflect.DeepEqual(expected[0], res1) { + dbt.Error("want =", expected[0], "got =", res1) + } + + if !rows.NextResultSet() { + dbt.Fatal("expected next result set") + } + + var res2 result + cols, err = rows.Columns() + if err != nil { + dbt.Fatal(err) + } + res2.columns = cols + + for rows.Next() { + var res [3]int + if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil { + dbt.Fatal(err) + } + res2.values = append(res2.values, res[:]) + } + + if !reflect.DeepEqual(expected[1], res2) { + dbt.Error("want =", expected[1], "got =", res2) + } + + if rows.Next() { + dbt.Error("unexpected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error(err) + } + }) +} diff --git a/packets.go b/packets.go index aafe9793e..354303c84 100644 --- a/packets.go +++ b/packets.go @@ -231,6 +231,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientTransactions | clientLocalFiles | clientPluginAuth | + clientMultiStatements | clientMultiResults | mc.flags&clientLongFlag @@ -698,6 +699,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc + if rows.rs.done { + return io.EOF + } + data, err := mc.readPacket() if err != nil { return err @@ -707,15 +712,11 @@ func (rows *textRows) readRow(dest []driver.Value) error { if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) - err = rows.mc.discardResults() - if err == nil { - err = io.EOF - } else { - // connection unusable - rows.mc.Close() + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil } - rows.mc = nil - return err + return io.EOF } if data[0] == iERR { rows.mc = nil @@ -736,7 +737,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { if !mc.parseTime { continue } else { - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( @@ -1145,14 +1146,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } // Convert to byte-coded string - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeNULL: dest[i] = nil continue // Numeric Types case fieldTypeTiny: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) @@ -1161,7 +1162,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeShort, fieldTypeYear: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) @@ -1170,7 +1171,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeInt24, fieldTypeLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) @@ -1179,7 +1180,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeLongLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) @@ -1233,10 +1234,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case isNull: dest[i] = nil continue - case rows.columns[i].fieldType == fieldTypeTime: + case rows.rs.columns[i].fieldType == fieldTypeTime: // database/sql does not support an equivalent to TIME, return a string var dstlen uint8 - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 8 case 1, 2, 3, 4, 5, 6: @@ -1244,7 +1245,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) @@ -1252,10 +1253,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 - if rows.columns[i].fieldType == fieldTypeDate { + if rows.rs.columns[i].fieldType == fieldTypeDate { dstlen = 10 } else { - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 19 case 1, 2, 3, 4, 5, 6: @@ -1263,7 +1264,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } } @@ -1279,7 +1280,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // Please report if this happens! default: - return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) + return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) } } diff --git a/rows.go b/rows.go index c08255eee..e3a6f687a 100644 --- a/rows.go +++ b/rows.go @@ -21,9 +21,14 @@ type mysqlField struct { decimals byte } -type mysqlRows struct { - mc *mysqlConn +type resultSet struct { columns []mysqlField + done bool +} + +type mysqlRows struct { + mc *mysqlConn + rs *resultSet } type binaryRows struct { @@ -34,21 +39,33 @@ type textRows struct { mysqlRows } +func newTextRows(mc *mysqlConn) *textRows { + return &textRows{ + mysqlRows{ + mc: mc, + rs: new(resultSet), + }, + } +} + type emptyRows struct{} func (rows *mysqlRows) Columns() []string { - columns := make([]string, len(rows.columns)) + if rows.rs == nil { + return []string{} + } + columns := make([]string, len(rows.rs.columns)) if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { for i := range columns { - if tableName := rows.columns[i].tableName; len(tableName) > 0 { - columns[i] = tableName + "." + rows.columns[i].name + if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.rs.columns[i].name } else { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } } else { for i := range columns { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } return columns @@ -99,6 +116,26 @@ func (rows *textRows) Next(dest []driver.Value) error { return io.EOF } +func (rows *textRows) HasNextResultSet() (b bool) { + if rows.mc == nil { + return false + } + return rows.mc.status&statusMoreResultsExists != 0 +} + +func (rows *textRows) NextResultSet() error { + if !rows.HasNextResultSet() { + return io.EOF + } + rows.rs = new(resultSet) + resLen, err := rows.mc.readResultSetHeaderPacket() + if err != nil { + return err + } + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + func (rows emptyRows) Columns() []string { return nil } @@ -110,3 +147,11 @@ func (rows emptyRows) Close() error { func (rows emptyRows) Next(dest []driver.Value) error { return io.EOF } + +func (rows emptyRows) HasNextResultSet() bool { + return false +} + +func (rows emptyRows) NextResultSet() error { + return io.EOF +} diff --git a/statement.go b/statement.go index 7f9b04585..0ba533550 100644 --- a/statement.go +++ b/statement.go @@ -103,17 +103,19 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, err } - rows := new(binaryRows) + rows := &binaryRows{ + mysqlRows{rs: new(resultSet)}, + } if resLen > 0 { rows.mc = mc // Columns // If not cached, read them and cache them if stmt.columns == nil { - rows.columns, err = mc.readColumns(resLen) - stmt.columns = rows.columns + rows.rs.columns, err = mc.readColumns(resLen) + stmt.columns = rows.rs.columns } else { - rows.columns = stmt.columns + rows.rs.columns = stmt.columns err = mc.readUntilEOF() } }