Skip to content

Commit

Permalink
Move MySQL packet parsing to individual functions (#10430)
Browse files Browse the repository at this point in the history
  • Loading branch information
greedy52 committed Apr 14, 2022
1 parent e9617b7 commit a6b30fe
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 185 deletions.
104 changes: 104 additions & 0 deletions lib/srv/db/mysql/protocol/command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package protocol

import (
"bytes"

"github.com/gravitational/trace"
)

// Query represents the COM_QUERY command.
//
// https://dev.mysql.com/doc/internals/en/com-query.html
// https://mariadb.com/kb/en/com_query/
type Query struct {
packet

// query is the query text.
query string
}

// Query returns the query text.
func (p *Query) Query() string {
return p.query
}

// Quit represents the COM_QUIT command.
//
// https://dev.mysql.com/doc/internals/en/com-quit.html
// https://mariadb.com/kb/en/com_quit/
type Quit struct {
packet
}

// ChangeUser represents the COM_CHANGE_USER command.
//
// https://dev.mysql.com/doc/internals/en/com-change-user.html
// https://mariadb.com/kb/en/com_change_user/
type ChangeUser struct {
packet

// user is the requested user.
user string
}

// User returns the requested user.
func (p *ChangeUser) User() string {
return p.user
}

// parseQueryPacket parses packet bytes and returns a Packet if successful.
func parseQueryPacket(packetBytes []byte) (Packet, error) {
// Be a bit paranoid and make sure the packet is not truncated.
if len(packetBytes) < packetHeaderAndTypeSize {
return nil, trace.BadParameter("failed to parse COM_QUERY packet: %v", packetBytes)
}

// 4-byte packet header + 1-byte payload header, then query text.
return &Query{
packet: packet{bytes: packetBytes},
query: string(packetBytes[packetHeaderAndTypeSize:]),
}, nil
}

// parseQuitPacket parses packet bytes and returns a Packet if successful.
func parseQuitPacket(packetBytes []byte) (Packet, error) {
return &Quit{
packet: packet{bytes: packetBytes},
}, nil
}

// parseChangeUserPacket parses packet bytes and returns a Packet if
// successful.
func parseChangeUserPacket(packetBytes []byte) (Packet, error) {
if len(packetBytes) < packetHeaderAndTypeSize {
return nil, trace.BadParameter("failed to parse COM_CHANGE_USER packet: %v", packetBytes)
}

// User is the first null-terminated string in the payload:
// https://dev.mysql.com/doc/internals/en/com-change-user.html#packet-COM_CHANGE_USER
idx := bytes.IndexByte(packetBytes[packetHeaderAndTypeSize:], 0x00)
if idx < -1 {
return nil, trace.BadParameter("failed to parse COM_CHANGE_USER packet: %v", packetBytes)
}

return &ChangeUser{
packet: packet{bytes: packetBytes},
user: string(packetBytes[packetHeaderAndTypeSize : packetHeaderAndTypeSize+idx]),
}, nil
}
175 changes: 28 additions & 147 deletions lib/srv/db/mysql/protocol/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package protocol

import (
"bytes"
"io"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -46,60 +45,6 @@ type Generic struct {
packet
}

// OK represents the OK packet.
//
// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
type OK struct {
packet
}

// Error represents the ERR packet.
//
// https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html
type Error struct {
packet
// message is the error message
message string
}

// Error returns the error message.
func (p *Error) Error() string {
return p.message
}

// Query represents the COM_QUERY command.
//
// https://dev.mysql.com/doc/internals/en/com-query.html
type Query struct {
packet
// query is the query text.
query string
}

// Query returns the query text.
func (p *Query) Query() string {
return p.query
}

// Quit represents the COM_QUIT command.
//
// https://dev.mysql.com/doc/internals/en/com-quit.html
type Quit struct {
packet
}

// ChangeUser represents the COM_CHANGE_USER command.
type ChangeUser struct {
packet
// user is the requested user.
user string
}

// User returns the requested user.
func (p *ChangeUser) User() string {
return p.user
}

// ParsePacket reads a protocol packet from the connection and returns it
// in a parsed form. See ReadPacket below for the packet structure.
func ParsePacket(conn io.Reader) (Packet, error) {
Expand All @@ -108,103 +53,18 @@ func ParsePacket(conn io.Reader) (Packet, error) {
return nil, trace.Wrap(err)
}

packet := packet{packetBytes}

switch packetType {
case mysql.OK_HEADER:
return &OK{packet: packet}, nil

case mysql.ERR_HEADER:
// Figure out where in the packet the error message is.
//
// Depending on the protocol version, the packet may include additional
// fields. In protocol version 4.1 it includes '#' marker:
//
// https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html
minLen := packetHeaderSize + packetTypeSize + 2 // 4-byte header + 1-byte type + 2-byte error code
if len(packetBytes) > minLen && packetBytes[minLen] == '#' {
minLen += 6 // 1-byte marker '#' + 5-byte state
}
// Be a bit paranoid and make sure the packet is not truncated.
if len(packetBytes) < minLen {
return nil, trace.BadParameter("failed to parse ERR packet: %v", packetBytes)
if int(packetType) < len(packetParsersByType) && packetParsersByType[packetType] != nil {
packet, err := packetParsersByType[packetType](packetBytes)
if err != nil {
return nil, trace.Wrap(err)
}
return &Error{packet: packet, message: string(packetBytes[minLen:])}, nil

case mysql.COM_QUERY:
// Be a bit paranoid and make sure the packet is not truncated.
if len(packetBytes) < packetHeaderAndTypeSize {
return nil, trace.BadParameter("failed to parse COM_QUERY packet: %v", packetBytes)
}
// 4-byte packet header + 1-byte payload header, then query text.
return &Query{packet: packet, query: string(packetBytes[packetHeaderAndTypeSize:])}, nil

case mysql.COM_QUIT:
return &Quit{packet: packet}, nil

case mysql.COM_CHANGE_USER:
if len(packetBytes) < packetHeaderAndTypeSize {
return nil, trace.BadParameter("failed to parse COM_CHANGE_USER packet: %v", packetBytes)
}
// User is the first null-terminated string in the payload:
// https://dev.mysql.com/doc/internals/en/com-change-user.html#packet-COM_CHANGE_USER
idx := bytes.IndexByte(packetBytes[packetHeaderAndTypeSize:], 0x00)
if idx < 0 {
return nil, trace.BadParameter("failed to parse COM_CHANGE_USER packet: %v", packetBytes)
}
return &ChangeUser{packet: packet, user: string(packetBytes[packetHeaderAndTypeSize : packetHeaderAndTypeSize+idx])}, nil

case mysql.COM_STMT_PREPARE:
packet, ok := parseStatementPreparePacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_PREPARE packet: %v", packetBytes)
}
return packet, nil

case mysql.COM_STMT_SEND_LONG_DATA:
packet, ok := parseStatementSendLongDataPacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_SEND_LONG_DATA packet: %v", packetBytes)
}
return packet, nil

case mysql.COM_STMT_EXECUTE:
packet, ok := parseStatementExecutePacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_EXECUTE packet: %v", packetBytes)
}
return packet, nil

case mysql.COM_STMT_CLOSE:
packet, ok := parseStatementClosePacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_CLOSE packet: %v", packetBytes)
}
return packet, nil

case mysql.COM_STMT_RESET:
packet, ok := parseStatementResetPacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_RESET packet: %v", packetBytes)
}
return packet, nil

case mysql.COM_STMT_FETCH:
packet, ok := parseStatementFetchPacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_FETCH packet: %v", packetBytes)
}
return packet, nil

