Skip to content

Commit

Permalink
Merge pull request #673 from atomicules/tls-in-driver
Browse files Browse the repository at this point in the history
Allow TLS connections in the driver
lance6716 authored May 5, 2022
2 parents 5a98cc1 + b490dc8 commit 145f684
Showing 3 changed files with 167 additions and 32 deletions.
30 changes: 21 additions & 9 deletions client/tls.go
Original file line number Diff line number Diff line change
@@ -13,16 +13,28 @@ func NewClientTLSConfig(caPem, certPem, keyPem []byte, insecureSkipVerify bool,
panic("failed to add ca PEM")
}

cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
panic(err)
}
var config *tls.Config

config := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: pool,
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
// Allow cert and key to be optional
// Send through `make([]byte, 0)` for "nil"
if string(certPem) != "" && string(keyPem) != "" {
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
panic(err)
}
config = &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
}
} else {
config = &tls.Config{
RootCAs: pool,
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
}
}

return config
}
142 changes: 119 additions & 23 deletions driver/driver.go
Original file line number Diff line number Diff line change
@@ -3,51 +3,120 @@
package driver

import (
"crypto/tls"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"io"
"strings"
"net/url"
"regexp"
"sync"

"github.com/go-mysql-org/go-mysql/client"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/errors"
"github.com/siddontang/go/hack"
)

var customTLSMutex sync.Mutex

// Map of dsn address (makes more sense than full dsn?) to tls Config
var customTLSConfigMap = make(map[string]*tls.Config)

type driver struct {
}

