Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-statement support #4486

Merged
merged 2 commits into from
Dec 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 91 additions & 55 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"vitess.io/vitess/go/sync2"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
)

const (
Expand Down Expand Up @@ -723,65 +724,28 @@ func (c *Conn) handleNextCommand(handler Handler) error {
queryStart := time.Now()
query := c.parseComQuery(data)
c.recycleReadPacket()
fieldSent := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false

err := handler.ComQuery(c, query, func(qr *sqltypes.Result) error {
if sendFinished {
// Failsafe: Unreachable if server is well-behaved.
return io.EOF
}

if !fieldSent {
fieldSent = true

if len(qr.Fields) == 0 {
sendFinished = true

// A successful callback with no fields means that this was a
// DML or other write-only operation.
//
// We should not send any more packets after this, but make sure
// to extract the affected rows and last insert id from the result
// struct here since clients expect it.
return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, handler.WarningCount(c))
}
if err := c.writeFields(qr); err != nil {
return err
var queries []string
if c.Capabilities&CapabilityClientMultiStatements != 0 {
queries, err = sqlparser.SplitStatementToPieces(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Conn %v: Error writing query error: %v", c, werr)
return werr
}
}

return c.writeRows(qr)
})

// If no field was sent, we expect an error.
if !fieldSent {
// This is just a failsafe. Should never happen.
if err == nil || err == io.EOF {
err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
}
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return werr
}
} else {
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
return err
queries = []string{query}
}
for index, sql := range queries {
more := false
if index != len(queries)-1 {
more = true
}

// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return err
}
if err := c.execQuery(sql, handler, more); err != nil {
return err
}
}

Expand All @@ -807,7 +771,9 @@ func (c *Conn) handleNextCommand(handler Handler) error {
}
}
case ComSetOption:
if operation, ok := c.parseComSetOption(data); ok {
operation, ok := c.parseComSetOption(data)
c.recycleReadPacket()
if ok {
switch operation {
case 0:
c.Capabilities |= CapabilityClientMultiStatements
Expand Down Expand Up @@ -843,6 +809,76 @@ func (c *Conn) handleNextCommand(handler Handler) error {
return nil
}

func (c *Conn) execQuery(query string, handler Handler, more bool) error {
fieldSent := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false

err := handler.ComQuery(c, query, func(qr *sqltypes.Result) error {
flag := c.StatusFlags
if more {
flag |= ServerMoreResultsExists
}
if sendFinished {
// Failsafe: Unreachable if server is well-behaved.
return io.EOF
}

if !fieldSent {
fieldSent = true

if len(qr.Fields) == 0 {
sendFinished = true

// A successful callback with no fields means that this was a
// DML or other write-only operation.
//
// We should not send any more packets after this, but make sure
// to extract the affected rows and last insert id from the result
// struct here since clients expect it.
return c.writeOKPacket(qr.RowsAffected, qr.InsertID, flag, handler.WarningCount(c))
}
if err := c.writeFields(qr); err != nil {
return err
}
}

return c.writeRows(qr)
})

// If no field was sent, we expect an error.
if !fieldSent {
// This is just a failsafe. Should never happen.
if err == nil || err == io.EOF {
err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
}
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return werr
}
} else {
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
return err
}

// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return err
}
}
}

return nil
}

//
// Packet parsing methods, for generic packets.
//
Expand Down
26 changes: 26 additions & 0 deletions test/mysql_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ class TestMySQL(unittest.TestCase):
"""This test makes sure the MySQL server connector is correct.
"""

MYSQL_OPTION_MULTI_STATEMENTS_ON = 0
MYSQL_OPTION_MULTI_STATEMENTS_OFF = 1

def test_mysql_connector(self):
with open(table_acl_config, 'w') as fd:
fd.write("""{
Expand Down Expand Up @@ -160,6 +163,29 @@ def test_mysql_connector(self):
cursor.execute('select * from vt_insert_test', {})
cursor.close()

# Test multi-statement support. It should only work when
# COM_SET_OPTION has set the options to 0
conn.set_server_option(self.MYSQL_OPTION_MULTI_STATEMENTS_ON)
cursor = conn.cursor()
cursor.execute("select 1; select 2")
self.assertEquals(((1L,),), cursor.fetchall())
self.assertEquals(1, cursor.nextset())
self.assertEquals(((2L,),), cursor.fetchall())
self.assertEquals(None, cursor.nextset())
cursor.close()
conn.set_server_option(self.MYSQL_OPTION_MULTI_STATEMENTS_OFF)

# Multi-statement support should not work without the
# option enabled
cursor = conn.cursor()
try:
cursor.execute("select 1; select 2")
self.fail('Execute went through')
except MySQLdb.OperationalError, e:
s = str(e)
self.assertIn('syntax error', s)
cursor.close()

# verify that queries work end-to-end with large grpc messages
largeComment = 'L' * ((4 * 1024 * 1024) + 1)
cursor = conn.cursor()
Expand Down