Skip to content

Commit

Permalink
Add Multi-Results support
Browse files Browse the repository at this point in the history
  • Loading branch information
jszwec committed Jan 9, 2017
1 parent 2e00b5c commit f9f8661
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 39 deletions.
11 changes: 4 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,13 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc

rows := newTextRows(mc)
if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
}
// Columns
rows.columns, err = mc.readColumns(resLen)
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}
}
Expand All @@ -357,9 +355,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
rows := newTextRows(mc)
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}

if resLen > 0 {
// Columns
Expand Down
90 changes: 90 additions & 0 deletions driver_go18_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// +build go1.8

package mysql

import (
"reflect"
"testing"
)

func TestMultiResultSet(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
type result struct {
values [][]int
columns []string
}

expected := []result{
{
values: [][]int{{1, 2}, {3, 4}},
columns: []string{"col1", "col2"},
},
{
values: [][]int{{1, 2, 3}, {4, 5, 6}},
columns: []string{"col1", "col2", "col3"},
},
}

query := `
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6`

rows := dbt.mustQuery(query)
defer rows.Close()

var res1 result
for rows.Next() {
var res [2]int
if err := rows.Scan(&res[0], &res[1]); err != nil {
dbt.Fatal(err)
}
res1.values = append(res1.values, res[:])
}
if rows.Next() {
dbt.Error("unexpected row")
}
cols, err := rows.Columns()
if err != nil {
dbt.Fatal(err)
}
res1.columns = cols
if !reflect.DeepEqual(expected[0], res1) {
dbt.Error("want =", expected[0], "got =", res1)
}

if !rows.NextResultSet() {
dbt.Fatal("expected next result set")
}

var res2 result
cols, err = rows.Columns()
if err != nil {
dbt.Fatal(err)
}
res2.columns = cols

for rows.Next() {
var res [3]int
if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
dbt.Fatal(err)
}
res2.values = append(res2.values, res[:])
}

if !reflect.DeepEqual(expected[1], res2) {
dbt.Error("want =", expected[1], "got =", res2)
}

if rows.Next() {
dbt.Error("unexpected row")
}

if rows.NextResultSet() {
dbt.Error("unexpected next result set")
}

if err := rows.Err(); err != nil {
dbt.Error(err)
}
})
}
43 changes: 22 additions & 21 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
clientTransactions |
clientLocalFiles |
clientPluginAuth |
clientMultiStatements |
clientMultiResults |
mc.flags&clientLongFlag

Expand Down Expand Up @@ -698,6 +699,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc

if rows.rs.done {
return io.EOF
}

