diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cd474767b..b2ab5e82a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -79,6 +79,7 @@ jobs: ; TestConcurrent fails if max_connections is too large max_connections=50 local_infile=1 + performance_schema=on - name: setup database run: | mysql --user 'root' --host '127.0.0.1' -e 'create database gotest;' diff --git a/README.md b/README.md index ddb5cefc7..2e81fefd9 100644 --- a/README.md +++ b/README.md @@ -393,6 +393,15 @@ Default: 0 I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. +##### `connectionAttributes` + +``` +Type: comma-delimited string of user-defined "key:value" pairs +Valid Values: (:,:,...) +Default: none +``` + +[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time. ##### System Variables diff --git a/connection.go b/connection.go index a7da9e7e2..67cea1fcb 100644 --- a/connection.go +++ b/connection.go @@ -27,6 +27,7 @@ type mysqlConn struct { affectedRows uint64 insertId uint64 cfg *Config + connector *connector maxAllowedPacket int maxWriteSize int writeTimeout time.Duration diff --git a/connector.go b/connector.go index a5c988e13..6acf3dd50 100644 --- a/connector.go +++ b/connector.go @@ -11,11 +11,54 @@ package mysql import ( "context" "database/sql/driver" + "fmt" "net" + "os" + "strconv" + "strings" ) type connector struct { - cfg *Config // immutable private copy. + cfg *Config // immutable private copy. + encodedAttributes string // Encoded connection attributes. +} + +func encodeConnectionAttributes(textAttributes string) string { + connAttrsBuf := make([]byte, 0, 251) + + // default connection attributes + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid())) + + // user-defined connection attributes + for _, connAttr := range strings.Split(textAttributes, ",") { + attr := strings.SplitN(connAttr, ":", 2) + if len(attr) != 2 { + continue + } + for _, v := range attr { + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v) + } + } + + return string(connAttrsBuf) +} + +func newConnector(cfg *Config) (*connector, error) { + encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes) + if len(encodedAttributes) > 250 { + return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes) + } + return &connector{ + cfg: cfg, + encodedAttributes: encodedAttributes, + }, nil } // Connect implements driver.Connector interface. @@ -29,6 +72,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), cfg: c.cfg, + connector: c, } mc.parseTime = mc.cfg.ParseTime diff --git a/connector_test.go b/connector_test.go index 976903c5b..bedb44ce2 100644 --- a/connector_test.go +++ b/connector_test.go @@ -8,13 +8,16 @@ import ( ) func TestConnectorReturnsTimeout(t *testing.T) { - connector := &connector{&Config{ + connector, err := newConnector(&Config{ Net: "tcp", Addr: "1.1.1.1:1234", Timeout: 10 * time.Millisecond, - }} + }) + if err != nil { + t.Fatal(err) + } - _, err := connector.Connect(context.Background()) + _, err = connector.Connect(context.Background()) if err == nil { t.Fatal("error expected") } diff --git a/const.go b/const.go index 64e2bced6..0f2621a6f 100644 --- a/const.go +++ b/const.go @@ -8,12 +8,24 @@ package mysql +import "runtime" + const ( defaultAuthPlugin = "mysql_native_password" defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" + + // Connection attributes + // See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available + connAttrClientName = "_client_name" + connAttrClientNameValue = "Go-MySQL-Driver" + connAttrOS = "_os" + connAttrOSValue = runtime.GOOS + connAttrPlatform = "_platform" + connAttrPlatformValue = runtime.GOARCH + connAttrPid = "_pid" ) // MySQL constants documentation: diff --git a/driver.go b/driver.go index 8b0c3ec0a..c19e04207 100644 --- a/driver.go +++ b/driver.go @@ -85,8 +85,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - c := &connector{ - cfg: cfg, + c, err := newConnector(cfg) + if err != nil { + return nil, err } return c.Connect(context.Background()) } @@ -103,7 +104,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) { if err := cfg.normalize(); err != nil { return nil, err } - return &connector{cfg: cfg}, nil + return newConnector(cfg) } // OpenConnector implements driver.DriverContext. @@ -112,7 +113,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { if err != nil { return nil, err } - return &connector{ - cfg: cfg, - }, nil + return newConnector(cfg) } diff --git a/driver_test.go b/driver_test.go index 118c0d7ba..7c25aa905 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3214,3 +3214,50 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) { t.Errorf("connection not closed") } } + +func TestConnectionAttributes(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + attr1 := "attr1" + value1 := "value1" + attr2 := "foo" + value2 := "boo" + dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2) + + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + + var attrValue string + queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?" + rows := dbt.mustQuery(queryString, connAttrClientName) + if rows.Next() { + rows.Scan(&attrValue) + if attrValue != connAttrClientNameValue { + dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue) + } + } else { + dbt.Errorf("no data") + } + rows.Close() + + rows = dbt.mustQuery(queryString, attr2) + if rows.Next() { + rows.Scan(&attrValue) + if attrValue != value2 { + dbt.Errorf("expected %q, got %q", value2, attrValue) + } + } else { + dbt.Errorf("no data") + } + rows.Close() +} diff --git a/dsn.go b/dsn.go index ded459c94..7c788517c 100644 --- a/dsn.go +++ b/dsn.go @@ -34,23 +34,24 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key - TLSConfig string // TLS configuration name - TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout - Logger Logger // Logger + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + Logger Logger // Logger AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -560,6 +561,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } + + // Connection attributes + case "connectionAttributes": + cfg.ConnectionAttributes = value + default: // lazy init if cfg.Params == nil { diff --git a/packets.go b/packets.go index 8fd67997b..d6a11fd21 100644 --- a/packets.go +++ b/packets.go @@ -285,6 +285,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientLocalFiles | clientPluginAuth | clientMultiResults | + clientConnectAttrs | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { @@ -318,6 +319,13 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pktLen += n + 1 } + // 1 byte to store length of all key-values + // NOTE: Actually, this is length encoded integer. + // But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer + // doesn't support buffer size more than 4096 bytes. + // TODO(methane): Rewrite buffer management. + pktLen += 1 + len(mc.connector.encodedAttributes) + // Calculate packet length and get buffer with that size data, err := mc.buf.takeSmallBuffer(pktLen + 4) if err != nil { @@ -394,6 +402,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data[pos] = 0x00 pos++ + // Connection Attributes + data[pos] = byte(len(mc.connector.encodedAttributes)) + pos++ + pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + // Send Auth packet return mc.writePacket(data[:pos]) } diff --git a/packets_test.go b/packets_test.go index cacec1c68..f429087e9 100644 --- a/packets_test.go +++ b/packets_test.go @@ -96,9 +96,14 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) + connector, err := newConnector(NewConfig()) + if err != nil { + panic(err) + } mc := &mysqlConn{ buf: newBuffer(conn), - cfg: NewConfig(), + cfg: connector.cfg, + connector: connector, netConn: conn, closech: make(chan struct{}), maxAllowedPacket: defaultMaxAllowedPacket, diff --git a/utils.go b/utils.go index 15dbd8d16..753ebd65c 100644 --- a/utils.go +++ b/utils.go @@ -616,6 +616,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } +func appendLengthEncodedString(b []byte, s string) []byte { + b = appendLengthEncodedInteger(b, uint64(len(s))) + return append(b, s...) +} + // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. // If cap(buf) is not enough, reallocate new buffer. func reserveBuffer(buf []byte, appendSize int) []byte {