diff --git a/lib/service/service.go b/lib/service/service.go index aff47d596ed05..a4c289b1e4499 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3079,6 +3079,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { if err != nil { return trace.Wrap(err) } + connLimiter, err := limiter.NewConnectionsLimiter(process.Config.Databases.Limiter) + if err != nil { + return trace.Wrap(err) + } dbProxyServer, err := db.NewProxyServer(process.ExitContext(), db.ProxyServerConfig{ AuthClient: conn.Client, @@ -3086,6 +3090,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Authorizer: authorizer, Tunnel: tsrv, TLSConfig: tlsConfig, + Limiter: connLimiter, Emitter: asyncEmitter, Clock: process.Clock, ServerID: cfg.HostUUID, diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 973992cb2f373..c3825afab7927 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -840,6 +840,9 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa testCtx.fakeRemoteSite, }, } + // Empty config means no limit. + connLimiter, err := limiter.NewConnectionsLimiter(limiter.Config{}) + require.NoError(t, err) // Create test audit events emitter. testCtx.emitter = newTestEmitter() @@ -851,6 +854,7 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa Authorizer: proxyAuthorizer, Tunnel: tunnel, TLSConfig: tlsConfig, + Limiter: connLimiter, Emitter: testCtx.emitter, Clock: testCtx.clock, ServerID: "proxy-server", diff --git a/lib/srv/db/common/interfaces.go b/lib/srv/db/common/interfaces.go index 64d760c15a6f8..90034b40ce1ba 100644 --- a/lib/srv/db/common/interfaces.go +++ b/lib/srv/db/common/interfaces.go @@ -30,6 +30,7 @@ type Proxy interface { HandleConnection(context.Context, net.Conn) error } +// ConnectParams keeps parameters used when connecting to Service. type ConnectParams struct { // User is a database username. User string diff --git a/lib/srv/db/dbutils/db.go b/lib/srv/db/dbutils/db.go index 3035dfd954acb..9397b73d79d51 100644 --- a/lib/srv/db/dbutils/db.go +++ b/lib/srv/db/dbutils/db.go @@ -18,7 +18,6 @@ package dbutils import ( "crypto/tls" - "net" "github.com/gravitational/trace" @@ -40,15 +39,3 @@ func IsDatabaseConnection(state tls.ConnectionState) (bool, error) { } return identity.RouteToDatabase.ServiceName != "", nil } - -// ClientIPFromConn extracts host from provided remote address. -func ClientIPFromConn(conn net.Conn) (string, error) { - clientRemoteAddr := conn.RemoteAddr() - - clientIP, _, err := net.SplitHostPort(clientRemoteAddr.String()) - if err != nil { - return "", trace.Wrap(err) - } - - return clientIP, nil -} diff --git a/lib/srv/db/mysql/proxy.go b/lib/srv/db/mysql/proxy.go index a4b6b9eb23340..4208931a9258f 100644 --- a/lib/srv/db/mysql/proxy.go +++ b/lib/srv/db/mysql/proxy.go @@ -26,8 +26,8 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/srv/db/common" - "github.com/gravitational/teleport/lib/srv/db/dbutils" "github.com/gravitational/teleport/lib/srv/db/mysql/protocol" + "github.com/gravitational/teleport/lib/utils" "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/server" @@ -82,7 +82,7 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err if err != nil { return trace.Wrap(err) } - clientIP, err := dbutils.ClientIPFromConn(clientConn) + clientIP, err := utils.ClientIPFromConn(clientConn) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/postgres/proxy.go b/lib/srv/db/postgres/proxy.go index 32d1b453878c3..f921734df3cd1 100644 --- a/lib/srv/db/postgres/proxy.go +++ b/lib/srv/db/postgres/proxy.go @@ -23,7 +23,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/srv/db/common" - "github.com/gravitational/teleport/lib/srv/db/dbutils" + "github.com/gravitational/teleport/lib/utils" "github.com/jackc/pgproto3/v2" @@ -64,7 +64,7 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err if err != nil { return trace.Wrap(err) } - clientIP, err := dbutils.ClientIPFromConn(clientConn) + clientIP, err := utils.ClientIPFromConn(clientConn) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index 54d1a13f6ce09..14b20b1da05bd 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -34,12 +34,12 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/native" + "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/db/common" - "github.com/gravitational/teleport/lib/srv/db/dbutils" "github.com/gravitational/teleport/lib/srv/db/mysql" "github.com/gravitational/teleport/lib/srv/db/postgres" "github.com/gravitational/teleport/lib/tlsca" @@ -76,6 +76,8 @@ type ProxyServerConfig struct { Tunnel reversetunnel.Server // TLSConfig is the proxy server TLS configuration. TLSConfig *tls.Config + // Limiter is the connection limiter. + Limiter *limiter.ConnectionsLimiter // Emitter is used to emit audit events. Emitter events.Emitter // Clock to override clock in tests. @@ -235,7 +237,7 @@ func (s *ProxyServer) handleConnection(conn net.Conn) error { if err != nil { return trace.Wrap(err) } - clientIP, err := dbutils.ClientIPFromConn(conn) + clientIP, err := utils.ClientIPFromConn(conn) if err != nil { return trace.Wrap(err) } @@ -269,7 +271,7 @@ func (s *ProxyServer) dispatch(clientConn net.Conn) (common.Proxy, error) { muxConn.Protocol()) } -// postgresProxy returns a new instance of the Postgres protocol aware proxy. +// PostgresProxy returns a new instance of the Postgres protocol aware proxy. func (s *ProxyServer) PostgresProxy() *postgres.Proxy { return &postgres.Proxy{ TLSConfig: s.cfg.TLSConfig, @@ -279,7 +281,7 @@ func (s *ProxyServer) PostgresProxy() *postgres.Proxy { } } -// mysqlProxy returns a new instance of the MySQL protocol aware proxy. +// MySQLProxy returns a new instance of the MySQL protocol aware proxy. func (s *ProxyServer) MySQLProxy() *mysql.Proxy { return &mysql.Proxy{ TLSConfig: s.cfg.TLSConfig, @@ -298,6 +300,12 @@ func (s *ProxyServer) MySQLProxy() *mysql.Proxy { // // Implements common.Service. func (s *ProxyServer) Connect(ctx context.Context, params common.ConnectParams) (net.Conn, *auth.Context, error) { + // Apply rate limiting. + if err := s.cfg.Limiter.AcquireConnection(params.ClientIP); err != nil { + return nil, nil, trace.LimitExceeded("client %v exceeded connection limit", params.ClientIP) + } + // Limiter will be decremented by ConnCloseWrapper below. + proxyContext, err := s.authorize(ctx, params) if err != nil { return nil, nil, trace.Wrap(err) @@ -329,7 +337,14 @@ func (s *ProxyServer) Connect(ctx context.Context, params common.ConnectParams) // remote server during TLS handshake. On the remote side, the connection // received from the reverse tunnel will be handled by tls.Server. serviceConn = tls.Client(serviceConn, tlsConfig) - return serviceConn, proxyContext.authContext, nil + // Wrap connection, so we can decrement the limit when the connection is closed. + srvConn := &utils.ConnCloseWrapper{ + Conn: serviceConn, + BeforeClose: func() { + s.cfg.Limiter.ReleaseConnection(params.ClientIP) + }, + } + return srvConn, proxyContext.authContext, nil } return nil, nil, trace.BadParameter("failed to connect to any of the database servers") } diff --git a/lib/srv/db/proxyserver_test.go b/lib/srv/db/proxyserver_test.go new file mode 100644 index 0000000000000..2a5c751d7dec5 --- /dev/null +++ b/lib/srv/db/proxyserver_test.go @@ -0,0 +1,71 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package db + +import ( + "context" + "testing" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/limiter" + "github.com/jackc/pgconn" + "github.com/stretchr/testify/require" +) + +func TestProxyConnectionLimiting(t *testing.T) { + const ( + user = "bob" + role = "admin" + dbName = "postgres" + dbUser = user + connLimitNumber = 3 // Arbitrary number + ) + + ctx := context.Background() + testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres")) + + connLimit, err := limiter.NewConnectionsLimiter(limiter.Config{MaxConnections: connLimitNumber}) + require.NoError(t, err) + + // Set proxy connection limiter + testCtx.proxyServer.cfg.Limiter = connLimit + + go testCtx.startHandlingConnections() + + // Create user/role with the requested permissions. + testCtx.createUserAndRole(ctx, t, user, role, []string{types.Wildcard}, []string{types.Wildcard}) + + conns := make([]*pgconn.PgConn, 0) + defer func() { + for _, conn := range conns { + err := conn.Close(ctx) + require.NoError(t, err) + } + }() + + for i := 0; i < connLimitNumber; i++ { + // Try to connect to the database as this user. + pgConn, err := testCtx.postgresClient(ctx, user, "postgres", dbUser, dbName) + require.NoError(t, err) + + conns = append(conns, pgConn) + } + + _, err = testCtx.postgresClient(ctx, user, "postgres", dbUser, dbName) + require.Error(t, err) + require.Contains(t, err.Error(), "exceeded connection limit") +} diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index f725f3543120f..3f3440d27d9e7 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -173,11 +173,8 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { } } if c.Limiter == nil { - // Set default limiter if one is not provided. - connLimiter := limiter.Config{} - defaults.ConfigureLimiter(&connLimiter) - - c.Limiter, err = limiter.NewConnectionsLimiter(connLimiter) + // Use default limiter if nothing is provided. Connection limiting will be disabled. + c.Limiter, err = limiter.NewConnectionsLimiter(limiter.Config{}) if err != nil { return trace.Wrap(err) } @@ -661,12 +658,16 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) error { } clientIP := sessionCtx.Identity.ClientIP - s.log.Debugf("Real client IP %s", clientIP) + if clientIP != "" { + s.log.Debugf("Real client IP %s", clientIP) - if err := s.cfg.Limiter.AcquireConnection(clientIP); err != nil { - return trace.WrapWithMessage(err, "Exceeded connection limit.") + if err := s.cfg.Limiter.AcquireConnection(clientIP); err != nil { + return trace.LimitExceeded("client %v exceeded connection limit", clientIP) + } + defer s.cfg.Limiter.ReleaseConnection(clientIP) + } else { + s.log.Debug("ClientIP is not set (Proxy Service has to be updated). Rate limiting is disabled.") } - defer s.cfg.Limiter.ReleaseConnection(clientIP) streamWriter, err := s.newStreamWriter(sessionCtx) if err != nil { diff --git a/lib/utils/net.go b/lib/utils/net.go new file mode 100644 index 0000000000000..9c392b032a0c6 --- /dev/null +++ b/lib/utils/net.go @@ -0,0 +1,54 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "net" + + "github.com/gravitational/trace" +) + +// ClientIPFromConn extracts host from provided remote address. +func ClientIPFromConn(conn net.Conn) (string, error) { + clientRemoteAddr := conn.RemoteAddr() + + clientIP, _, err := net.SplitHostPort(clientRemoteAddr.String()) + if err != nil { + return "", trace.Wrap(err) + } + + return clientIP, nil +} + +// ConnCloseWrapper is a helper struct that allows to call additional function +// before net.Conn.Close() is called. +type ConnCloseWrapper struct { + // Underlying connection. + net.Conn + // BeforeClose will be called on Close(), before net.Conn.Close(). + BeforeClose func() +} + +// Close calls user provided beforeClose function and then Close() from +// underlying connection. +func (c *ConnCloseWrapper) Close() error { + if c.BeforeClose != nil { + c.BeforeClose() + } + + return c.Conn.Close() +}