// Open: DSN user:password@addr[?db]
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
lastIndex := strings.LastIndex(dsn, "@")
seps := []string{dsn[:lastIndex], dsn[lastIndex+1:]}
if len(seps) != 2 {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
type connInfo struct {
standardDSN bool
addr string
user string
password string
db string
params url.Values
}

// ParseDSN takes a DSN string and splits it up into struct containing addr,
// user, password and db.
// It returns an error if unable to parse.
// The struct also contains a boolean indicating if the DSN is in legacy or
// standard form.
//
// Legacy form uses a `?` is used as the path separator: user:password@addr[?db]
// Standard form uses a `/`: user:password@addr/db?param=value
//
// Optional parameters are supported in the standard DSN form
func parseDSN(dsn string) (connInfo, error) {
var matchErr error
ci := connInfo{}

// If a "/" occurs after "@" and then no more "@" or "/" occur after that
ci.standardDSN, matchErr = regexp.MatchString("@[^@]+/[^@/]+", dsn)
if matchErr != nil {
return ci, errors.Errorf("invalid dsn, must be user:password@addr[/db[?param=X]]")
}

// Add a prefix so we can parse with url.Parse
dsn = "mysql://" + dsn
parsedDSN, parseErr := url.Parse(dsn)
if parseErr != nil {
return ci, errors.Errorf("invalid dsn, must be user:password@addr[/db[?param=X]]")
}

var user string
var password string
var addr string
var db string
ci.addr = parsedDSN.Host
ci.user = parsedDSN.User.Username()
// We ignore the second argument as that is just a flag for existence of a password
// If not set we get empty string anyway
ci.password, _ = parsedDSN.User.Password()

if ss := strings.Split(seps[0], ":"); len(ss) == 2 {
user, password = ss[0], ss[1]
} else if len(ss) == 1 {
user = ss[0]
if ci.standardDSN {
ci.db = parsedDSN.Path[1:]
ci.params = parsedDSN.Query()
} else {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
ci.db = parsedDSN.RawQuery
// This is the equivalent to a "nil" list of parameters
ci.params = url.Values{}
}

if ss := strings.Split(seps[1], "?"); len(ss) == 2 {
addr, db = ss[0], ss[1]
} else if len(ss) == 1 {
addr = ss[0]
} else {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
return ci, nil
}

// Open takes a supplied DSN string and opens a connection
// See ParseDSN for more information on the form of the DSN
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
var c *client.Conn

ci, err := parseDSN(dsn)

if err != nil {
return nil, err
}

c, err := client.Connect(addr, user, password, db)
if ci.standardDSN {
if ci.params["ssl"] != nil {
tlsConfigName := ci.params.Get("ssl")
switch tlsConfigName {
case "true":
// This actually does insecureSkipVerify
// But not even sure if it makes sense to handle false? According to
// client_test.go it doesn't - it'd result in an error
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.UseSSL(true) })
case "custom":
// I was too concerned about mimicking what go-sql-driver/mysql does which will
// allow any name for a custom tls profile and maps the query parameter value to
// that TLSConfig variable... there is no need to be that clever.
// Instead of doing that, let's store required custom TLSConfigs in a map that
// uses the DSN address as the key
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.SetTLSConfig(customTLSConfigMap[ci.addr]) })
default:
return nil, errors.Errorf("Supported options are ssl=true or ssl=custom")
}
} else {
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
}
} else {
// No more processing here. Let's only support url parameters with the newer style DSN
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
}
if err != nil {
return nil, err
}
@@ -229,3 +298,30 @@ func (r *rows) Next(dest []sqldriver.Value) error {
func init() {
sql.Register("mysql", driver{})
}

// SetCustomTLSConfig sets a custom TLSConfig for the address (host:port) of the supplied DSN.
// It requires a full import of the driver (not by side-effects only).
// Example of supplying a custom CA, no client cert, no key, validating the
// certificate, and supplying a serverName for the validation:
//
// driver.SetCustomTLSConfig(CaPem, make([]byte, 0), make([]byte, 0), false, "my.domain.name")
//
func SetCustomTLSConfig(dsn string, caPem []byte, certPem []byte, keyPem []byte, insecureSkipVerify bool, serverName string) error {
// Extract addr from dsn
parsed, err := url.Parse(dsn)
if err != nil {
return errors.Errorf("Unable to parse DSN. Need to extract address to use as key for storing custom TLS config")
}
addr := parsed.Host

// I thought about using serverName instead of addr below, but decided against that as
// having multiple CA certs for one hostname is likely when you have services running on
// different ports.

customTLSMutex.Lock()
// Basic pass-through function so we can just import the driver
customTLSConfigMap[addr] = client.NewClientTLSConfig(caPem, certPem, keyPem, insecureSkipVerify, serverName)
customTLSMutex.Unlock()

return nil
}
27 changes: 27 additions & 0 deletions driver/driver_test.go
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@ package driver
import (
"flag"
"fmt"
"net/url"
"reflect"
"testing"

"github.com/jmoiron/sqlx"
@@ -78,3 +80,28 @@ func (s *testDriverSuite) TestTransaction(c *C) {
err = tx.Commit()
c.Assert(err, IsNil)
}

func TestParseDSN(t *testing.T) {
// List of DSNs to test and expected results
// Use different numbered domains to more readily see what has failed - since we
// test in a loop we get the same line number on error
testDSNs := map[string]connInfo{
"user:password@localhost?db": connInfo{standardDSN: false, addr: "localhost", user: "user", password: "password", db: "db", params: url.Values{}},
"[email protected]?db": connInfo{standardDSN: false, addr: "1.domain.com", user: "user", password: "", db: "db", params: url.Values{}},
"user:[email protected]/db": connInfo{standardDSN: true, addr: "2.domain.com", user: "user", password: "password", db: "db", params: url.Values{}},
"user:[email protected]/db?ssl=true": connInfo{standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"true"}}},
"user:[email protected]/db?ssl=custom": connInfo{standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"custom"}}},
"user:[email protected]/db?unused=param": connInfo{standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"unused": []string{"param"}}},
}

for supplied, expected := range testDSNs {
actual, err := parseDSN(supplied)
if err != nil {
t.Errorf("TestParseDSN failed. Got error: %s", err)
}
// Compare that with expected
if !reflect.DeepEqual(actual, expected) {
t.Errorf("TestParseDSN failed.\nExpected:\n%#v\nGot:\n%#v", expected, actual)
}
}
}

0 comments on commit 145f684

Please sign in to comment.