Skip to content

Commit

Permalink
Add Multi-Results support
Browse files Browse the repository at this point in the history
  • Loading branch information
jszwec committed Jan 10, 2017
1 parent 2e00b5c commit ca9ddf5
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 36 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Hirotaka Yamamoto <ymmt2005 at gmail.com>
INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Joshua Prunier <joshua.prunier at gmail.com>
Expand Down
5 changes: 2 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,12 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
if err == nil {
rows := new(textRows)
rows.mc = mc

if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
}
// Columns
rows.columns, err = mc.readColumns(resLen)
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}
}
Expand All @@ -359,7 +358,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}

if resLen > 0 {
// Columns
Expand Down
100 changes: 100 additions & 0 deletions driver_18_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// +build go1.8

package mysql

import (
"reflect"
"testing"
)

func TestMultiResultSet(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
type result struct {
values [][]int
columns []string
}

expected := []result{
{
values: [][]int{{1, 2}, {3, 4}},
columns: []string{"col1", "col2"},
},
{
values: [][]int{{1, 2, 3}, {4, 5, 6}},
columns: []string{"col1", "col2", "col3"},
},
}

query := `
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
SELECT 0 UNION SELECT 1; -- ignore this result set
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`

rows := dbt.mustQuery(query)
defer rows.Close()

var res1 result
for rows.Next() {
var res [2]int
if err := rows.Scan(&res[0], &res[1]); err != nil {
dbt.Fatal(err)
}
res1.values = append(res1.values, res[:])
}

if rows.Next() {
dbt.Error("unexpected row")
}

cols, err := rows.Columns()
if err != nil {
dbt.Fatal(err)
}
res1.columns = cols

if !reflect.DeepEqual(expected[0], res1) {
dbt.Error("want =", expected[0], "got =", res1)
}

if !rows.NextResultSet() {
dbt.Fatal("expected next result set")
}

// ignoring one result set

if !rows.NextResultSet() {
dbt.Fatal("expected next result set")
}

var res2 result
cols, err = rows.Columns()
if err != nil {
dbt.Fatal(err)
}
res2.columns = cols

for rows.Next() {
var res [3]int
if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
dbt.Fatal(err)
}
res2.values = append(res2.values, res[:])
}

if !reflect.DeepEqual(expected[1], res2) {
dbt.Error("want =", expected[1], "got =", res2)
}

if rows.Next() {
dbt.Error("unexpected row")
}

if rows.NextResultSet() {
dbt.Error("unexpected next result set")
}

if err := rows.Err(); err != nil {
dbt.Error(err)
}
})
}
43 changes: 22 additions & 21 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
clientTransactions |
clientLocalFiles |
clientPluginAuth |
clientMultiStatements |
clientMultiResults |
mc.flags&clientLongFlag

Expand Down Expand Up @@ -698,6 +699,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc

if rows.rs.done {
return io.EOF
}

