Skip to content

Commit

Permalink
Add "limiter" support to database service (#9087)
Browse files Browse the repository at this point in the history
Add rate and connection limiter to database service.
  • Loading branch information
jakule authored Jan 7, 2022
1 parent 622e0aa commit c7c9411
Show file tree
Hide file tree
Showing 19 changed files with 667 additions and 69 deletions.
1 change: 1 addition & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ type Server struct {
// if not set, cache uses itself
cache Cache

// limiter limits the number of active connections per client IP.
limiter *limiter.ConnectionsLimiter

// Emitter is events emitter, used to submit discrete events
Expand Down
1 change: 1 addition & 0 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ func ApplyFileConfig(fc *FileConfig, cfg *service.Config) error {
&cfg.SSH.Limiter,
&cfg.Auth.Limiter,
&cfg.Proxy.Limiter,
&cfg.Databases.Limiter,
&cfg.Kube.Limiter,
&cfg.WindowsDesktop.ConnLimiter,
}
Expand Down
1 change: 0 additions & 1 deletion lib/limiter/connlimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ func (l *ConnectionsLimiter) AcquireConnection(token string) error {

// ReleaseConnection decrements the counter
func (l *ConnectionsLimiter) ReleaseConnection(token string) {

l.Lock()
defer l.Unlock()

Expand Down
25 changes: 24 additions & 1 deletion lib/limiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,31 @@ func (l *Limiter) RegisterRequestWithCustomRate(token string, customRate *rateli
return l.rateLimiter.RegisterRequest(token, customRate)
}

// Add limiter to the handle
// WrapHandle adds limiter to the handle
func (l *Limiter) WrapHandle(h http.Handler) {
l.rateLimiter.Wrap(h)
l.ConnLimiter.Wrap(l.rateLimiter)
}

// RegisterRequestAndConnection register a rate and connection limiter for a given token. Close function is returned,
// and it must be called to release the token. When a limit is hit an error is returned.
// Example usage:
//
// release, err := limiter.RegisterRequestAndConnection(clientIP)
// if err != nil {
// return trace.Wrap(err)
// }
// defer release()
func (l *Limiter) RegisterRequestAndConnection(token string) (func(), error) {
// Apply rate limiting.
if err := l.RegisterRequest(token); err != nil {
return func() {}, trace.LimitExceeded("rate limit exceeded for %q", token)
}

// Apply connection limiting.
if err := l.AcquireConnection(token); err != nil {
return func() {}, trace.LimitExceeded("exceeded connection limit for %q", token)
}

return func() { l.ReleaseConnection(token) }, nil
}
3 changes: 3 additions & 0 deletions lib/service/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,8 @@ type DatabasesConfig struct {
ResourceMatchers []services.ResourceMatcher
// AWSMatchers match AWS hosted databases.
AWSMatchers []services.AWSMatcher
// Limiter limits the connection and request rates.
Limiter limiter.Config
}

// Database represents a single database that's being proxied.
Expand Down Expand Up @@ -1103,6 +1105,7 @@ func ApplyDefaults(cfg *Config) {

// Databases proxy service is disabled by default.
cfg.Databases.Enabled = false
defaults.ConfigureLimiter(&cfg.Databases.Limiter)

// Metrics service defaults.
cfg.Metrics.Enabled = false
Expand Down
7 changes: 7 additions & 0 deletions lib/service/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db"
Expand Down Expand Up @@ -167,6 +168,11 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) {
return trace.Wrap(err)
}

connLimiter, err := limiter.NewLimiter(process.Config.Databases.Limiter)
if err != nil {
return trace.Wrap(err)
}

// Create and start the database service.
dbService, err := db.New(process.ExitContext(), db.Config{
Clock: process.Clock,
Expand All @@ -179,6 +185,7 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) {
},
Authorizer: authorizer,
TLSConfig: tlsConfig,
Limiter: connLimiter,
GetRotation: process.getRotation,
Hostname: process.Config.Hostname,
HostID: process.Config.HostUUID,
Expand Down
5 changes: 5 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3120,13 +3120,18 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
if err != nil {
return trace.Wrap(err)
}
connLimiter, err := limiter.NewLimiter(process.Config.Databases.Limiter)
if err != nil {
return trace.Wrap(err)
}
dbProxyServer, err := db.NewProxyServer(process.ExitContext(),
db.ProxyServerConfig{
AuthClient: conn.Client,
AccessPoint: accessPoint,
Authorizer: authorizer,
Tunnel: tsrv,
TLSConfig: tlsConfig,
Limiter: connLimiter,
Emitter: asyncEmitter,
Clock: process.Clock,
ServerID: cfg.HostUUID,
Expand Down
12 changes: 11 additions & 1 deletion lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/reversetunnel"
Expand Down Expand Up @@ -520,7 +521,7 @@ type testModules struct {

func (m *testModules) Features() modules.Features {
return modules.Features{
DB: false, // Explicily turn off database access.
DB: false, // Explicitly turn off database access.
}
}

Expand Down Expand Up @@ -938,6 +939,9 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa
testCtx.fakeRemoteSite,
},
}
// Empty config means no limit.
connLimiter, err := limiter.NewLimiter(limiter.Config{})
require.NoError(t, err)

// Create test audit events emitter.
testCtx.emitter = newTestEmitter()
Expand All @@ -949,6 +953,7 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa
Authorizer: proxyAuthorizer,
Tunnel: tunnel,
TLSConfig: tlsConfig,
Limiter: connLimiter,
Emitter: testCtx.emitter,
Clock: testCtx.clock,
ServerID: "proxy-server",
Expand Down Expand Up @@ -1021,6 +1026,10 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p a
})
require.NoError(t, err)

// Create default limiter.
connLimiter, err := limiter.NewLimiter(limiter.Config{})
require.NoError(t, err)

// Create database server agent itself.
server, err := New(ctx, Config{
Clock: clockwork.NewFakeClockAt(time.Now()),
Expand All @@ -1032,6 +1041,7 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p a
Hostname: constants.APIDomain,
HostID: p.HostID,
TLSConfig: tlsConfig,
Limiter: connLimiter,
Auth: testAuth,
Databases: p.Databases,
ResourceMatchers: p.ResourceMatchers,
Expand Down
20 changes: 18 additions & 2 deletions lib/srv/db/common/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,34 @@ type Proxy interface {
HandleConnection(context.Context, net.Conn) error
}

// ConnectParams keeps parameters used when connecting to Service.
type ConnectParams struct {
// User is a database username.
User string
// Database is a database name/schema.
Database string
// ClientIP is a client real IP. Currently, used for rate limiting.
ClientIP string
}

// Service defines an interface for connecting to a remote database service.
type Service interface {
// Connect is used to connect to remote database server over reverse tunnel.
Connect(ctx context.Context, user, database string) (net.Conn, *auth.Context, error)
Connect(ctx context.Context, params ConnectParams) (net.Conn, *auth.Context, error)
// Proxy starts proxying between client and service connections.
Proxy(ctx context.Context, authContext *auth.Context, clientConn, serviceConn net.Conn) error
}

// Engine defines an interface for specific database protocol engine such
// as Postgres or MySQL.
type Engine interface {
// InitializeConnection initializes the client connection. No DB connection is made at this point, but a message
// can be sent to a client in a database format.
InitializeConnection(clientConn net.Conn, sessionCtx *Session) error
// SendError sends an error to a client in database encoded format.
// NOTE: Client connection must be initialized before this function is called.
SendError(error)
// HandleConnection proxies the connection received from the proxy to
// the particular database instance.
HandleConnection(context.Context, *Session, net.Conn) error
HandleConnection(context.Context, *Session) error
}
28 changes: 19 additions & 9 deletions lib/srv/db/mongodb/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ type Engine struct {
Clock clockwork.Clock
// Log is used for logging.
Log logrus.FieldLogger
// clientConn is an incoming client connection.
clientConn net.Conn
}

// InitializeConnection initializes the client connection.
func (e *Engine) InitializeConnection(clientConn net.Conn, _ *common.Session) error {
e.clientConn = clientConn
return nil
}

// SendError sends an error to the connected client in MongoDB understandable format.
func (e *Engine) SendError(err error) {
if err != nil && !utils.IsOKNetworkError(err) {
e.replyError(e.clientConn, nil, err)
}
}

// HandleConnection processes the connection from MongoDB proxy coming
Expand All @@ -58,14 +73,9 @@ type Engine struct {
// It handles all necessary startup actions, authorization and acts as a
// middleman between the proxy and the database intercepting and interpreting
// all messages i.e. doing protocol parsing.
func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session, clientConn net.Conn) (err error) {
defer func() {
if err != nil && !utils.IsOKNetworkError(err) {
e.replyError(clientConn, nil, err)
}
}()
func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session) error {
// Check that the user has access to the database.
err = e.authorizeConnection(ctx, sessionCtx)
err := e.authorizeConnection(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err, "error authorizing database access")
}
Expand All @@ -84,11 +94,11 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)
// Start reading client messages and sending them to server.
for {
clientMessage, err := protocol.ReadMessage(clientConn)
clientMessage, err := protocol.ReadMessage(e.clientConn)
if err != nil {
return trace.Wrap(err)
}
err = e.handleClientMessage(ctx, sessionCtx, clientMessage, clientConn, serverConn)
err = e.handleClientMessage(ctx, sessionCtx, clientMessage, e.clientConn, serverConn)
if err != nil {
return trace.Wrap(err)
}
Expand Down
35 changes: 21 additions & 14 deletions lib/srv/db/mysql/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ type Engine struct {
Clock clockwork.Clock
// Log is used for logging.
Log logrus.FieldLogger
// proxyConn is a client connection.
proxyConn server.Conn
}

// InitializeConnection initializes the engine with client connection.
func (e *Engine) InitializeConnection(clientConn net.Conn, _ *common.Session) error {
// Make server conn to get access to protocol's WriteOK/WriteError methods.
e.proxyConn = server.Conn{Conn: packet.NewConn(clientConn)}
return nil
}

// SendError sends an error to connected client in the MySQL understandable format.
func (e *Engine) SendError(err error) {
if writeErr := e.proxyConn.WriteError(err); writeErr != nil {
e.Log.WithError(writeErr).Debugf("Failed to send error %q to MySQL client.", err)
}
}

// HandleConnection processes the connection from MySQL proxy coming
Expand All @@ -66,18 +82,9 @@ type Engine struct {
// It handles all necessary startup actions, authorization and acts as a
// middleman between the proxy and the database intercepting and interpreting
// all messages i.e. doing protocol parsing.
func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session, clientConn net.Conn) (err error) {
// Make server conn to get access to protocol's WriteOK/WriteError methods.
proxyConn := server.Conn{Conn: packet.NewConn(clientConn)}
defer func() {
if err != nil {
if writeErr := proxyConn.WriteError(err); writeErr != nil {
e.Log.WithError(writeErr).Debugf("Failed to send error %q to MySQL client.", err)
}
}
}()
func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session) error {
// Perform authorization checks.
err = e.checkAccess(ctx, sessionCtx)
err := e.checkAccess(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -97,7 +104,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
}()
// Send back OK packet to indicate auth/connect success. At this point
// the original client should consider the connection phase completed.
err = proxyConn.WriteOK(nil)
err = e.proxyConn.WriteOK(nil)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -106,8 +113,8 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
// Copy between the connections.
clientErrCh := make(chan error, 1)
serverErrCh := make(chan error, 1)
go e.receiveFromClient(clientConn, serverConn, clientErrCh, sessionCtx)
go e.receiveFromServer(serverConn, clientConn, serverErrCh)
go e.receiveFromClient(e.proxyConn.Conn, serverConn, clientErrCh, sessionCtx)
go e.receiveFromServer(serverConn, e.proxyConn.Conn, serverErrCh)
select {
case err := <-clientErrCh:
e.Log.WithError(err).Debug("Client done.")
Expand Down
24 changes: 22 additions & 2 deletions lib/srv/db/mysql/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ import (

"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/mysql/protocol"
"github.com/gravitational/teleport/lib/utils"

"github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/server"
Expand All @@ -48,6 +50,8 @@ type Proxy struct {
Service common.Service
// Log is used for logging.
Log logrus.FieldLogger
// Limiter limits the number of active connections per client IP.
Limiter *limiter.Limiter
}

// HandleConnection accepts connection from a MySQL client, authenticates
Expand All @@ -58,7 +62,7 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err
// proxy protocol which otherwise would interfere with MySQL protocol.
conn := multiplexer.NewConn(clientConn)
server := p.makeServer(conn)
// If any error happens, make sure to send it back to the client so it
// If any error happens, make sure to send it back to the client, so it
// has a chance to close the connection from its side.
defer func() {
if r := recover(); r != nil {
Expand All @@ -81,7 +85,23 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err
if err != nil {
return trace.Wrap(err)
}
serviceConn, authContext, err := p.Service.Connect(ctx, server.GetUser(), server.GetDatabase())

clientIP, err := utils.ClientIPFromConn(clientConn)
if err != nil {
return trace.Wrap(err)
}
// Apply connection and rate limiting.
releaseConn, err := p.Limiter.RegisterRequestAndConnection(clientIP)
if err != nil {
return trace.Wrap(err)
}
defer releaseConn()

serviceConn, authContext, err := p.Service.Connect(ctx, common.ConnectParams{
User: server.GetUser(),
Database: server.GetDatabase(),
ClientIP: clientIP,
})
if err != nil {
return trace.Wrap(err)
}
Expand Down
Loading

0 comments on commit c7c9411

Please sign in to comment.