Skip to content

Commit

Permalink
A way to register preferred TLS config
Browse files Browse the repository at this point in the history
Signed-off-by: lance6716 <[email protected]>
  • Loading branch information
lance6716 committed Nov 11, 2022
1 parent fa1e4ed commit ad744fc
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 7 deletions.
8 changes: 6 additions & 2 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type Config struct {
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
tlsIsPreferred bool // TLS is preferred, not a must
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Expand Down Expand Up @@ -124,10 +125,13 @@ func (cfg *Config) normalize() error {
// don't set anything
case "true":
cfg.tls = &tls.Config{}
case "skip-verify", "preferred":
case "skip-verify":
cfg.tls = &tls.Config{InsecureSkipVerify: true}
case "preferred":
cfg.tls = &tls.Config{InsecureSkipVerify: true}
cfg.tlsIsPreferred = true
default:
cfg.tls = getTLSConfigClone(cfg.TLSConfig)
cfg.tls, cfg.tlsIsPreferred = getTLSConfigCloneAndPreferred(cfg.TLSConfig)
if cfg.tls == nil {
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
}
Expand Down
60 changes: 60 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,66 @@ func TestDSNWithCustomTLS(t *testing.T) {
}
}

func TestRegisterPreferredTLSConfig(t *testing.T) {
tlsMust := "tls-must"
tlsPreferred := "tls-preferred"
tlsCfg := tls.Config{ServerName: tlsMust}
tlsCfg2 := tls.Config{ServerName: tlsPreferred}

err := RegisterTLSConfig(tlsMust, &tlsCfg)
if err != nil {
t.Error(err.Error())
}
defer DeregisterTLSConfig(tlsMust)
err = RegisterPreferredTLSConfig(tlsPreferred, &tlsCfg2)
if err != nil {
t.Error(err.Error())
}
defer DeregisterTLSConfig(tlsPreferred)

checkCfgAndPreferred := func(key, expectedName string, expectedPreferred bool) {
cfg, preferred := getTLSConfigCloneAndPreferred(key)
if cfg.ServerName != expectedName {
t.Errorf("did not get the correct TLS ServerName (%s): %s.", expectedName, cfg.ServerName)
}
if expectedPreferred && !preferred {
t.Errorf("this should not be a preferred TLS config: %s", key)
}
if !expectedPreferred && preferred {
t.Errorf("this should be a preferred TLS config: %s", key)
}
}

checkCfgAndPreferred(tlsMust, tlsMust, false)
checkCfgAndPreferred(tlsPreferred, tlsPreferred, true)

// test RegisterTLSConfig overwrites existing config and preferred
tlsAnother := "tls-another"
tlsCfg3 := tls.Config{ServerName: tlsAnother}
// register the name tlsPreferred in non-preferred mode!
err = RegisterTLSConfig(tlsPreferred, &tlsCfg3)
if err != nil {
t.Error(err.Error())
}
checkCfgAndPreferred(tlsPreferred, tlsAnother, false)

// overwrite it back to preferred
err = RegisterPreferredTLSConfig(tlsPreferred, &tlsCfg3)
if err != nil {
t.Error(err.Error())
}
checkCfgAndPreferred(tlsPreferred, tlsAnother, true)

DeregisterTLSConfig(tlsPreferred)
cfg, preferred := getTLSConfigCloneAndPreferred(tlsPreferred)
if cfg != nil {
t.Error("found TLS config after deregister")
}
if preferred {
t.Error("preferred should be false after deregister")
}
}

func TestDSNTLSConfig(t *testing.T) {
expectedServerName := "example.com"
dsn := "tcp(example.com:1234)/?tls=true"
Expand Down
2 changes: 1 addition & 1 deletion packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
if mc.cfg.TLSConfig == "preferred" {
if mc.cfg.tlsIsPreferred {
mc.cfg.tls = nil
} else {
return nil, "", ErrNoTLS
Expand Down
32 changes: 28 additions & 4 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import (

// Registry for custom tls.Configs
var (
tlsConfigLock sync.RWMutex
tlsConfigRegistry map[string]*tls.Config
tlsConfigLock sync.RWMutex
tlsConfigRegistry map[string]*tls.Config
tlsConfigPreferred map[string]struct{}
)

// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
Expand Down Expand Up @@ -55,6 +56,17 @@ var (
// })
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
func RegisterTLSConfig(key string, config *tls.Config) error {
return registerTLSConfig(key, config, false)
}

// RegisterPreferredTLSConfig is like a RegisterTLSConfig, but when the MySQL
// server does not support TLS the driver will try to connect without TLS.
// It can be used as a customized "preferred" TLS configuration in the DSN.
func RegisterPreferredTLSConfig(key string, config *tls.Config) error {
return registerTLSConfig(key, config, true)
}

func registerTLSConfig(key string, config *tls.Config, preferred bool) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
return fmt.Errorf("key '%s' is reserved", key)
}
Expand All @@ -63,8 +75,16 @@ func RegisterTLSConfig(key string, config *tls.Config) error {
if tlsConfigRegistry == nil {
tlsConfigRegistry = make(map[string]*tls.Config)
}

tlsConfigRegistry[key] = config

if preferred {
if tlsConfigPreferred == nil {
tlsConfigPreferred = make(map[string]struct{})
}
tlsConfigPreferred[key] = struct{}{}
} else {
delete(tlsConfigPreferred, key)
}
tlsConfigLock.Unlock()
return nil
}
Expand All @@ -75,14 +95,18 @@ func DeregisterTLSConfig(key string) {
if tlsConfigRegistry != nil {
delete(tlsConfigRegistry, key)
}
if tlsConfigPreferred != nil {
delete(tlsConfigPreferred, key)
}
tlsConfigLock.Unlock()
}

func getTLSConfigClone(key string) (config *tls.Config) {
func getTLSConfigCloneAndPreferred(key string) (config *tls.Config, preferred bool) {
tlsConfigLock.RLock()
if v, ok := tlsConfigRegistry[key]; ok {
config = v.Clone()
}
_, preferred = tlsConfigPreferred[key]
tlsConfigLock.RUnlock()
return
}
Expand Down

0 comments on commit ad744fc

Please sign in to comment.