diff --git a/client/conn.go b/client/conn.go index 683cc66c9..5cbe30025 100644 --- a/client/conn.go +++ b/client/conn.go @@ -1,6 +1,7 @@ package client import ( + "bytes" "context" "crypto/tls" "fmt" @@ -223,30 +224,30 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall return nil, errors.Trace(err) } - var buf []byte var err error var result *Result - defer utils.ByteSlicePut(buf) + + bs := utils.ByteSliceGet(16) + defer utils.ByteSlicePut(bs) for { - buf, err = c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0]) + bs.B, err = c.ReadPacketReuseMem(bs.B[:0]) if err != nil { return nil, errors.Trace(err) } - switch buf[0] { + switch bs.B[0] { case OK_HEADER: - result, err = c.handleOKPacket(buf) + result, err = c.handleOKPacket(bs.B) case ERR_HEADER: - err = c.handleErrorPacket(append([]byte{}, buf...)) + err = c.handleErrorPacket(bytes.Repeat(bs.B, 1)) result = nil case LocalInFile_HEADER: err = ErrMalformPacket result = nil default: - result, err = c.readResultset(buf, false) + result, err = c.readResultset(bs.B, false) } - // call user-defined callback perResultCallback(result, err) diff --git a/client/req.go b/client/req.go index df3cee214..eaa64b014 100644 --- a/client/req.go +++ b/client/req.go @@ -21,11 +21,11 @@ func (c *Conn) writeCommandBuf(command byte, arg []byte) error { length := len(arg) + 1 data := utils.ByteSliceGet(length + 4) - data[4] = command + data.B[4] = command - copy(data[5:], arg) + copy(data.B[5:], arg) - err := c.WritePacket(data) + err := c.WritePacket(data.B) utils.ByteSlicePut(data) diff --git a/client/resp.go b/client/resp.go index 0f5215ebf..0c94b398a 100644 --- a/client/resp.go +++ b/client/resp.go @@ -216,38 +216,42 @@ func (c *Conn) readOK() (*Result, error) { } func (c *Conn) readResult(binary bool) (*Result, error) { - firstPkgBuf, err := c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0]) - defer utils.ByteSlicePut(firstPkgBuf) - + bs := utils.ByteSliceGet(16) + defer utils.ByteSlicePut(bs) + var err error + bs.B, err = c.ReadPacketReuseMem(bs.B[:0]) if err != nil { return nil, errors.Trace(err) } - if firstPkgBuf[0] == OK_HEADER { - return c.handleOKPacket(firstPkgBuf) - } else if firstPkgBuf[0] == ERR_HEADER { - return nil, c.handleErrorPacket(append([]byte{}, firstPkgBuf...)) - } else if firstPkgBuf[0] == LocalInFile_HEADER { + switch bs.B[0] { + case OK_HEADER: + return c.handleOKPacket(bs.B) + case ERR_HEADER: + return nil, c.handleErrorPacket(bytes.Repeat(bs.B, 1)) + case LocalInFile_HEADER: return nil, ErrMalformPacket + default: + return c.readResultset(bs.B, binary) } - - return c.readResultset(firstPkgBuf, binary) } func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error { - firstPkgBuf, err := c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0]) - defer utils.ByteSlicePut(firstPkgBuf) - + bs := utils.ByteSliceGet(16) + defer utils.ByteSlicePut(bs) + var err error + bs.B, err = c.ReadPacketReuseMem(bs.B[:0]) if err != nil { return errors.Trace(err) } - if firstPkgBuf[0] == OK_HEADER { + switch bs.B[0] { + case OK_HEADER: // https://dev.mysql.com/doc/internals/en/com-query-response.html // 14.6.4.1 COM_QUERY Response // If the number of columns in the resultset is 0, this is a OK_Packet. - okResult, err := c.handleOKPacket(firstPkgBuf) + okResult, err := c.handleOKPacket(bs.B) if err != nil { return errors.Trace(err) } @@ -262,13 +266,13 @@ func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectP result.Reset(0) } return nil - } else if firstPkgBuf[0] == ERR_HEADER { - return c.handleErrorPacket(append([]byte{}, firstPkgBuf...)) - } else if firstPkgBuf[0] == LocalInFile_HEADER { + case ERR_HEADER: + return c.handleErrorPacket(bytes.Repeat(bs.B, 1)) + case LocalInFile_HEADER: return ErrMalformPacket + default: + return c.readResultsetStreaming(bs.B, binary, result, perRowCb, perResCb) } - - return c.readResultsetStreaming(firstPkgBuf, binary, result, perRowCb, perResCb) } func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) { diff --git a/utils/byte_slice_pool.go b/utils/byte_slice_pool.go index cd51544ff..8dace1491 100644 --- a/utils/byte_slice_pool.go +++ b/utils/byte_slice_pool.go @@ -2,35 +2,29 @@ package utils import "sync" +type ByteSlice struct { + B []byte +} + var ( byteSlicePool = sync.Pool{ New: func() interface{} { - return []byte{} + return new(ByteSlice) }, } - byteSliceChan = make(chan []byte, 10) ) -func ByteSliceGet(length int) (data []byte) { - select { - case data = <-byteSliceChan: - default: - data = byteSlicePool.Get().([]byte)[:0] - } - - if cap(data) < length { - data = make([]byte, length) +func ByteSliceGet(length int) *ByteSlice { + data := byteSlicePool.Get().(*ByteSlice) + if cap(data.B) < length { + data.B = make([]byte, length) } else { - data = data[:length] + data.B = data.B[:length] } - return data } -func ByteSlicePut(data []byte) { - select { - case byteSliceChan <- data: - default: - byteSlicePool.Put(data) //nolint:staticcheck - } +func ByteSlicePut(data *ByteSlice) { + data.B = data.B[:0] + byteSlicePool.Put(data) } diff --git a/utils/byte_slice_pool_test.go b/utils/byte_slice_pool_test.go new file mode 100644 index 000000000..2f713d590 --- /dev/null +++ b/utils/byte_slice_pool_test.go @@ -0,0 +1,12 @@ +package utils + +import "testing" + +func BenchmarkByteSlicePool(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + b := ByteSliceGet(16) + b.B = append(b.B[:0], 0, 1) + ByteSlicePut(b) + } +} diff --git a/utils/bytes_buffer_pool_test.go b/utils/bytes_buffer_pool_test.go new file mode 100644 index 000000000..820073e47 --- /dev/null +++ b/utils/bytes_buffer_pool_test.go @@ -0,0 +1,12 @@ +package utils + +import "testing" + +func BenchmarkBytesBufferPool(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + b := BytesBufferGet() + b.WriteString("01") + BytesBufferPut(b) + } +}