data, err := mc.readPacket()
if err != nil {
return err
Expand All @@ -707,15 +712,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
err = rows.mc.discardResults()
if err == nil {
err = io.EOF
} else {
// connection unusable
rows.mc.Close()
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
}
rows.mc = nil
return err
return io.EOF
}
if data[0] == iERR {
rows.mc = nil
Expand All @@ -736,7 +737,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
if !mc.parseTime {
continue
} else {
switch rows.columns[i].fieldType {
switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
Expand Down Expand Up @@ -1145,14 +1146,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
}

// Convert to byte-coded string
switch rows.columns[i].fieldType {
switch rows.rs.columns[i].fieldType {
case fieldTypeNULL:
dest[i] = nil
continue

// Numeric Types
case fieldTypeTiny:
if rows.columns[i].flags&flagUnsigned != 0 {
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(data[pos])
} else {
dest[i] = int64(int8(data[pos]))
Expand All @@ -1161,7 +1162,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue

case fieldTypeShort, fieldTypeYear:
if rows.columns[i].flags&flagUnsigned != 0 {
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
} else {
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
Expand All @@ -1170,7 +1171,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue

case fieldTypeInt24, fieldTypeLong:
if rows.columns[i].flags&flagUnsigned != 0 {
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
} else {
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
Expand All @@ -1179,7 +1180,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue

case fieldTypeLongLong:
if rows.columns[i].flags&flagUnsigned != 0 {
if rows.rs.columns[i].flags&flagUnsigned != 0 {
val := binary.LittleEndian.Uint64(data[pos : pos+8])
if val > math.MaxInt64 {
dest[i] = uint64ToString(val)
Expand Down Expand Up @@ -1233,37 +1234,37 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
case isNull:
dest[i] = nil
continue
case rows.columns[i].fieldType == fieldTypeTime:
case rows.rs.columns[i].fieldType == fieldTypeTime:
// database/sql does not support an equivalent to TIME, return a string
var dstlen uint8
switch decimals := rows.columns[i].decimals; decimals {
switch decimals := rows.rs.columns[i].decimals; decimals {
case 0x00, 0x1f:
dstlen = 8
case 1, 2, 3, 4, 5, 6:
dstlen = 8 + 1 + decimals
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
rows.columns[i].decimals,
rows.rs.columns[i].decimals,
)
}
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
case rows.mc.parseTime:
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
default:
var dstlen uint8
if rows.columns[i].fieldType == fieldTypeDate {
if rows.rs.columns[i].fieldType == fieldTypeDate {
dstlen = 10
} else {
switch decimals := rows.columns[i].decimals; decimals {
switch decimals := rows.rs.columns[i].decimals; decimals {
case 0x00, 0x1f:
dstlen = 19
case 1, 2, 3, 4, 5, 6:
dstlen = 19 + 1 + decimals
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
rows.columns[i].decimals,
rows.rs.columns[i].decimals,
)
}
}
Expand All @@ -1279,7 +1280,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {

// Please report if this happens!
default:
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
}
}

Expand Down
69 changes: 60 additions & 9 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@ type mysqlField struct {
decimals byte
}

type mysqlRows struct {
mc *mysqlConn
type resultSet struct {
columns []mysqlField
done bool
}

type mysqlRows struct {
mc *mysqlConn
rs resultSet
}

type binaryRows struct {
Expand All @@ -37,24 +42,24 @@ type textRows struct {
type emptyRows struct{}

func (rows *mysqlRows) Columns() []string {
columns := make([]string, len(rows.columns))
columns := make([]string, len(rows.rs.columns))
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
for i := range columns {
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
columns[i] = tableName + "." + rows.columns[i].name
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
columns[i] = tableName + "." + rows.rs.columns[i].name
} else {
columns[i] = rows.columns[i].name
columns[i] = rows.rs.columns[i].name
}
}
} else {
for i := range columns {
columns[i] = rows.columns[i].name
columns[i] = rows.rs.columns[i].name
}
}
return columns
}

func (rows *mysqlRows) Close() error {
func (rows *mysqlRows) Close() (err error) {
mc := rows.mc
if mc == nil {
return nil
Expand All @@ -64,7 +69,9 @@ func (rows *mysqlRows) Close() error {
}

// Remove unread packets from stream
err := mc.readUntilEOF()
if !rows.rs.done {
err = mc.readUntilEOF()
}
if err == nil {
if err = mc.discardResults(); err != nil {
return err
Expand Down Expand Up @@ -99,6 +106,42 @@ func (rows *textRows) Next(dest []driver.Value) error {
return io.EOF
}

func (rows *textRows) HasNextResultSet() (b bool) {
if rows.mc == nil {
return false
}
return rows.mc.status&statusMoreResultsExists != 0
}

func (rows *textRows) NextResultSet() error {
if rows.mc == nil {
return io.EOF
}
if rows.mc.netConn == nil {
return ErrInvalidConn
}

// Remove unread packets from stream
if !rows.rs.done {
if err := rows.mc.readUntilEOF(); err != nil {
return err
}
}

if !rows.HasNextResultSet() {
return io.EOF
}
rows.rs = resultSet{}

resLen, err := rows.mc.readResultSetHeaderPacket()
if err != nil {
return err
}

rows.rs.columns, err = rows.mc.readColumns(resLen)
return err
}

func (rows emptyRows) Columns() []string {
return nil
}
Expand All @@ -110,3 +153,11 @@ func (rows emptyRows) Close() error {
func (rows emptyRows) Next(dest []driver.Value) error {
return io.EOF
}

func (rows emptyRows) HasNextResultSet() bool {
return false
}

func (rows emptyRows) NextResultSet() error {
return io.EOF
}
Loading

0 comments on commit ca9ddf5

Please sign in to comment.