diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 8f24ae53555..2cf2ab99786 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -31,6 +31,7 @@ import ( "vitess.io/vitess/go/sync2" "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/sqlparser" ) const ( @@ -723,65 +724,28 @@ func (c *Conn) handleNextCommand(handler Handler) error { queryStart := time.Now() query := c.parseComQuery(data) c.recycleReadPacket() - fieldSent := false - // sendFinished is set if the response should just be an OK packet. - sendFinished := false - - err := handler.ComQuery(c, query, func(qr *sqltypes.Result) error { - if sendFinished { - // Failsafe: Unreachable if server is well-behaved. - return io.EOF - } - - if !fieldSent { - fieldSent = true - - if len(qr.Fields) == 0 { - sendFinished = true - // A successful callback with no fields means that this was a - // DML or other write-only operation. - // - // We should not send any more packets after this, but make sure - // to extract the affected rows and last insert id from the result - // struct here since clients expect it. - return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, handler.WarningCount(c)) - } - if err := c.writeFields(qr); err != nil { - return err + var queries []string + if c.Capabilities&CapabilityClientMultiStatements != 0 { + queries, err = sqlparser.SplitStatementToPieces(query) + if err != nil { + log.Errorf("Conn %v: Error splitting query: %v", c, err) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Conn %v: Error writing query error: %v", c, werr) + return werr } } - - return c.writeRows(qr) - }) - - // If no field was sent, we expect an error. - if !fieldSent { - // This is just a failsafe. Should never happen. - if err == nil || err == io.EOF { - err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) - } - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Error writing query error to %s: %v", c, werr) - return werr - } } else { - if err != nil { - // We can't send an error in the middle of a stream. - // All we can do is abort the send, which will cause a 2013. - log.Errorf("Error in the middle of a stream to %s: %v", c, err) - return err + queries = []string{query} + } + for index, sql := range queries { + more := false + if index != len(queries)-1 { + more = true } - - // Send the end packet only sendFinished is false (results were streamed). - // In this case the affectedRows and lastInsertID are always 0 since it - // was a read operation. - if !sendFinished { - if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { - log.Errorf("Error writing result to %s: %v", c, err) - return err - } + if err := c.execQuery(sql, handler, more); err != nil { + return err } } @@ -807,7 +771,9 @@ func (c *Conn) handleNextCommand(handler Handler) error { } } case ComSetOption: - if operation, ok := c.parseComSetOption(data); ok { + operation, ok := c.parseComSetOption(data) + c.recycleReadPacket() + if ok { switch operation { case 0: c.Capabilities |= CapabilityClientMultiStatements @@ -843,6 +809,76 @@ func (c *Conn) handleNextCommand(handler Handler) error { return nil } +func (c *Conn) execQuery(query string, handler Handler, more bool) error { + fieldSent := false + // sendFinished is set if the response should just be an OK packet. + sendFinished := false + + err := handler.ComQuery(c, query, func(qr *sqltypes.Result) error { + flag := c.StatusFlags + if more { + flag |= ServerMoreResultsExists + } + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } + + if !fieldSent { + fieldSent = true + + if len(qr.Fields) == 0 { + sendFinished = true + + // A successful callback with no fields means that this was a + // DML or other write-only operation. + // + // We should not send any more packets after this, but make sure + // to extract the affected rows and last insert id from the result + // struct here since clients expect it. + return c.writeOKPacket(qr.RowsAffected, qr.InsertID, flag, handler.WarningCount(c)) + } + if err := c.writeFields(qr); err != nil { + return err + } + } + + return c.writeRows(qr) + }) + + // If no field was sent, we expect an error. + if !fieldSent { + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) + } + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Error writing query error to %s: %v", c, werr) + return werr + } + } else { + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return err + } + + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return err + } + } + } + + return nil +} + // // Packet parsing methods, for generic packets. // diff --git a/test/mysql_server_test.py b/test/mysql_server_test.py index b723eb33ec7..d65380133ba 100755 --- a/test/mysql_server_test.py +++ b/test/mysql_server_test.py @@ -100,6 +100,9 @@ class TestMySQL(unittest.TestCase): """This test makes sure the MySQL server connector is correct. """ + MYSQL_OPTION_MULTI_STATEMENTS_ON = 0 + MYSQL_OPTION_MULTI_STATEMENTS_OFF = 1 + def test_mysql_connector(self): with open(table_acl_config, 'w') as fd: fd.write("""{ @@ -160,6 +163,29 @@ def test_mysql_connector(self): cursor.execute('select * from vt_insert_test', {}) cursor.close() + # Test multi-statement support. It should only work when + # COM_SET_OPTION has set the options to 0 + conn.set_server_option(self.MYSQL_OPTION_MULTI_STATEMENTS_ON) + cursor = conn.cursor() + cursor.execute("select 1; select 2") + self.assertEquals(((1L,),), cursor.fetchall()) + self.assertEquals(1, cursor.nextset()) + self.assertEquals(((2L,),), cursor.fetchall()) + self.assertEquals(None, cursor.nextset()) + cursor.close() + conn.set_server_option(self.MYSQL_OPTION_MULTI_STATEMENTS_OFF) + + # Multi-statement support should not work without the + # option enabled + cursor = conn.cursor() + try: + cursor.execute("select 1; select 2") + self.fail('Execute went through') + except MySQLdb.OperationalError, e: + s = str(e) + self.assertIn('syntax error', s) + cursor.close() + # verify that queries work end-to-end with large grpc messages largeComment = 'L' * ((4 * 1024 * 1024) + 1) cursor = conn.cursor()