Skip to content

Commit

Permalink
Merge pull request go-sql-driver#1 from bLamarche413/cr4
Browse files Browse the repository at this point in the history
Cr4 -- removing buffer from mysqlConn
  • Loading branch information
bLamarche413 authored Oct 8, 2018
2 parents f339392 + 6ceaef6 commit 97afd8d
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 53 deletions.
7 changes: 4 additions & 3 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func BenchmarkRoundtripBin(b *testing.B) {
length = max
}
test := sample[0:length]
rows := tb.checkRows(stmt.Query(test))
rows := tb.checkRows(stmt.Query(test)) //run benchmark tests to test that bit of code
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
Expand All @@ -231,9 +231,10 @@ func BenchmarkInterpolation(b *testing.B) {
},
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
buf: newBuffer(nil),
}
mc.reader = &mc.buf

buf := newBuffer(nil)
mc.reader = &buf

args := []driver.Value{
int64(42424242),
Expand Down
16 changes: 5 additions & 11 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ func (b *buffer) readNext(need int) ([]byte, error) {
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
func (b *buffer) reuseBuffer(length int) []byte {
if length == -1 {
return b.takeCompleteBuffer()
}

if b.length > 0 {
return nil
}
Expand All @@ -126,16 +130,6 @@ func (b *buffer) takeBuffer(length int) []byte {
return make([]byte, length)
}

// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
if b.length == 0 {
return b.buf[:length]
}
return nil
}

// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// Only one buffer (total) can be used at a time.
Expand Down
4 changes: 4 additions & 0 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ func (cr *compressedReader) readNext(need int) ([]byte, error) {
return data, nil
}

func (cr *compressedReader) reuseBuffer(length int) []byte {
return cr.buf.reuseBuffer(length)
}

func (cr *compressedReader) uncompressPacket() error {
header, err := cr.buf.readNext(7) // size of compressed header

Expand Down
4 changes: 4 additions & 0 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ func (mb *mockBuf) readNext(need int) ([]byte, error) {
return data, nil
}

func (mb *mockBuf) reuseBuffer(length int) []byte {
return make([]byte, length) //just give them a new buffer
}

// compressHelper compresses uncompressedPacket and checks state variables
func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte {
// get status variables
Expand Down
5 changes: 3 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ type mysqlContext interface {
}

type mysqlConn struct {
buf buffer
netConn net.Conn
reader packetReader
writer io.Writer
Expand All @@ -55,6 +54,7 @@ type mysqlConn struct {

type packetReader interface {
readNext(need int) ([]byte, error)
reuseBuffer(length int) []byte
}

// Handles parameters set in DSN after the connection is established
Expand Down Expand Up @@ -197,7 +197,8 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
return "", driver.ErrSkip
}

buf := mc.buf.takeCompleteBuffer()
buf := mc.reader.reuseBuffer(-1)

if buf == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand Down
12 changes: 6 additions & 6 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ import (

func TestInterpolateParams(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
mc.reader = &mc.buf
buf := newBuffer(nil)
mc.reader = &buf

q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
if err != nil {
Expand All @@ -36,13 +36,13 @@ func TestInterpolateParams(t *testing.T) {

func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
mc.reader = &mc.buf
buf := newBuffer(nil)
mc.reader = &buf

q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
if err != driver.ErrSkip {
Expand All @@ -54,14 +54,14 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
// https://github.com/go-sql-driver/mysql/pull/490
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}

mc.reader = &mc.buf
buf := newBuffer(nil)
mc.reader = &buf

q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
Expand Down
14 changes: 7 additions & 7 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
s.startWatcher()
}

mc.buf = newBuffer(mc.netConn)

// packet reader and writer in handshake are never compressed
mc.reader = &mc.buf
mc.writer = mc.netConn
buf := newBuffer(mc.netConn)

// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout

// packet reader and writer in handshake are never compressed
mc.reader = &buf
mc.writer = mc.netConn

// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
if err != nil {
Expand All @@ -124,7 +124,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
}

if mc.cfg.Compress {
mc.reader = newCompressedReader(&mc.buf, mc)
mc.reader = newCompressedReader(&buf, mc)
mc.writer = newCompressedWriter(mc.writer, mc)
}

Expand Down
36 changes: 25 additions & 11 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
}

// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
data := mc.reader.reuseBuffer(pktLen + 4)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand Down Expand Up @@ -326,7 +327,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
return err
}
mc.netConn = tlsConn
mc.buf.nc = tlsConn
nc := tlsConn

newBuf := newBuffer(nc)
mc.reader = &newBuf

mc.writer = mc.netConn
}
Expand Down Expand Up @@ -373,7 +377,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {

// Calculate the packet length and add a tailing 0
pktLen := len(scrambleBuff) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
data := mc.reader.reuseBuffer(4 + pktLen)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -392,7 +397,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
func (mc *mysqlConn) writeClearAuthPacket() error {
// Calculate the packet length and add a tailing 0
pktLen := len(mc.cfg.Passwd) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
data := mc.reader.reuseBuffer(4 + pktLen)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -415,7 +421,8 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {

// Calculate the packet length and add a tailing 0
pktLen := len(scrambleBuff)
data := mc.buf.takeSmallBuffer(4 + pktLen)
data := mc.reader.reuseBuffer(4 + pktLen)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -437,7 +444,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
mc.sequence = 0
mc.compressionSequence = 0

data := mc.buf.takeSmallBuffer(4 + 1)
data := mc.reader.reuseBuffer(4 + 1)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -457,7 +465,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
mc.compressionSequence = 0

pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
data := mc.reader.reuseBuffer(pktLen + 4)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand All @@ -479,7 +488,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
mc.sequence = 0
mc.compressionSequence = 0

data := mc.buf.takeSmallBuffer(4 + 1 + 4)
data := mc.reader.reuseBuffer(4 + 1 + 4)

if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
Expand Down Expand Up @@ -946,9 +956,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
var data []byte

if len(args) == 0 {
data = mc.buf.takeBuffer(minPktLen)
data = mc.reader.reuseBuffer(minPktLen)

} else {
data = mc.buf.takeCompleteBuffer()
data = mc.reader.reuseBuffer(-1)
}
if data == nil {
// can not take the buffer. Something must be wrong with the connection
Expand Down Expand Up @@ -1127,7 +1138,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
mc.buf.buf = data

bufBuf := mc.reader.reuseBuffer(-1)
bufBuf = data
fmt.Println(bufBuf) //dont know how to make it compile w/o some op here on bufBuf
}

pos += len(paramValues)
Expand Down
27 changes: 14 additions & 13 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,11 @@ var _ net.Conn = new(mockConn)

func TestReadPacketSingleByte(t *testing.T) {
conn := new(mockConn)
buf := newBuffer(conn)
mc := &mysqlConn{
buf: newBuffer(conn),
reader: &buf,
}

mc.reader = &mc.buf

conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
packet, err := mc.readPacket()
Expand All @@ -111,10 +110,10 @@ func TestReadPacketSingleByte(t *testing.T) {

func TestReadPacketWrongSequenceID(t *testing.T) {
conn := new(mockConn)
buf := newBuffer(conn)
mc := &mysqlConn{
buf: newBuffer(conn),
reader: &buf,
}
mc.reader = &mc.buf

// too low sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
Expand All @@ -128,7 +127,8 @@ func TestReadPacketWrongSequenceID(t *testing.T) {
// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
newBuf := newBuffer(conn)
mc.reader = &newBuf

// too high sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
Expand All @@ -140,12 +140,11 @@ func TestReadPacketWrongSequenceID(t *testing.T) {

func TestReadPacketSplit(t *testing.T) {
conn := new(mockConn)
buf := newBuffer(conn)
mc := &mysqlConn{
buf: newBuffer(conn),
reader: &buf,
}

mc.reader = &mc.buf

data := make([]byte, maxPacketSize*2+4*3)
const pkt2ofs = maxPacketSize + 4
const pkt3ofs = 2 * (maxPacketSize + 4)
Expand Down Expand Up @@ -247,11 +246,11 @@ func TestReadPacketSplit(t *testing.T) {

func TestReadPacketFail(t *testing.T) {
conn := new(mockConn)
buf := newBuffer(conn)
mc := &mysqlConn{
buf: newBuffer(conn),
reader: &buf,
closech: make(chan struct{}),
}
mc.reader = &mc.buf

// illegal empty (stand-alone) packet
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
Expand All @@ -264,7 +263,8 @@ func TestReadPacketFail(t *testing.T) {
// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
newBuf := newBuffer(conn)
mc.reader = &newBuf

// fail to read header
conn.closed = true
Expand All @@ -277,7 +277,8 @@ func TestReadPacketFail(t *testing.T) {
conn.closed = false
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
newBuf = newBuffer(conn)
mc.reader = &newBuf

// fail to read body
conn.maxReads = 1
Expand Down

0 comments on commit 97afd8d

Please sign in to comment.