Skip to content

Commit

Permalink
Merge pull request #251 from brokercap/v2.2
Browse files Browse the repository at this point in the history
v2.2.0
  • Loading branch information
jc3wish authored Oct 6, 2023
2 parents 6ced480 + fff7c0c commit 72938fb
Show file tree
Hide file tree
Showing 50 changed files with 2,819 additions and 240 deletions.
196 changes: 182 additions & 14 deletions Bristol/mysql/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"bufio"
"crypto/tls"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net"
"strconv"
"strings"
"time"
)

Expand All @@ -21,18 +24,18 @@ type mysqlConn struct {
insertId uint64
lastCmdTime time.Time
keepaliveTimer *time.Timer
status uint16
status uint16
}

type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
authPluginName string
tlsConfig *tls.Config
user string
passwd string
net string
addr string
dbname string
params map[string]string
authPluginName string
tlsConfig *tls.Config
}

type serverSettings struct {
Expand Down Expand Up @@ -181,11 +184,13 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
}

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {

if len(args) > 0 {
return nil, driver.ErrSkip
var err error
query, err = mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
}

mc.affectedRows = 0
mc.insertId = 0

Expand Down Expand Up @@ -236,6 +241,162 @@ func (mc *mysqlConn) exec(query string) (e error) {
return
}

func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
if len(args) > 0 {
var err error
query, err = mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
}
return mc.query(query)
}

func (mc *mysqlConn) query(query string) (dataRows driver.Rows, e error) {
e = mc.writeCommandPacket(COM_QUERY, query)
if e != nil {
return
}

// Read Result
var resLen int
resLen, e = mc.readResultSetHeaderPacket()
if e != nil {
return
}
if resLen == 0 {
return nil, driver.ErrSkip
}
rows := mysqlRows{new(rowsContent)}
rows.content.columns, e = mc.readColumns(resLen)
if e != nil {
return
}
e = mc.readStringRows(rows.content)
return rows, e
}

func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
buf := make([]byte, 0)
argPos := 0
whereIndex := strings.Index(strings.ToUpper(query), "WHERE")

for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q

arg := args[argPos]
argPos++

if arg == nil {
buf = append(buf, "NULL"...)
continue
}

switch v := arg.(type) {
case int8, int16, int32, int:
int64N, _ := strconv.ParseInt(fmt.Sprint(v), 10, 64)
buf = strconv.AppendInt(buf, int64N, 10)
break
case uint8, uint16, uint32, uint:
uint64N, _ := strconv.ParseUint(fmt.Sprint(v), 10, 64)
buf = strconv.AppendUint(buf, uint64N, 10)
break
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
buf = strconv.AppendUint(buf, v, 10)
case float32:
buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 64)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
// 测试下来 在timestamp(0)的情况下,2006-01-02 15:04:05.999999 写,binlog 解析出来和写进去存在 1 秒 误差,暂不知道具体的原因
// 建议使用string传进来,而不要使用time.Time类型
timeStr := v.Format("2006-01-02 15:04:05.999999")
buf = append(buf, '\'')
if mc.isSupportedBackslash() {
buf = escapeBytesBackslash(buf, []byte(timeStr))
} else {
buf = escapeBytesQuotes(buf, []byte(timeStr))
}
buf = append(buf, '\'')
}
case json.RawMessage:
if v == nil {
buf = append(buf, "NULL"...)
continue
}
buf = append(buf, '\'')
if mc.isSupportedBackslash() {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = appendArgsBufByBytes(buf, v, mc.isSupportedBackslash())
}
case string:
buf = appendArgsBufByString(buf, v, mc.isSupportedBackslash())
break
case []string:
if whereIndex > 0 && q > whereIndex {
for ii, val := range v {
if ii > 0 {
buf = append(buf, ',')
}
buf = appendArgsBufByString(buf, val, mc.isSupportedBackslash())
}
} else {
c, _ := json.Marshal(v)
buf = appendArgsBufByBytes(buf, c, mc.isSupportedBackslash())
}
break
case []int, []uint, []int8, []int16, []uint16, []int32, []uint32, []int64, []uint64:
if whereIndex > 0 && q > whereIndex {
whereInStr := strings.Replace(strings.Trim(fmt.Sprint(v), "[]"), " ", ",", -1)
buf = appendArgsBufByString(buf, whereInStr, mc.isSupportedBackslash())
} else {
c, _ := json.Marshal(v)
buf = appendArgsBufByBytes(buf, c, mc.isSupportedBackslash())
}
break
default:
return "", driver.ErrSkip
}

if len(buf)+4 > MAX_PACKET_SIZE {
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}

