diff --git a/dsn.go b/dsn.go index a306d66a3..3605012f7 100644 --- a/dsn.go +++ b/dsn.go @@ -47,6 +47,7 @@ type Config struct { pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name tls *tls.Config // TLS configuration + tlsIsPreferred bool // TLS is preferred, not a must Timeout time.Duration // Dial timeout ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout @@ -124,10 +125,13 @@ func (cfg *Config) normalize() error { // don't set anything case "true": cfg.tls = &tls.Config{} - case "skip-verify", "preferred": + case "skip-verify": cfg.tls = &tls.Config{InsecureSkipVerify: true} + case "preferred": + cfg.tls = &tls.Config{InsecureSkipVerify: true} + cfg.tlsIsPreferred = true default: - cfg.tls = getTLSConfigClone(cfg.TLSConfig) + cfg.tls, cfg.tlsIsPreferred = getTLSConfigCloneAndPreferred(cfg.TLSConfig) if cfg.tls == nil { return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) } diff --git a/dsn_test.go b/dsn_test.go index fc6eea9c8..fd2ca84df 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -221,6 +221,66 @@ func TestDSNWithCustomTLS(t *testing.T) { } } +func TestRegisterPreferredTLSConfig(t *testing.T) { + tlsMust := "tls-must" + tlsPreferred := "tls-preferred" + tlsCfg := tls.Config{ServerName: tlsMust} + tlsCfg2 := tls.Config{ServerName: tlsPreferred} + + err := RegisterTLSConfig(tlsMust, &tlsCfg) + if err != nil { + t.Error(err.Error()) + } + defer DeregisterTLSConfig(tlsMust) + err = RegisterPreferredTLSConfig(tlsPreferred, &tlsCfg2) + if err != nil { + t.Error(err.Error()) + } + defer DeregisterTLSConfig(tlsPreferred) + + checkCfgAndPreferred := func(key, expectedName string, expectedPreferred bool) { + cfg, preferred := getTLSConfigCloneAndPreferred(key) + if cfg.ServerName != expectedName { + t.Errorf("did not get the correct TLS ServerName (%s): %s.", expectedName, cfg.ServerName) + } + if expectedPreferred && !preferred { + t.Errorf("this should not be a preferred TLS config: %s", key) + } + if !expectedPreferred && preferred { + t.Errorf("this should be a preferred TLS config: %s", key) + } + } + + checkCfgAndPreferred(tlsMust, tlsMust, false) + checkCfgAndPreferred(tlsPreferred, tlsPreferred, true) + + // test RegisterTLSConfig overwrites existing config and preferred + tlsAnother := "tls-another" + tlsCfg3 := tls.Config{ServerName: tlsAnother} + // register the name tlsPreferred in non-preferred mode! + err = RegisterTLSConfig(tlsPreferred, &tlsCfg3) + if err != nil { + t.Error(err.Error()) + } + checkCfgAndPreferred(tlsPreferred, tlsAnother, false) + + // overwrite it back to preferred + err = RegisterPreferredTLSConfig(tlsPreferred, &tlsCfg3) + if err != nil { + t.Error(err.Error()) + } + checkCfgAndPreferred(tlsPreferred, tlsAnother, true) + + DeregisterTLSConfig(tlsPreferred) + cfg, preferred := getTLSConfigCloneAndPreferred(tlsPreferred) + if cfg != nil { + t.Error("found TLS config after deregister") + } + if preferred { + t.Error("preferred should be false after deregister") + } +} + func TestDSNTLSConfig(t *testing.T) { expectedServerName := "example.com" dsn := "tcp(example.com:1234)/?tls=true" diff --git a/packets.go b/packets.go index 003584c25..be129a6ed 100644 --- a/packets.go +++ b/packets.go @@ -223,7 +223,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - if mc.cfg.TLSConfig == "preferred" { + if mc.cfg.tlsIsPreferred { mc.cfg.tls = nil } else { return nil, "", ErrNoTLS diff --git a/utils.go b/utils.go index 15dbd8d16..d91025f1d 100644 --- a/utils.go +++ b/utils.go @@ -25,8 +25,9 @@ import ( // Registry for custom tls.Configs var ( - tlsConfigLock sync.RWMutex - tlsConfigRegistry map[string]*tls.Config + tlsConfigLock sync.RWMutex + tlsConfigRegistry map[string]*tls.Config + tlsConfigPreferred map[string]struct{} ) // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. @@ -55,6 +56,17 @@ var ( // }) // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") func RegisterTLSConfig(key string, config *tls.Config) error { + return registerTLSConfig(key, config, false) +} + +// RegisterPreferredTLSConfig is like a RegisterTLSConfig, but when the MySQL +// server does not support TLS the driver will try to connect without TLS. +// It can be used as a customized "preferred" TLS configuration in the DSN. +func RegisterPreferredTLSConfig(key string, config *tls.Config) error { + return registerTLSConfig(key, config, true) +} + +func registerTLSConfig(key string, config *tls.Config, preferred bool) error { if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { return fmt.Errorf("key '%s' is reserved", key) } @@ -63,8 +75,16 @@ func RegisterTLSConfig(key string, config *tls.Config) error { if tlsConfigRegistry == nil { tlsConfigRegistry = make(map[string]*tls.Config) } - tlsConfigRegistry[key] = config + + if preferred { + if tlsConfigPreferred == nil { + tlsConfigPreferred = make(map[string]struct{}) + } + tlsConfigPreferred[key] = struct{}{} + } else { + delete(tlsConfigPreferred, key) + } tlsConfigLock.Unlock() return nil } @@ -75,14 +95,18 @@ func DeregisterTLSConfig(key string) { if tlsConfigRegistry != nil { delete(tlsConfigRegistry, key) } + if tlsConfigPreferred != nil { + delete(tlsConfigPreferred, key) + } tlsConfigLock.Unlock() } -func getTLSConfigClone(key string) (config *tls.Config) { +func getTLSConfigCloneAndPreferred(key string) (config *tls.Config, preferred bool) { tlsConfigLock.RLock() if v, ok := tlsConfigRegistry[key]; ok { config = v.Clone() } + _, preferred = tlsConfigPreferred[key] tlsConfigLock.RUnlock() return }