Skip to content

Commit

Permalink
sqlproxyccl: handle black hole sql servers
Browse files Browse the repository at this point in the history
Previously, if a sql server did not respond to the TLS handshake, the
sql proxy would wait forever. This could happen in production if a sql
server is overloaded. It can also cause test flakes if a port is reused
by something that does not understand the pgwire protocol.

Release Note: None
Fixes: cockroachdb#106554
Part of: cockroachdb#105402
  • Loading branch information
jeffswenson committed Jul 11, 2023
1 parent 71ff6f9 commit 6116fac
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 67 deletions.
107 changes: 66 additions & 41 deletions pkg/ccl/sqlproxyccl/backend_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
package sqlproxyccl

import (
"context"
"crypto/tls"
"encoding/binary"
"io"
"net"
"time"

"github.com/cockroachdb/errors"
"github.com/jackc/pgproto3/v2"
Expand All @@ -22,59 +22,84 @@ import (
// BackendDial is an example backend dialer that does a TCP/IP connection
// to a backend, SSL and forwards the start message. It is defined as a variable
// so it can be redirected for testing.
//
// BackendDial uses a dial timeout of 5 seconds to mitigate network black
// holes.
//
// TODO(jaylim-crl): Move dialer into connector in the future.
var BackendDial = func(
msg *pgproto3.StartupMessage, serverAddress string, tlsConfig *tls.Config,
) (_ net.Conn, retErr error) {
// TODO(JeffSwenson): This behavior may need to change once multi-region
// multi-tenant clusters are supported. The fixed timeout may need to be
// replaced by an adaptive timeout or the timeout could be replaced by
// speculative retries.
conn, err := net.DialTimeout("tcp", serverAddress, time.Second*5)
ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string, tlsConfig *tls.Config,
) (net.Conn, error) {
var d net.Dialer

conn, err := d.DialContext(ctx, "tcp", serverAddress)
if err != nil {
return nil, withCode(
errors.Wrap(err, "unable to reach backend SQL server"),
codeBackendDown)
}

// Ensure that conn is closed whenever BackendDial returns an error.
defer func() {
if retErr != nil {
conn.Close()
}
}()

// Try to upgrade the PG connection to use SSL.
if tlsConfig != nil {
// Send SSLRequest.
if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil {
return nil, withCode(
errors.Wrap(err, "sending SSLRequest to target server"),
codeBackendDown)
err = func() error {
// If the context is cancelled during the negotiation process, close the
// connection. Closing the connection unblocks active reads or writes on
// the connection.
removeCancelHook := closeWhenCancelled(ctx, conn)
defer removeCancelHook()

if tlsConfig != nil {
// Send SSLRequest.
if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil {
return withCode(
errors.Wrap(err, "sending SSLRequest to target server"),
codeBackendDown)
}
response := make([]byte, 1)
if _, err = io.ReadFull(conn, response); err != nil {
return withCode(
errors.New("reading response to SSLRequest"),
codeBackendDown)
}
if response[0] != pgAcceptSSLRequest {
return withCode(
errors.New("target server refused TLS connection"),
codeBackendRefusedTLS)
}
conn = tls.Client(conn, tlsConfig.Clone())
}
response := make([]byte, 1)
if _, err = io.ReadFull(conn, response); err != nil {
return nil, withCode(
errors.New("reading response to SSLRequest"),

// Forward startup message to the backend connection.
if _, err := conn.Write(msg.Encode(nil)); err != nil {
return withCode(
errors.Wrapf(err, "relaying StartupMessage to target server %v", serverAddress),
codeBackendDown)
}
if response[0] != pgAcceptSSLRequest {
return nil, withCode(
errors.New("target server refused TLS connection"),
codeBackendRefusedTLS)
}
conn = tls.Client(conn, tlsConfig.Clone())
}

// Forward startup message to the backend connection.
if _, err := conn.Write(msg.Encode(nil)); err != nil {
return nil, withCode(
errors.Wrapf(err, "relaying StartupMessage to target server %v", serverAddress),
codeBackendDown)
return nil
}()
if ctx.Err() != nil {
// If the context is cancelled, overwrite the error because closing the
// connection caused the connection to fail at an arbitrary step.
err = withCode(
errors.Wrapf(ctx.Err(), "unable to negotiate connection with %s", serverAddress),
codeBackendDown,
)
}
if err != nil {
_ = conn.Close()
return nil, err
}

return conn, nil
}

// closeWhenCancelled will close the connection if the context is cancelled
// before the cleanup function is called.
func closeWhenCancelled(ctx context.Context, conn net.Conn) (cleanup func()) {
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
conn.Close()
case <-done:
// Do nothing because the cleanup function was called.
}
}()
return func() { close(done) }
}
31 changes: 29 additions & 2 deletions pkg/ccl/sqlproxyccl/backend_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"path/filepath"
"testing"
"time"

"github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/roachpb"
Expand All @@ -36,12 +38,37 @@ func TestBackendDialTLSInsecure(t *testing.T) {
sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true})
defer sql.Stopper().Stop(ctx)

conn, err := BackendDial(startupMsg, sql.ServingSQLAddr(), &tls.Config{})
conn, err := BackendDial(context.Background(), startupMsg, sql.ServingSQLAddr(), &tls.Config{})
require.Error(t, err)
require.Regexp(t, "target server refused TLS connection", err)
require.Nil(t, conn)
}

func TestBackendDialBlackhole(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
conChannel := make(chan net.Conn, 1)
go func() {
// accept then ignore the connection
conn, err := listener.Accept()
require.NoError(t, err)
conChannel <- conn
}()

startupMsg := &pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber}

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

_, err = BackendDial(ctx, startupMsg, listener.Addr().String(), &tls.Config{})
require.Error(t, err)
require.ErrorIs(t, err, ctx.Err())
(<-conChannel).Close()
}

func TestBackendDialTLS(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)
Expand Down Expand Up @@ -118,7 +145,7 @@ func TestBackendDialTLS(t *testing.T) {
tenantConfig, err := tlsConfigForTenant(tenantID, tc.addr, tlsConfig)
require.NoError(t, err)

conn, err := BackendDial(startupMsg, tc.addr, tenantConfig)
conn, err := BackendDial(context.Background(), startupMsg, tc.addr, tenantConfig)

if tc.errCode != codeNone {
require.Equal(t, tc.errCode, getErrorCode(err))
Expand Down
15 changes: 12 additions & 3 deletions pkg/ccl/sqlproxyccl/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (c *connector) dialTenantCluster(
serverAssignment := balancer.NewServerAssignment(
c.TenantID, c.Balancer.GetTracker(), requester, serverAddr,
)
crdbConn, err = c.dialSQLServer(serverAssignment)
crdbConn, err = c.dialSQLServer(ctx, serverAssignment)
if err != nil {
// Clean up the server assignment in case of an error. If there
// was no error, the cleanup process is merged with net.Conn.Close().
Expand Down Expand Up @@ -349,7 +349,7 @@ func (c *connector) lookupAddr(ctx context.Context) (string, error) {
// transient, this will return an error that has been marked with
// errRetryConnectorSentinel (i.e. markAsRetriableConnectorError).
func (c *connector) dialSQLServer(
serverAssignment *balancer.ServerAssignment,
ctx context.Context, serverAssignment *balancer.ServerAssignment,
) (_ net.Conn, retErr error) {
if c.testingKnobs.dialSQLServer != nil {
return c.testingKnobs.dialSQLServer(serverAssignment)
Expand All @@ -363,7 +363,16 @@ func (c *connector) dialSQLServer(
}
}

conn, err := BackendDial(c.StartupMsg, serverAssignment.Addr(), tlsConf)
// TODO(JeffSwenson): The five second time out is pretty mediocre. It's too
// short if the sql server is overloaded and too long if everything is
// working the way it should. Ideally the fixed the timeout would be replaced
// by an adaptive timeout or maybe speculative retries on a different server.
var conn net.Conn
err := timeutil.RunWithTimeout(ctx, "backend-dial", time.Second*5, func(ctx context.Context) error {
var err error
conn, err = BackendDial(ctx, c.StartupMsg, serverAssignment.Addr(), tlsConf)
return err
})
if err != nil {
if getErrorCode(err) == codeBackendDown {
return nil, markAsRetriableConnectorError(err)
Expand Down
18 changes: 9 additions & 9 deletions pkg/ccl/sqlproxyccl/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ func TestConnector_dialSQLServer(t *testing.T) {
defer crdbConn.Close()

defer testutils.TestingHook(&BackendDial,
func(msg *pgproto3.StartupMessage, serverAddress string,
func(ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string,
tlsConfig *tls.Config) (net.Conn, error) {
require.Equal(t, c.StartupMsg, msg)
require.Equal(t, "10.11.12.13:80", serverAddress)
Expand All @@ -800,7 +800,7 @@ func TestConnector_dialSQLServer(t *testing.T) {
)()

sa := balancer.NewServerAssignment(tenantID, tracker, nil, "10.11.12.13:80")
conn, err := c.dialSQLServer(sa)
conn, err := c.dialSQLServer(ctx, sa)
require.NoError(t, err)
defer conn.Close()

Expand All @@ -823,7 +823,7 @@ func TestConnector_dialSQLServer(t *testing.T) {
sa := balancer.NewServerAssignment(tenantID, tracker, nil, "!@#$::")
defer sa.Close()

conn, err := c.dialSQLServer(sa)
conn, err := c.dialSQLServer(ctx, sa)
require.Error(t, err)
require.Regexp(t, "invalid address format", err)
require.False(t, isRetriableConnectorError(err))
Expand All @@ -836,7 +836,7 @@ func TestConnector_dialSQLServer(t *testing.T) {
defer crdbConn.Close()

defer testutils.TestingHook(&BackendDial,
func(msg *pgproto3.StartupMessage, serverAddress string,
func(ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string,
tlsConfig *tls.Config) (net.Conn, error) {
require.Equal(t, c.StartupMsg, msg)
require.Equal(t, "10.11.12.13:1234", serverAddress)
Expand All @@ -845,7 +845,7 @@ func TestConnector_dialSQLServer(t *testing.T) {
},
)()
sa := balancer.NewServerAssignment(tenantID, tracker, nil, "10.11.12.13:1234")
conn, err := c.dialSQLServer(sa)
conn, err := c.dialSQLServer(ctx, sa)
require.NoError(t, err)
defer conn.Close()

Expand All @@ -863,7 +863,7 @@ func TestConnector_dialSQLServer(t *testing.T) {
t.Run("failed to dial with non-transient error", func(t *testing.T) {
c := &connector{StartupMsg: &pgproto3.StartupMessage{}}
defer testutils.TestingHook(&BackendDial,
func(msg *pgproto3.StartupMessage, serverAddress string,
func(ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string,
tlsConfig *tls.Config) (net.Conn, error) {
require.Equal(t, c.StartupMsg, msg)
require.Equal(t, "127.0.0.1:1234", serverAddress)
Expand All @@ -874,7 +874,7 @@ func TestConnector_dialSQLServer(t *testing.T) {
sa := balancer.NewServerAssignment(tenantID, tracker, nil, "127.0.0.1:1234")
defer sa.Close()

conn, err := c.dialSQLServer(sa)
conn, err := c.dialSQLServer(ctx, sa)
require.EqualError(t, err, "foo")
require.False(t, isRetriableConnectorError(err))
require.Nil(t, conn)
Expand All @@ -883,7 +883,7 @@ func TestConnector_dialSQLServer(t *testing.T) {
t.Run("failed to dial with transient error", func(t *testing.T) {
c := &connector{StartupMsg: &pgproto3.StartupMessage{}}
defer testutils.TestingHook(&BackendDial,
func(msg *pgproto3.StartupMessage, serverAddress string,
func(ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string,
tlsConfig *tls.Config) (net.Conn, error) {
require.Equal(t, c.StartupMsg, msg)
require.Equal(t, "127.0.0.2:4567", serverAddress)
Expand All @@ -894,7 +894,7 @@ func TestConnector_dialSQLServer(t *testing.T) {
sa := balancer.NewServerAssignment(tenantID, tracker, nil, "127.0.0.2:4567")
defer sa.Close()

conn, err := c.dialSQLServer(sa)
conn, err := c.dialSQLServer(ctx, sa)
require.EqualError(t, err, "codeBackendDown: bar")
require.True(t, isRetriableConnectorError(err))
require.Nil(t, conn)
Expand Down
Loading

0 comments on commit 6116fac

Please sign in to comment.