// Gets the value of the given MySQL System Variable
func (mc *mysqlConn) getSystemVar(name string) (val string, e error) {
// Send command
Expand Down Expand Up @@ -289,11 +450,18 @@ func (mc *mysqlConn) markBadConn(err error) error {
return driver.ErrBadConn
}

func NewConnect(uri string) MysqlConnection{
func (mc *mysqlConn) isSupportedBackslash() bool {
if mc.status&STATUS_NO_BACK_SLASH_ESCAPES == 0 {
return true
}
return false
}

func NewConnect(uri string) MysqlConnection {
dbopen := &mysqlDriver{}
conn, err := dbopen.Open(uri)
if err != nil {
panic(err)
}
return conn.(MysqlConnection)
}
}
118 changes: 118 additions & 0 deletions Bristol/mysql/connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package mysql

import (
"database/sql/driver"
"log"
"testing"
"time"
)

func TestMyconn_Exec_Integration(t *testing.T) {
uri := "root:root@tcp(127.0.0.1:55001)/bifrost_test"
conn := NewConnect(uri)
log.Println("Connect over")
//conn.Close()
//return
connectionId, err := testGetConnectId(conn)
if err != nil {
t.Fatal(err)
}
t.Log("connectionId:", connectionId)

createSQL := `CREATE TABLE IF NOT EXISTS bifrost_test.test_1 (
id int(11) unsigned NOT NULL AUTO_INCREMENT,
testtinyint tinyint(4) NOT NULL DEFAULT '-1',
testsmallint smallint(6) NOT NULL DEFAULT '-2',
testmediumint mediumint(8) NOT NULL DEFAULT '-3',
testint int(11) NOT NULL DEFAULT '-4',
testbigint bigint(20) NOT NULL DEFAULT '-5',
testvarchar varchar(10) NOT NULL,
testchar char(2) NOT NULL,
testenum enum('en1','en2','en3') NOT NULL DEFAULT 'en1',
testset set('set1','set2','set3') NOT NULL DEFAULT 'set1',
testtime time NOT NULL DEFAULT '00:00:00',
testdate date NOT NULL DEFAULT '0000-00-00',
testyear year(4) NOT NULL DEFAULT '1989',
testtimestamp timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP(),
testdatetime datetime NOT NULL DEFAULT '0000-00-00 00:00:00',
testfloat float(9,2) NOT NULL DEFAULT '0.00',
testdouble double(9,2) NOT NULL DEFAULT '0.00',
testdecimal decimal(9,2) NOT NULL DEFAULT '0.00',
testdatatime_null datetime DEFAULT NULL,
PRIMARY KEY (id)
) ENGINE=InnoDB AUTO_INCREMENT=0 DEFAULT CHARSET=utf8 PARTITION BY HASH (id) PARTITIONS 3`
_, err = conn.Exec(createSQL, []driver.Value{})
if err != nil {
t.Fatal(err)
}
insertSQL := `INSERT INTO bifrost_test.test_1
(testtinyint,testsmallint,testmediumint,testint,testbigint,testvarchar,testchar,testenum,testset,testtime,testdate,testyear,testtimestamp,testdatetime,testfloat,testdouble,testdecimal,testdatatime_null)
values
(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
`
args := make([]driver.Value, 18)
args[0] = int8(8)
args[1] = int16(16)
args[2] = int32(24)
args[3] = int32(32)
args[4] = int64(64)
args[5] = "va"
args[6] = "c"
args[7] = "en3"
args[8] = "set1,set3"
args[9] = "10:11:00"
args[10] = "2023-09-17"
args[11] = "2023"
args[12] = time.Now()
args[13] = time.Now()
args[14] = float32(9.32)
args[15] = float64(666.3264)
args[16] = "9.999"
args[17] = nil

r, err := conn.Exec(insertSQL, args)
if err != nil {
t.Fatal(err)
}
t.Log(r.LastInsertId())
t.Log(r.RowsAffected())
}

func TestMyconn_query_Integration(t *testing.T) {
uri := "root:root@tcp(127.0.0.1:55001)/bifrost_test"
conn := NewConnect(uri)
log.Println("Connect over")
//conn.Close()
//return
connectionId, err := testGetConnectId(conn)
if err != nil {
t.Fatal(err)
}
t.Log("connectionId:", connectionId)

selectSQL := "SELECT * FROM bifrost_test.test_1 WHERE id in (?)"
AutoIncrementValue := make([]string, 1)
AutoIncrementValue[0] = "1"
args := make([]driver.Value, 0)
args = append(args, AutoIncrementValue)
//args = append(args, strings.Replace(strings.Trim(fmt.Sprint(AutoIncrementValue), "[]"), " ", "','", -1))
rows, err := conn.Query(selectSQL, args)
if err != nil {
t.Fatal(err)
}
data := make([]map[string]driver.Value, 0)
for {
m := make(map[string]driver.Value, len(rows.Columns()))
dest := make([]driver.Value, len(rows.Columns()), len(rows.Columns()))
err := rows.Next(dest)
if err != nil {
break
}
for i, fieldName := range rows.Columns() {
m[fieldName] = dest[i]
}
data = append(data, m)
}
t.Log(data)
t.Log("id:", data[0]["id"])
}
Loading

0 comments on commit 72938fb

Please sign in to comment.