Skip to content

Commit

Permalink
Add rate limiting to DB proxy server.
Browse files Browse the repository at this point in the history
Improve documentation.
Minor refactoring.
  • Loading branch information
jakule committed Dec 4, 2021
1 parent 9b7d19b commit 0480908
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 31 deletions.
5 changes: 5 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3079,13 +3079,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
if err != nil {
return trace.Wrap(err)
}
connLimiter, err := limiter.NewConnectionsLimiter(process.Config.Databases.Limiter)
if err != nil {
return trace.Wrap(err)
}
dbProxyServer, err := db.NewProxyServer(process.ExitContext(),
db.ProxyServerConfig{
AuthClient: conn.Client,
AccessPoint: accessPoint,
Authorizer: authorizer,
Tunnel: tsrv,
TLSConfig: tlsConfig,
Limiter: connLimiter,
Emitter: asyncEmitter,
Clock: process.Clock,
ServerID: cfg.HostUUID,
Expand Down
4 changes: 4 additions & 0 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,9 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa
testCtx.fakeRemoteSite,
},
}
// Empty config means no limit.
connLimiter, err := limiter.NewConnectionsLimiter(limiter.Config{})
require.NoError(t, err)

// Create test audit events emitter.
testCtx.emitter = newTestEmitter()
Expand All @@ -851,6 +854,7 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa
Authorizer: proxyAuthorizer,
Tunnel: tunnel,
TLSConfig: tlsConfig,
Limiter: connLimiter,
Emitter: testCtx.emitter,
Clock: testCtx.clock,
ServerID: "proxy-server",
Expand Down
1 change: 1 addition & 0 deletions lib/srv/db/common/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Proxy interface {
HandleConnection(context.Context, net.Conn) error
}