data, err := mc.readPacket()
if err != nil {
return err
Expand All @@ -707,15 +712,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
err = rows.mc.discardResults()
if err == nil {
err = io.EOF
} else {
// connection unusable
rows.mc.Close()
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
}
rows.mc = nil
return err
return io.EOF
}
if data[0] == iERR {
rows.mc = nil
Expand All @@ -736,7 +737,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
if !mc.parseTime {
continue
} else {
switch rows.columns[i].fieldType {
switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
Expand Down Expand Up @@ -1145,14 +1146,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
}

// Convert to byte-coded string
switch rows.columns[i].fieldType {
switch rows.rs.columns[i].fieldType {
case fieldTypeNULL:
dest[i] = nil
continue

// Numeric Types
case fieldTypeTiny:
if rows.columns[i].flags&flagUnsigned != 0 {
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(data[pos])
} else {
dest[i] = int64(int8(data[pos]))
Expand All @@ -1161,7 +1162,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue

case fieldTypeShort, fieldTypeYear:
if rows.columns[i].flags&flagUnsigned != 0 {
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
} else {
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
Expand All @@ -1170,7 +1171,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue

case fieldTypeInt24, fieldTypeLong:
if rows.columns[i].flags&flagUnsigned != 0 {
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
} else {
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
Expand All @@ -1179,7 +1180,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue

case fieldTypeLongLong:
if rows.columns[i].flags&flagUnsigned != 0 {
if rows.rs.columns[i].flags&flagUnsigned != 0 {
val := binary.LittleEndian.Uint64(data[pos : pos+8])
if val > math.MaxInt64 {
dest[i] = uint64ToString(val)
Expand Down Expand Up @@ -1233,37 +1234,37 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
case isNull:
dest[i] = nil
continue
case rows.columns[i].fieldType == fieldTypeTime:
case rows.rs.columns[i].fieldType == fieldTypeTime:
// database/sql does not support an equivalent to TIME, return a string
var dstlen uint8
switch decimals := rows.columns[i].decimals; decimals {
switch decimals := rows.rs.columns[i].decimals; decimals {
case 0x00, 0x1f:
dstlen = 8
case 1, 2, 3, 4, 5, 6:
dstlen = 8 + 1 + decimals
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
rows.columns[i].decimals,
rows.rs.columns[i].decimals,
)
}
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
case rows.mc.parseTime:
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
default:
var dstlen uint8
if rows.columns[i].fieldType == fieldTypeDate {
if rows.rs.columns[i].fieldType == fieldTypeDate {
dstlen = 10
} else {
switch decimals := rows.columns[i].decimals; decimals {
switch decimals := rows.rs.columns[i].decimals; decimals {
case 0x00, 0x1f:
dstlen = 19
case 1, 2, 3, 4, 5, 6:
dstlen = 19 + 1 + decimals
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
rows.columns[i].decimals,
rows.rs.columns[i].decimals,
)
}
}
Expand All @@ -1279,7 +1280,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {

// Please report if this happens!
default:
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
}
}

Expand Down
59 changes: 52 additions & 7 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@ type mysqlField struct {
decimals byte
}

type mysqlRows struct {
mc *mysqlConn
type resultSet struct {
columns []mysqlField
done bool
}

type mysqlRows struct {
mc *mysqlConn
rs *resultSet
}

type binaryRows struct {
Expand All @@ -34,21 +39,33 @@ type textRows struct {
mysqlRows
}

func newTextRows(mc *mysqlConn) *textRows {
return &textRows{
mysqlRows{
mc: mc,
rs: new(resultSet),
},
}
}

type emptyRows struct{}

func (rows *mysqlRows) Columns() []string {
columns := make([]string, len(rows.columns))
if rows.rs == nil {
return []string{}
}
columns := make([]string, len(rows.rs.columns))
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
for i := range columns {
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
columns[i] = tableName + "." + rows.columns[i].name
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
columns[i] = tableName + "." + rows.rs.columns[i].name
} else {
columns[i] = rows.columns[i].name
columns[i] = rows.rs.columns[i].name
}
}
} else {
for i := range columns {
columns[i] = rows.columns[i].name
columns[i] = rows.rs.columns[i].name
}
}
return columns
Expand Down Expand Up @@ -99,6 +116,26 @@ func (rows *textRows) Next(dest []driver.Value) error {
return io.EOF
}

func (rows *textRows) HasNextResultSet() (b bool) {
if rows.mc == nil {
return false
}
return rows.mc.status&statusMoreResultsExists != 0
}

func (rows *textRows) NextResultSet() error {
if !rows.HasNextResultSet() {
return io.EOF
}
rows.rs = new(resultSet)
resLen, err := rows.mc.readResultSetHeaderPacket()
if err != nil {
return err
}
rows.rs.columns, err = rows.mc.readColumns(resLen)
return err
}

func (rows emptyRows) Columns() []string {
return nil
}
Expand All @@ -110,3 +147,11 @@ func (rows emptyRows) Close() error {
func (rows emptyRows) Next(dest []driver.Value) error {
return io.EOF
}

func (rows emptyRows) HasNextResultSet() bool {
return false
}

func (rows emptyRows) NextResultSet() error {
return io.EOF
}
10 changes: 6 additions & 4 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,19 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, err
}

rows := new(binaryRows)
rows := &binaryRows{
mysqlRows{rs: new(resultSet)},
}

if resLen > 0 {
rows.mc = mc
// Columns
// If not cached, read them and cache them
if stmt.columns == nil {
rows.columns, err = mc.readColumns(resLen)
stmt.columns = rows.columns
rows.rs.columns, err = mc.readColumns(resLen)
stmt.columns = rows.rs.columns
} else {
rows.columns = stmt.columns
rows.rs.columns = stmt.columns
err = mc.readUntilEOF()
}
}
Expand Down

0 comments on commit f9f8661

Please sign in to comment.