case packetTypeStatementBulkExecute:
packet, ok := parseStatementBulkExecutePacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_BULK_EXECUTE packet: %v", packetBytes)
}
return packet, nil
}

return &Generic{packet: packet}, nil
return &Generic{
packet: packet{bytes: packetBytes},
}, nil
}

// ReadPacket reads a protocol packet from the connection.
Expand Down Expand Up @@ -262,6 +122,27 @@ func WritePacket(pkt []byte, conn io.Writer) (int, error) {
return n, nil
}

// packetParsersByType is a slice of packet parser functions by packet type.
var packetParsersByType = []func([]byte) (Packet, error){
// Server responses.
mysql.OK_HEADER: parseOKPacket,
mysql.ERR_HEADER: parseErrorPacket,

// Text protocol commands.
mysql.COM_QUERY: parseQueryPacket,
mysql.COM_QUIT: parseQuitPacket,
mysql.COM_CHANGE_USER: parseChangeUserPacket,

// Prepared statement commands.
mysql.COM_STMT_PREPARE: parseStatementPreparePacket,
mysql.COM_STMT_SEND_LONG_DATA: parseStatementSendLongDataPacket,
mysql.COM_STMT_EXECUTE: parseStatementExecutePacket,
mysql.COM_STMT_CLOSE: parseStatementClosePacket,
mysql.COM_STMT_RESET: parseStatementResetPacket,
mysql.COM_STMT_FETCH: parseStatementFetchPacket,
packetTypeStatementBulkExecute: parseStatementBulkExecutePacket,
}

const (
// packetHeaderSize is the size of the packet header.
packetHeaderSize = 4
Expand Down
26 changes: 26 additions & 0 deletions lib/srv/db/mysql/protocol/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,28 @@ func TestParsePacket(t *testing.T) {
input: iotest.ErrReader(&net.OpError{}),
expectErrorIs: trace.IsConnectionProblem,
},
{
name: "not enough data for header",
input: bytes.NewBuffer([]byte{0x00}),
expectErrorIs: isUnexpectedEOFError,
},
{
name: "not enough data for payload",
input: bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0x00, 0x01}),
expectErrorIs: isUnexpectedEOFError,
},
{
name: "unrecognized type",
input: bytes.NewBuffer([]byte{
0x01, 0x00, 0x00, 0x00, // header
0x44, // type
}),
expectedPacket: &Generic{
packet: packet{
bytes: []byte{0x01, 0x00, 0x00, 0x00, 0x44},
},
},
},
{
name: "OK_HEADER",
input: bytes.NewBuffer(sampleOKPacket.Bytes()),
Expand Down Expand Up @@ -320,3 +342,7 @@ func TestParsePacket(t *testing.T) {
})
}
}

func isUnexpectedEOFError(err error) bool {
return trace.Unwrap(err) == io.ErrUnexpectedEOF
}
Loading

0 comments on commit a6b30fe

Please sign in to comment.