// ConnectParams keeps parameters used when connecting to Service.
type ConnectParams struct {
// User is a database username.
User string
Expand Down
13 changes: 0 additions & 13 deletions lib/srv/db/dbutils/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package dbutils

import (
"crypto/tls"
"net"

"github.com/gravitational/trace"

Expand All @@ -40,15 +39,3 @@ func IsDatabaseConnection(state tls.ConnectionState) (bool, error) {
}
return identity.RouteToDatabase.ServiceName != "", nil
}

// ClientIPFromConn extracts host from provided remote address.
func ClientIPFromConn(conn net.Conn) (string, error) {
clientRemoteAddr := conn.RemoteAddr()

clientIP, _, err := net.SplitHostPort(clientRemoteAddr.String())
if err != nil {
return "", trace.Wrap(err)
}

return clientIP, nil
}
4 changes: 2 additions & 2 deletions lib/srv/db/mysql/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/dbutils"
"github.com/gravitational/teleport/lib/srv/db/mysql/protocol"
"github.com/gravitational/teleport/lib/utils"

"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/server"
Expand Down Expand Up @@ -82,7 +82,7 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err
if err != nil {
return trace.Wrap(err)
}
clientIP, err := dbutils.ClientIPFromConn(clientConn)
clientIP, err := utils.ClientIPFromConn(clientConn)
if err != nil {
return trace.Wrap(err)
}
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/db/postgres/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/dbutils"
"github.com/gravitational/teleport/lib/utils"

"github.com/jackc/pgproto3/v2"

Expand Down Expand Up @@ -64,7 +64,7 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err
if err != nil {
return trace.Wrap(err)
}
clientIP, err := dbutils.ClientIPFromConn(clientConn)
clientIP, err := utils.ClientIPFromConn(clientConn)
if err != nil {
return trace.Wrap(err)
}
Expand Down
25 changes: 20 additions & 5 deletions lib/srv/db/proxyserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ import (
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/dbutils"
"github.com/gravitational/teleport/lib/srv/db/mysql"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/tlsca"
Expand Down Expand Up @@ -76,6 +76,8 @@ type ProxyServerConfig struct {
Tunnel reversetunnel.Server
// TLSConfig is the proxy server TLS configuration.
TLSConfig *tls.Config
// Limiter is the connection limiter.
Limiter *limiter.ConnectionsLimiter
// Emitter is used to emit audit events.
Emitter events.Emitter
// Clock to override clock in tests.
Expand Down Expand Up @@ -235,7 +237,7 @@ func (s *ProxyServer) handleConnection(conn net.Conn) error {
if err != nil {
return trace.Wrap(err)
}
clientIP, err := dbutils.ClientIPFromConn(conn)
clientIP, err := utils.ClientIPFromConn(conn)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -269,7 +271,7 @@ func (s *ProxyServer) dispatch(clientConn net.Conn) (common.Proxy, error) {
muxConn.Protocol())
}

// postgresProxy returns a new instance of the Postgres protocol aware proxy.
// PostgresProxy returns a new instance of the Postgres protocol aware proxy.
func (s *ProxyServer) PostgresProxy() *postgres.Proxy {
return &postgres.Proxy{
TLSConfig: s.cfg.TLSConfig,
Expand All @@ -279,7 +281,7 @@ func (s *ProxyServer) PostgresProxy() *postgres.Proxy {
}
}

// mysqlProxy returns a new instance of the MySQL protocol aware proxy.
// MySQLProxy returns a new instance of the MySQL protocol aware proxy.
func (s *ProxyServer) MySQLProxy() *mysql.Proxy {
return &mysql.Proxy{
TLSConfig: s.cfg.TLSConfig,
Expand All @@ -298,6 +300,12 @@ func (s *ProxyServer) MySQLProxy() *mysql.Proxy {
//
// Implements common.Service.
func (s *ProxyServer) Connect(ctx context.Context, params common.ConnectParams) (net.Conn, *auth.Context, error) {
// Apply rate limiting.
if err := s.cfg.Limiter.AcquireConnection(params.ClientIP); err != nil {
return nil, nil, trace.LimitExceeded("client %v exceeded connection limit", params.ClientIP)
}
// Limiter will be decremented by ConnCloseWrapper below.

proxyContext, err := s.authorize(ctx, params)
if err != nil {
return nil, nil, trace.Wrap(err)
Expand Down Expand Up @@ -329,7 +337,14 @@ func (s *ProxyServer) Connect(ctx context.Context, params common.ConnectParams)
// remote server during TLS handshake. On the remote side, the connection
// received from the reverse tunnel will be handled by tls.Server.
serviceConn = tls.Client(serviceConn, tlsConfig)
return serviceConn, proxyContext.authContext, nil
// Wrap connection, so we can decrement the limit when the connection is closed.
srvConn := &utils.ConnCloseWrapper{
Conn: serviceConn,
BeforeClose: func() {
s.cfg.Limiter.ReleaseConnection(params.ClientIP)
},
}
return srvConn, proxyContext.authContext, nil
}
return nil, nil, trace.BadParameter("failed to connect to any of the database servers")
}
Expand Down
71 changes: 71 additions & 0 deletions lib/srv/db/proxyserver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
Copyright 2021 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package db

import (
"context"
"testing"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/limiter"
"github.com/jackc/pgconn"
"github.com/stretchr/testify/require"
)

func TestProxyConnectionLimiting(t *testing.T) {
const (
user = "bob"
role = "admin"
dbName = "postgres"
dbUser = user
connLimitNumber = 3 // Arbitrary number
)

ctx := context.Background()
testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres"))

connLimit, err := limiter.NewConnectionsLimiter(limiter.Config{MaxConnections: connLimitNumber})
require.NoError(t, err)

// Set proxy connection limiter
testCtx.proxyServer.cfg.Limiter = connLimit

go testCtx.startHandlingConnections()

// Create user/role with the requested permissions.
testCtx.createUserAndRole(ctx, t, user, role, []string{types.Wildcard}, []string{types.Wildcard})

conns := make([]*pgconn.PgConn, 0)
defer func() {
for _, conn := range conns {
err := conn.Close(ctx)
require.NoError(t, err)
}
}()

for i := 0; i < connLimitNumber; i++ {
// Try to connect to the database as this user.
pgConn, err := testCtx.postgresClient(ctx, user, "postgres", dbUser, dbName)
require.NoError(t, err)

conns = append(conns, pgConn)
}

_, err = testCtx.postgresClient(ctx, user, "postgres", dbUser, dbName)
require.Error(t, err)
require.Contains(t, err.Error(), "exceeded connection limit")
}
19 changes: 10 additions & 9 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,8 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) {
}
}
if c.Limiter == nil {
// Set default limiter if one is not provided.
connLimiter := limiter.Config{}
defaults.ConfigureLimiter(&connLimiter)

c.Limiter, err = limiter.NewConnectionsLimiter(connLimiter)
// Use default limiter if nothing is provided. Connection limiting will be disabled.
c.Limiter, err = limiter.NewConnectionsLimiter(limiter.Config{})
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -661,12 +658,16 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) error {
}

clientIP := sessionCtx.Identity.ClientIP
s.log.Debugf("Real client IP %s", clientIP)
if clientIP != "" {
s.log.Debugf("Real client IP %s", clientIP)

if err := s.cfg.Limiter.AcquireConnection(clientIP); err != nil {
return trace.WrapWithMessage(err, "Exceeded connection limit.")
if err := s.cfg.Limiter.AcquireConnection(clientIP); err != nil {
return trace.LimitExceeded("client %v exceeded connection limit", clientIP)
}
defer s.cfg.Limiter.ReleaseConnection(clientIP)
} else {
s.log.Debug("ClientIP is not set (Proxy Service has to be updated). Rate limiting is disabled.")
}
defer s.cfg.Limiter.ReleaseConnection(clientIP)

streamWriter, err := s.newStreamWriter(sessionCtx)
if err != nil {
Expand Down
54 changes: 54 additions & 0 deletions lib/utils/net.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
Copyright 2021 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package utils

import (
"net"

"github.com/gravitational/trace"
)

// ClientIPFromConn extracts host from provided remote address.
func ClientIPFromConn(conn net.Conn) (string, error) {
clientRemoteAddr := conn.RemoteAddr()

clientIP, _, err := net.SplitHostPort(clientRemoteAddr.String())
if err != nil {
return "", trace.Wrap(err)
}

return clientIP, nil
}

// ConnCloseWrapper is a helper struct that allows to call additional function
// before net.Conn.Close() is called.
type ConnCloseWrapper struct {
// Underlying connection.
net.Conn
// BeforeClose will be called on Close(), before net.Conn.Close().
BeforeClose func()
}

// Close calls user provided beforeClose function and then Close() from
// underlying connection.
func (c *ConnCloseWrapper) Close() error {
if c.BeforeClose != nil {
c.BeforeClose()
}

return c.Conn.Close()
}

0 comments on commit 0480908

Please sign in to comment.