Skip to content

Commit

Permalink
Merge pull request #1249 from cyberark/verify-full-pg-mysql
Browse files Browse the repository at this point in the history
Add support for sslmode=verify-full for mysql and pg
  • Loading branch information
doodlesbykumbi authored Jun 23, 2020
2 parents 7e195e4 + 1180aff commit b3a3c23
Show file tree
Hide file tree
Showing 31 changed files with 813 additions and 323 deletions.
1 change: 1 addition & 0 deletions .gitleaks.toml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ files = [
"test/connector/ssh/id_(.*)", # test ssh handler certs
"test/connector/ssh_agent/id_(.*)", # test ssh-agent handler certs
"test/connector/tcp/mssql/certs/(.*)", # test mssql connector certs
"internal/plugin/connectors/tcp/ssl/testdata/(.*)", # test shared ssl package certs
"test/ssh/id_(.*)", # since-removed ssh test certs
"test/util/ssl/(.*)", # test ssl certs
"internal/plugin/connectors/tcp/mssql/connection_details_test.go", # fake cert string
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

### Added
- MySQL and PostgreSQL connectors support SSL host name verification with
`verify-full` SSL mode. Also adds optional `sslhost` configuration parameter
that is compared to the server's certificate SAN. [#548](https://github.com/cyberark/secretless-broker/issues/548)

## [1.6.0] - 2020-05-04

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func (h *AuthenticationHandshake) dbSSLMode() *ssl.DbSSLMode {

var ret ssl.DbSSLMode
ret, h.err = ssl.NewDbSSLMode(
h.connectionDetails.Options, false,
h.connectionDetails.SSLOptions, false,
)
h.sslMode = &ret

Expand Down
51 changes: 32 additions & 19 deletions internal/plugin/connectors/tcp/mysql/connection_details.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,70 @@ package mysql

import "strconv"

// DefaultMySQLPort is the default port on which we connect to the MySQL service
// If another port is found within the connectionDetails, we will use that.
const DefaultMySQLPort = uint(3306)

var sslOptions = []string{
"host",
"sslhost",
"sslrootcert",
"sslmode",
"sslkey",
"sslcert",
}

// ConnectionDetails stores the connection info to the real backend database.
// These values are pulled from the SingleUseConnector credentials config
type ConnectionDetails struct {
Host string
Options map[string]string
Password string
Port uint
Username string
Host string
Options map[string]string
Password string
Port uint
SSLOptions map[string]string
Username string
}

// DefaultMySQLPort is the default port on which we connect to the MySQL service
// If another port is found within the connectionDetails, we will use that.
const DefaultMySQLPort = uint(3306)

// NewConnectionDetails is a constructor of ConnectionDetails structure from a
// map of credentials.
func NewConnectionDetails(credentials map[string][]byte) (
*ConnectionDetails, error) {

connDetails := &ConnectionDetails{Options: make(map[string]string)}
connDetails := &ConnectionDetails{
Options: make(map[string]string),
SSLOptions: make(map[string]string),
}

if host := credentials["host"]; host != nil {
if len(credentials["host"]) > 0 {
connDetails.Host = string(credentials["host"])
}

connDetails.Port = DefaultMySQLPort
if credentials["port"] != nil {
if len(credentials["port"]) > 0 {
port64, _ := strconv.ParseUint(string(credentials["port"]), 10, 64)
connDetails.Port = uint(port64)
}

if credentials["username"] != nil {
if len(credentials["username"]) > 0 {
connDetails.Username = string(credentials["username"])
}

if credentials["password"] != nil {
if len(credentials["password"]) > 0 {
connDetails.Password = string(credentials["password"])
}

// Make sure that we process the SSL mode arg even if it's not specified
// otherwise it will get ignored
if _, ok := credentials["sslmode"]; !ok {
credentials["sslmode"] = []byte("")
for _, sslOption := range sslOptions {
if len(credentials[sslOption]) > 0 {
connDetails.SSLOptions[sslOption] = string(credentials[sslOption])
}
delete(credentials, sslOption)
}

delete(credentials, "host")
delete(credentials, "port")
delete(credentials, "username")
delete(credentials, "password")

connDetails.Options = make(map[string]string)
for k, v := range credentials {
connDetails.Options[k] = string(v)
}
Expand Down
16 changes: 9 additions & 7 deletions internal/plugin/connectors/tcp/mysql/connection_details_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ func TestExpectedFields(t *testing.T) {
}

expectedConnDetails := ConnectionDetails{
Host: "myhost",
Options: map[string]string{
Host: "myhost",
Options: map[string]string{},
SSLOptions: map[string]string{
"host": "myhost",
"sslmode": "disable",
},
Password: "mypassword",
Expand Down Expand Up @@ -45,8 +47,9 @@ func TestDefaultPort(t *testing.T) {
Port: DefaultMySQLPort,
Username: "myusername",
Password: "mypassword",
Options: map[string]string{
"sslmode": "",
Options: map[string]string{},
SSLOptions: map[string]string{
"host": "myhost",
},
}

Expand All @@ -69,9 +72,8 @@ func TestUnexpectedFieldsAreSavedAsOptions(t *testing.T) {
}

expectedOptions := map[string]string{
"foo": "5432",
"bar": "data",
"sslmode": "",
"foo": "5432",
"bar": "data",
}

actualConnDetails, err := NewConnectionDetails(credentials)
Expand Down
7 changes: 3 additions & 4 deletions internal/plugin/connectors/tcp/pg/connect_details.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
const DefaultPostgresPort = "5432"

var sslOptions = []string{
"host",
"sslhost",
"sslrootcert",
"sslmode",
"sslkey",
Expand Down Expand Up @@ -75,10 +77,7 @@ func NewConnectionDetails(options map[string][]byte) (*ConnectionDetails, error)

for _, sslOption := range sslOptions {
if len(options[sslOption]) > 0 {
value := string(options[sslOption])
if value != "" {
connectionDetails.SSLOptions[sslOption] = value
}
connectionDetails.SSLOptions[sslOption] = string(options[sslOption])
}
delete(options, sslOption)
}
Expand Down
31 changes: 19 additions & 12 deletions internal/plugin/connectors/tcp/pg/connect_details_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ func TestExpectedFields(t *testing.T) {
}

expectedConnectionDetails := ConnectionDetails{
Host: "myhost",
Port: "1234",
Username: "myusername",
Password: "mypassword",
Options: map[string]string{},
SSLOptions: map[string]string{},
Host: "myhost",
Port: "1234",
Username: "myusername",
Password: "mypassword",
Options: map[string]string{},
SSLOptions: map[string]string{
"host": "myhost",
},
}

actualConnectionDetails, err := NewConnectionDetails(options)
Expand All @@ -38,6 +40,7 @@ func TestSSLOptions(t *testing.T) {
"username": []byte("myusername"),
"password": []byte("mypassword"),

"sslhost": []byte("customhost"),
"sslrootcert": []byte("mysslrootcert"),
"sslmode": []byte("mysslmode"),
"sslkey": []byte("mysslkey"),
Expand All @@ -51,6 +54,8 @@ func TestSSLOptions(t *testing.T) {
Password: "mypassword",
Options: map[string]string{},
SSLOptions: map[string]string{
"host": "myhost",
"sslhost": "customhost",
"sslrootcert": "mysslrootcert",
"sslmode": "mysslmode",
"sslkey": "mysslkey",
Expand All @@ -74,12 +79,14 @@ func TestDefaultPort(t *testing.T) {
}

expectedConnectionDetails := ConnectionDetails{
Host: "myhost",
Port: DefaultPostgresPort,
Username: "myusername",
Password: "mypassword",
Options: map[string]string{},
SSLOptions: map[string]string{},
Host: "myhost",
Port: DefaultPostgresPort,
Username: "myusername",
Password: "mypassword",
Options: map[string]string{},
SSLOptions: map[string]string{
"host": "myhost",
},
}

actualConnectionDetails, err := NewConnectionDetails(options)
Expand Down
64 changes: 35 additions & 29 deletions internal/plugin/connectors/tcp/ssl/ssl.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type DbSSLMode struct {
}

// NewDbSSLMode configures and creates a DbSSLMode
func NewDbSSLMode(o options, requireCanVerifyCAOnly bool) (DbSSLMode, error) {
func NewDbSSLMode(o options, requireCanVerifyCA bool) (DbSSLMode, error) {
// NOTE for the "require" case:
//
// From http://www.postgresql.org/docs/current/static/libpq-ssl.html:
Expand All @@ -34,10 +34,10 @@ func NewDbSSLMode(o options, requireCanVerifyCAOnly bool) (DbSSLMode, error) {
switch mode := o["sslmode"]; mode {
case "disable":
sslMode.UseTLS = false
return sslMode, nil
// "require" is the default.

// "require" is the default.
case "", "require":
// Skip TLS's own verification: it requires full verification since Go 1.3.
// Skip stdlib's verification: it requires full verification since Go 1.3.
sslMode.InsecureSkipVerify = true

// From http://www.postgresql.org/docs/current/static/libpq-ssl.html:
Expand All @@ -51,18 +51,29 @@ func NewDbSSLMode(o options, requireCanVerifyCAOnly bool) (DbSSLMode, error) {

// MySQL on the other hand notes in its docs that it ignores
// SSL certs if supplied in REQUIRED sslmode.
if requireCanVerifyCAOnly && len(o["sslrootcert"]) > 0 {
if requireCanVerifyCA && len(o["sslrootcert"]) > 0 {
sslMode.VerifyCaOnly = true
}

case "verify-ca":
// Skip TLS's own verification: it requires full verification since Go 1.3.
// Skip stdlib's verification: it requires full verification since Go 1.3.
sslMode.InsecureSkipVerify = true
sslMode.VerifyCaOnly = true
//case "verify-full":
// sslMode.ServerName = o["host"]

case "verify-full":
// Use stdlib's verification
sslMode.InsecureSkipVerify = false
sslMode.VerifyCaOnly = false

// 'sslhost', when not empty, takes precedence over 'host'
if len(o["sslhost"]) > 0 {
sslMode.ServerName = o["sslhost"]
} else {
sslMode.ServerName = o["host"]
}

default:
// TODO add verify-full below
return DbSSLMode{}, fmt.Errorf(`unsupported sslmode %q; only "require" (default), "verify-ca", and "disable" supported`, mode)
return DbSSLMode{}, fmt.Errorf(`unsupported sslmode %q; only "require" (default), "verify-ca", "verify-full" and "disable" supported`, mode)
}

return sslMode, nil
Expand All @@ -74,9 +85,16 @@ func HandleSSLUpgrade(connection net.Conn, tlsConf DbSSLMode) (net.Conn, error)
if err != nil {
return nil, err
}
err = sslCertificateAuthority(&tlsConf.Config, tlsConf.Options)
if err != nil {
return nil, err

// Add the root CA certificate specified in the "sslrootcert" setting to the root CA
// pool on the tls configuration.
sslRootCert := []byte(tlsConf.Options["sslrootcert"])
if len(sslRootCert) > 0 {
tlsConf.RootCAs = x509.NewCertPool()

if !tlsConf.RootCAs.AppendCertsFromPEM(sslRootCert) {
return nil, fmt.Errorf("couldn't parse pem in sslrootcert")
}
}

// Accept renegotiation requests initiated by the backend.
Expand All @@ -93,6 +111,10 @@ func HandleSSLUpgrade(connection net.Conn, tlsConf DbSSLMode) (net.Conn, error)
return nil, err
}
}
err = client.Handshake()
if err != nil {
return nil, err
}

return client, nil
}
Expand Down Expand Up @@ -120,22 +142,6 @@ func sslClientCertificates(tlsConf *tls.Config, o options) error {
return nil
}

// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting.
func sslCertificateAuthority(tlsConf *tls.Config, o options) error {
// The root certificate is only loaded if the setting is not blank.
if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 {
tlsConf.RootCAs = x509.NewCertPool()

cert := []byte(sslrootcert)

if !tlsConf.RootCAs.AppendCertsFromPEM(cert) {
return fmt.Errorf("couldn't parse pem in sslrootcert")
}
}

return nil
}

// sslVerifyCertificateAuthority carries out a TLS handshake to the server and
// verifies the presented certificate against the CA, i.e. the one specified in
// sslrootcert or the system CA if sslrootcert was not specified.
Expand Down
Loading

0 comments on commit b3a3c23

Please sign in to comment.