Skip to content

Commit

Permalink
Fix ALPN SNI Proxy TLS termination for DB connections (#8303)
Browse files Browse the repository at this point in the history
  • Loading branch information
smallinsky authored Sep 24, 2021
1 parent bddc54d commit e8f9220
Show file tree
Hide file tree
Showing 17 changed files with 257 additions and 138 deletions.
3 changes: 2 additions & 1 deletion integration/proxy_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/testlog"
Expand Down Expand Up @@ -453,7 +454,7 @@ func mustCreateKubeConfigFile(t *testing.T, config clientcmdapi.Config) string {
return configPath
}

func mustStartALPNLocalProxy(t *testing.T, addr string, protocol alpnproxy.Protocol) *alpnproxy.LocalProxy {
func mustStartALPNLocalProxy(t *testing.T, addr string, protocol common.Protocol) *alpnproxy.LocalProxy {
listener, err := net.Listen("tcp", ":0")
require.NoError(t, err)

Expand Down
8 changes: 4 additions & 4 deletions integration/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
"github.com/gravitational/teleport/lib/srv/db/mysql"
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {
pack.waitForLeaf(t)

t.Run("mysql", func(t *testing.T) {
lp := mustStartALPNLocalProxy(t, pack.root.cluster.GetProxyAddr(), alpnproxy.ProtocolMySQL)
lp := mustStartALPNLocalProxy(t, pack.root.cluster.GetProxyAddr(), alpncommon.ProtocolMySQL)
t.Run("connect to main cluster via proxy", func(t *testing.T) {
client, err := mysql.MakeTestClient(common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
Expand Down Expand Up @@ -340,7 +340,7 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {
})

t.Run("postgres", func(t *testing.T) {
lp := mustStartALPNLocalProxy(t, pack.root.cluster.GetProxyAddr(), alpnproxy.ProtocolPostgres)
lp := mustStartALPNLocalProxy(t, pack.root.cluster.GetProxyAddr(), alpncommon.ProtocolPostgres)
t.Run("connect to main cluster via proxy", func(t *testing.T) {
client, err := postgres.MakeTestClient(context.Background(), common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
Expand Down Expand Up @@ -380,7 +380,7 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {
})

t.Run("mongo", func(t *testing.T) {
lp := mustStartALPNLocalProxy(t, pack.root.cluster.GetProxyAddr(), alpnproxy.ProtocolMongoDB)
lp := mustStartALPNLocalProxy(t, pack.root.cluster.GetProxyAddr(), alpncommon.ProtocolMongoDB)
t.Run("connect to main cluster via proxy", func(t *testing.T) {
client, err := mongodb.MakeTestClient(context.Background(), common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
Expand Down
73 changes: 67 additions & 6 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ import (
"bytes"
"context"
"crypto/subtle"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"math"
"math/rand"
"net"
"net/url"
Expand All @@ -39,6 +41,14 @@ import (

"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/pborman/uuid"
"github.com/prometheus/client_golang/prometheus"
saml2 "github.com/russellhaering/gosaml2"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/constants"
Expand All @@ -63,12 +73,6 @@ import (
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/interval"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/pborman/uuid"
"github.com/prometheus/client_golang/prometheus"
saml2 "github.com/russellhaering/gosaml2"
"golang.org/x/crypto/ssh"
)

// ServerOption allows setting options as functional arguments to Server
Expand Down Expand Up @@ -3436,3 +3440,60 @@ func isHTTPS(u string) error {

return nil
}

// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
// TLS config with client CAs pool of the specified cluster.
func WithClusterCAs(tlsConfig *tls.Config, ap AccessPoint, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
return func(info *tls.ClientHelloInfo) (*tls.Config, error) {
var clusterName string
var err error
if info.ServerName != "" {
// Newer clients will set SNI that encodes the cluster name.
clusterName, err = apiutils.DecodeClusterName(info.ServerName)
if err != nil {
if !trace.IsNotFound(err) {
log.Debugf("Ignoring unsupported cluster name name %q.", info.ServerName)
clusterName = ""
}
}
}
pool, err := ClientCertPool(ap, clusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", clusterName)
// this falls back to the default config
return nil, nil
}

// Per https://tools.ietf.org/html/rfc5246#section-7.4.4 the total size of
// the known CA subjects sent to the client can't exceed 2^16-1 (due to
// 2-byte length encoding). The crypto/tls stack will panic if this
// happens.
//
// This usually happens on the root cluster with a very large (>500) number
// of leaf clusters. In these cases, the client cert will be signed by the
// current (root) cluster.
//
// If the number of CAs turns out too large for the handshake, drop all but
// the current cluster CA. In the unlikely case where it's wrong, the
// client will be rejected.
var totalSubjectsLen int64
for _, s := range pool.Subjects() {
// Each subject in the list gets a separate 2-byte length prefix.
totalSubjectsLen += 2
totalSubjectsLen += int64(len(s))
}
if totalSubjectsLen >= int64(math.MaxUint16) {
log.Debugf("Number of CAs in client cert pool is too large (%d) and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate.", len(pool.Subjects()))

pool, err = ClientCertPool(ap, currentClusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", currentClusterName)
// this falls back to the default config
return nil, nil
}
}
tlsCopy := tlsConfig.Clone()
tlsCopy.ClientCAs = pool
return tlsCopy, nil
}
}
4 changes: 2 additions & 2 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/shell"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/sshutils/scp"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -2109,7 +2109,7 @@ func (tc *TeleportClient) connectToProxy(ctx context.Context) (*ProxyClient, err

func makeProxySSHClientWithTLSWrapper(cfg Config, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
tlsConn, err := tls.Dial("tcp", cfg.WebProxyAddr, &tls.Config{
NextProtos: []string{string(alpnproxy.ProtocolProxySSH)},
NextProtos: []string{string(alpncommon.ProtocolProxySSH)},
InsecureSkipVerify: cfg.InsecureSkipVerify,
})
if err != nil {
Expand Down
59 changes: 6 additions & 53 deletions lib/kube/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ package proxy

import (
"crypto/tls"
"math"
"net"
"net/http"
"sync"

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"
Expand All @@ -50,6 +48,8 @@ type TLSServerConfig struct {
AccessPoint auth.AccessPoint
// OnHeartbeat is a callback for kubernetes_service heartbeats.
OnHeartbeat func(error)
// Log is the logger.
Log log.FieldLogger
}

// CheckAndSetDefaults checks and sets default values
Expand All @@ -73,6 +73,9 @@ func (c *TLSServerConfig) CheckAndSetDefaults() error {
if c.AccessPoint == nil {
return trace.BadParameter("missing parameter AccessPoint")
}
if c.Log == nil {
c.Log = log.New()
}
return nil
}

Expand Down Expand Up @@ -197,57 +200,7 @@ func (t *TLSServer) Close() error {
// and server's GetConfigForClient reloads the list of trusted
// local and remote certificate authorities
func (t *TLSServer) GetConfigForClient(info *tls.ClientHelloInfo) (*tls.Config, error) {
var clusterName string
var err error
if info.ServerName != "" {
// Newer clients will set SNI that encodes the cluster name.
clusterName, err = apiutils.DecodeClusterName(info.ServerName)
if err != nil {
if !trace.IsNotFound(err) {
log.Debugf("Ignoring unsupported cluster name name %q.", info.ServerName)
clusterName = ""
}
}
}
pool, err := auth.ClientCertPool(t.AccessPoint, clusterName)
if err != nil {
log.Errorf("failed to retrieve client pool: %v", trace.DebugReport(err))
// this falls back to the default config
return nil, nil
}

// Per https://tools.ietf.org/html/rfc5246#section-7.4.4 the total size of
// the known CA subjects sent to the client can't exceed 2^16-1 (due to
// 2-byte length encoding). The crypto/tls stack will panic if this
// happens.
//
// This usually happens on the root cluster with a very large (>500) number
// of leaf clusters. In these cases, the client cert will be signed by the
// current (root) cluster.
//
// If the number of CAs turns out too large for the handshake, drop all but
// the current cluster CA. In the unlikely case where it's wrong, the
// client will be rejected.
var totalSubjectsLen int64
for _, s := range pool.Subjects() {
// Each subject in the list gets a separate 2-byte length prefix.
totalSubjectsLen += 2
totalSubjectsLen += int64(len(s))
}
if totalSubjectsLen >= int64(math.MaxUint16) {
log.Debugf("number of CAs in client cert pool is too large (%d) and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate", len(pool.Subjects()))

pool, err = auth.ClientCertPool(t.AccessPoint, t.ClusterName)
if err != nil {
log.Errorf("failed to retrieve client pool: %v", trace.DebugReport(err))
// this falls back to the default config
return nil, nil
}
}

tlsCopy := t.TLS.Clone()
tlsCopy.ClientCAs = pool
return tlsCopy, nil
return auth.WithClusterCAs(t.TLS, t.AccessPoint, t.ClusterName, t.Log)(info)
}

// GetServerInfo returns a services.Server object for heartbeats (aka
Expand Down
3 changes: 3 additions & 0 deletions lib/kube/proxy/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"time"

"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -95,6 +96,7 @@ func TestMTLSClientCAs(t *testing.T) {

srv := &TLSServer{
TLSServerConfig: TLSServerConfig{
Log: logrus.New(),
ForwarderConfig: ForwarderConfig{
ClusterName: mainClusterName,
},
Expand Down Expand Up @@ -185,6 +187,7 @@ func TestGetServerInfo(t *testing.T) {

srv := &TLSServer{
TLSServerConfig: TLSServerConfig{
Log: logrus.New(),
ForwarderConfig: ForwarderConfig{
Clock: clockwork.NewFakeClock(),
ClusterName: "kube-cluster",
Expand Down
1 change: 0 additions & 1 deletion lib/service/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ func (process *TeleportProcess) connectToAuthService(role types.SystemRole) (*Co
return nil, trace.Wrap(err)
}
process.log.Debugf("Connected client: %v", connector.ClientIdentity)
process.log.Debugf("Connected server: %v", connector.ServerIdentity)
process.addConnector(connector)

return connector, nil
Expand Down
32 changes: 20 additions & 12 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ import (
"github.com/gravitational/teleport/lib/srv"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
alpnproxyauth "github.com/gravitational/teleport/lib/srv/alpnproxy/auth"
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/srv/app"
"github.com/gravitational/teleport/lib/srv/db"
"github.com/gravitational/teleport/lib/srv/desktop"
Expand Down Expand Up @@ -2990,18 +2991,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {

if alpnRouter != nil && !cfg.Proxy.DisableDatabaseProxy {
alpnRouter.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByProtocol(alpnproxy.ProtocolMySQL),
MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolMySQL),
Handler: dbProxyServer.MySQLProxy().HandleConnection,
})
alpnRouter.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByProtocol(alpnproxy.ProtocolPostgres),
MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolPostgres),
Handler: dbProxyServer.PostgresProxy().HandleConnection,
})
alpnRouter.Add(alpnproxy.HandlerDecs{
// Add MongoDB teleport ALPN protocol without setting custom Handler.
// ALPN Proxy will handle MongoDB connection internally (terminate wrapped TLS traffic) and route
// extracted connection to ALPN Proxy DB TLS Handler.
MatchFunc: alpnproxy.MatchByProtocol(alpnproxy.ProtocolMongoDB),
MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolMongoDB),
})
}

Expand Down Expand Up @@ -3039,14 +3040,21 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
if !cfg.Proxy.DisableTLS && !cfg.Proxy.DisableALPNSNIListener && listeners.web != nil {
authDialerService := alpnproxyauth.NewAuthProxyDialerService(tsrv, accessPoint)
alpnRouter.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByALPNPrefix(string(alpnproxy.ProtocolAuth)),
MatchFunc: alpnproxy.MatchByALPNPrefix(string(alpncommon.ProtocolAuth)),
HandlerWithConnInfo: authDialerService.HandleConnection,
ForwardTLS: true,
})
identityTLSConf, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites)
if err != nil {
return trace.Wrap(err)
}
alpnServer, err = alpnproxy.New(alpnproxy.ProxyConfig{
TLSConfig: tlsConfigWeb.Clone(),
Router: alpnRouter,
Listener: listeners.alpn,
WebTLSConfig: tlsConfigWeb.Clone(),
IdentityTLSConfig: identityTLSConf,
Router: alpnRouter,
Listener: listeners.alpn,
ClusterName: clusterName,
AccessPoint: accessPoint,
})
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -3213,7 +3221,7 @@ func setupALPNRouter(listeners *proxyListeners, cfg *Config) *alpnproxy.Router {
if !cfg.Proxy.DisableReverseTunnel {
reverseTunnel := alpnproxy.NewMuxListenerWrapper(listeners.reverseTunnel, listeners.web)
router.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByProtocol(alpnproxy.ProtocolReverseTunnel),
MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolReverseTunnel),
Handler: reverseTunnel.HandleConnection,
})
listeners.reverseTunnel = reverseTunnel
Expand All @@ -3223,9 +3231,9 @@ func setupALPNRouter(listeners *proxyListeners, cfg *Config) *alpnproxy.Router {
webWrapper := alpnproxy.NewMuxListenerWrapper(nil, listeners.web)
router.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByProtocol(
alpnproxy.ProtocolHTTP,
alpnproxy.ProtocolHTTP2,
alpnproxy.ProtocolDefault,
alpncommon.ProtocolHTTP,
alpncommon.ProtocolHTTP2,
alpncommon.ProtocolDefault,
acme.ALPNProto,
),
Handler: webWrapper.HandleConnection,
Expand All @@ -3235,7 +3243,7 @@ func setupALPNRouter(listeners *proxyListeners, cfg *Config) *alpnproxy.Router {
}
sshProxyListener := alpnproxy.NewMuxListenerWrapper(listeners.ssh, listeners.web)
router.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByProtocol(alpnproxy.ProtocolProxySSH),
MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolProxySSH),
Handler: sshProxyListener.HandleConnection,
})
listeners.ssh = sshProxyListener
Expand Down
Loading

0 comments on commit e8f9220

Please sign in to comment.