diff --git a/pkg/ccl/sqlproxyccl/backend_dialer.go b/pkg/ccl/sqlproxyccl/backend_dialer.go index 1e0bc5db34e8..f9f671ad76a9 100644 --- a/pkg/ccl/sqlproxyccl/backend_dialer.go +++ b/pkg/ccl/sqlproxyccl/backend_dialer.go @@ -9,11 +9,11 @@ package sqlproxyccl import ( + "context" "crypto/tls" "encoding/binary" "io" "net" - "time" "github.com/cockroachdb/errors" "github.com/jackc/pgproto3/v2" @@ -22,59 +22,79 @@ import ( // BackendDial is an example backend dialer that does a TCP/IP connection // to a backend, SSL and forwards the start message. It is defined as a variable // so it can be redirected for testing. -// -// BackendDial uses a dial timeout of 5 seconds to mitigate network black -// holes. -// // TODO(jaylim-crl): Move dialer into connector in the future. var BackendDial = func( - msg *pgproto3.StartupMessage, serverAddress string, tlsConfig *tls.Config, -) (_ net.Conn, retErr error) { - // TODO(JeffSwenson): This behavior may need to change once multi-region - // multi-tenant clusters are supported. The fixed timeout may need to be - // replaced by an adaptive timeout or the timeout could be replaced by - // speculative retries. - conn, err := net.DialTimeout("tcp", serverAddress, time.Second*5) + ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string, tlsConfig *tls.Config, +) (net.Conn, error) { + var d net.Dialer + + conn, err := d.DialContext(ctx, "tcp", serverAddress) if err != nil { return nil, withCode( errors.Wrap(err, "unable to reach backend SQL server"), codeBackendDown) } - - // Ensure that conn is closed whenever BackendDial returns an error. - defer func() { - if retErr != nil { - conn.Close() - } - }() + defer closeIfCancelled(ctx, conn)() // Try to upgrade the PG connection to use SSL. - if tlsConfig != nil { - // Send SSLRequest. - if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil { - return nil, withCode( - errors.Wrap(err, "sending SSLRequest to target server"), - codeBackendDown) + err = func() error { + if tlsConfig != nil { + // Send SSLRequest. + if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil { + return withCode( + errors.Wrap(err, "sending SSLRequest to target server"), + codeBackendDown) + } + response := make([]byte, 1) + if _, err = io.ReadFull(conn, response); err != nil { + return withCode( + errors.New("reading response to SSLRequest"), + codeBackendDown) + } + if response[0] != pgAcceptSSLRequest { + return withCode( + errors.New("target server refused TLS connection"), + codeBackendRefusedTLS) + } + conn = tls.Client(conn, tlsConfig.Clone()) } - response := make([]byte, 1) - if _, err = io.ReadFull(conn, response); err != nil { - return nil, withCode( - errors.New("reading response to SSLRequest"), + + // Forward startup message to the backend connection. + if _, err := conn.Write(msg.Encode(nil)); err != nil { + return withCode( + errors.Wrapf(err, "relaying StartupMessage to target server %v", serverAddress), codeBackendDown) } - if response[0] != pgAcceptSSLRequest { - return nil, withCode( - errors.New("target server refused TLS connection"), - codeBackendRefusedTLS) - } - conn = tls.Client(conn, tlsConfig.Clone()) - } - // Forward startup message to the backend connection. - if _, err := conn.Write(msg.Encode(nil)); err != nil { - return nil, withCode( - errors.Wrapf(err, "relaying StartupMessage to target server %v", serverAddress), - codeBackendDown) + return nil + }() + if ctx.Err() != nil { + // If the context is cancelled, overwrite the error because closing the + // connection caused the connection to fail at an arbitrary step. + err = withCode( + errors.Wrapf(ctx.Err(), "unable to negotiate connection with %s", serverAddress), + codeBackendDown, + ) } + if err != nil { + _ = conn.Close() + return nil, err + } + return conn, nil } + +// closeIfCancelled will close the connection if the context is cancelled +// before the cleanup function is called. +func closeIfCancelled(ctx context.Context, conn net.Conn) (cleanup func()) { + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + conn.Close() + case <-done: + // Do nothing because the cleanup function was called. + } + }() + return func() { close(done) } +} diff --git a/pkg/ccl/sqlproxyccl/backend_dialer_test.go b/pkg/ccl/sqlproxyccl/backend_dialer_test.go index 444284d5fab9..bbb521fdfe72 100644 --- a/pkg/ccl/sqlproxyccl/backend_dialer_test.go +++ b/pkg/ccl/sqlproxyccl/backend_dialer_test.go @@ -12,8 +12,10 @@ import ( "context" "crypto/tls" "crypto/x509" + "net" "path/filepath" "testing" + "time" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/roachpb" @@ -36,12 +38,37 @@ func TestBackendDialTLSInsecure(t *testing.T) { sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true}) defer sql.Stopper().Stop(ctx) - conn, err := BackendDial(startupMsg, sql.ServingSQLAddr(), &tls.Config{}) + conn, err := BackendDial(context.Background(), startupMsg, sql.ServingSQLAddr(), &tls.Config{}) require.Error(t, err) require.Regexp(t, "target server refused TLS connection", err) require.Nil(t, conn) } +func TestBackendDialBlackhole(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + listener, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + conChannel := make(chan net.Conn, 1) + go func() { + // accept then ignore the connection + conn, err := listener.Accept() + require.NoError(t, err) + conChannel <- conn + }() + + startupMsg := &pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber} + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + _, err = BackendDial(ctx, startupMsg, listener.Addr().String(), &tls.Config{}) + require.Error(t, err) + require.ErrorIs(t, err, ctx.Err()) + (<-conChannel).Close() +} + func TestBackendDialTLS(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) @@ -118,7 +145,7 @@ func TestBackendDialTLS(t *testing.T) { tenantConfig, err := tlsConfigForTenant(tenantID, tc.addr, tlsConfig) require.NoError(t, err) - conn, err := BackendDial(startupMsg, tc.addr, tenantConfig) + conn, err := BackendDial(context.Background(), startupMsg, tc.addr, tenantConfig) if tc.errCode != codeNone { require.Equal(t, tc.errCode, getErrorCode(err)) diff --git a/pkg/ccl/sqlproxyccl/connector.go b/pkg/ccl/sqlproxyccl/connector.go index 3c5470c052a5..eb28e9b127ea 100644 --- a/pkg/ccl/sqlproxyccl/connector.go +++ b/pkg/ccl/sqlproxyccl/connector.go @@ -234,7 +234,7 @@ func (c *connector) dialTenantCluster( serverAssignment := balancer.NewServerAssignment( c.TenantID, c.Balancer.GetTracker(), requester, serverAddr, ) - crdbConn, err = c.dialSQLServer(serverAssignment) + crdbConn, err = c.dialSQLServer(ctx, serverAssignment) if err != nil { // Clean up the server assignment in case of an error. If there // was no error, the cleanup process is merged with net.Conn.Close(). @@ -349,7 +349,7 @@ func (c *connector) lookupAddr(ctx context.Context) (string, error) { // transient, this will return an error that has been marked with // errRetryConnectorSentinel (i.e. markAsRetriableConnectorError). func (c *connector) dialSQLServer( - serverAssignment *balancer.ServerAssignment, + ctx context.Context, serverAssignment *balancer.ServerAssignment, ) (_ net.Conn, retErr error) { if c.testingKnobs.dialSQLServer != nil { return c.testingKnobs.dialSQLServer(serverAssignment) @@ -363,7 +363,16 @@ func (c *connector) dialSQLServer( } } - conn, err := BackendDial(c.StartupMsg, serverAssignment.Addr(), tlsConf) + // TODO(JeffSwenson): The five second time out is pretty mediocre. It's too + // short if the sql server is overloaded and too long if everything is + // working the way it should. Ideally the fixed the timeout would be replaced + // by an adaptive timeout or maybe speculative retries on a different server. + var conn net.Conn + err := timeutil.RunWithTimeout(ctx, "backend-dial", time.Second*5, func(ctx context.Context) error { + var err error + conn, err = BackendDial(ctx, c.StartupMsg, serverAssignment.Addr(), tlsConf) + return err + }) if err != nil { if getErrorCode(err) == codeBackendDown { return nil, markAsRetriableConnectorError(err) diff --git a/pkg/ccl/sqlproxyccl/connector_test.go b/pkg/ccl/sqlproxyccl/connector_test.go index ae2979e59c66..7e2ac2f26eb2 100644 --- a/pkg/ccl/sqlproxyccl/connector_test.go +++ b/pkg/ccl/sqlproxyccl/connector_test.go @@ -790,7 +790,7 @@ func TestConnector_dialSQLServer(t *testing.T) { defer crdbConn.Close() defer testutils.TestingHook(&BackendDial, - func(msg *pgproto3.StartupMessage, serverAddress string, + func(ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string, tlsConfig *tls.Config) (net.Conn, error) { require.Equal(t, c.StartupMsg, msg) require.Equal(t, "10.11.12.13:80", serverAddress) @@ -800,7 +800,7 @@ func TestConnector_dialSQLServer(t *testing.T) { )() sa := balancer.NewServerAssignment(tenantID, tracker, nil, "10.11.12.13:80") - conn, err := c.dialSQLServer(sa) + conn, err := c.dialSQLServer(ctx, sa) require.NoError(t, err) defer conn.Close() @@ -823,7 +823,7 @@ func TestConnector_dialSQLServer(t *testing.T) { sa := balancer.NewServerAssignment(tenantID, tracker, nil, "!@#$::") defer sa.Close() - conn, err := c.dialSQLServer(sa) + conn, err := c.dialSQLServer(ctx, sa) require.Error(t, err) require.Regexp(t, "invalid address format", err) require.False(t, isRetriableConnectorError(err)) @@ -836,7 +836,7 @@ func TestConnector_dialSQLServer(t *testing.T) { defer crdbConn.Close() defer testutils.TestingHook(&BackendDial, - func(msg *pgproto3.StartupMessage, serverAddress string, + func(ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string, tlsConfig *tls.Config) (net.Conn, error) { require.Equal(t, c.StartupMsg, msg) require.Equal(t, "10.11.12.13:1234", serverAddress) @@ -845,7 +845,7 @@ func TestConnector_dialSQLServer(t *testing.T) { }, )() sa := balancer.NewServerAssignment(tenantID, tracker, nil, "10.11.12.13:1234") - conn, err := c.dialSQLServer(sa) + conn, err := c.dialSQLServer(ctx, sa) require.NoError(t, err) defer conn.Close() @@ -863,7 +863,7 @@ func TestConnector_dialSQLServer(t *testing.T) { t.Run("failed to dial with non-transient error", func(t *testing.T) { c := &connector{StartupMsg: &pgproto3.StartupMessage{}} defer testutils.TestingHook(&BackendDial, - func(msg *pgproto3.StartupMessage, serverAddress string, + func(ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string, tlsConfig *tls.Config) (net.Conn, error) { require.Equal(t, c.StartupMsg, msg) require.Equal(t, "127.0.0.1:1234", serverAddress) @@ -874,7 +874,7 @@ func TestConnector_dialSQLServer(t *testing.T) { sa := balancer.NewServerAssignment(tenantID, tracker, nil, "127.0.0.1:1234") defer sa.Close() - conn, err := c.dialSQLServer(sa) + conn, err := c.dialSQLServer(ctx, sa) require.EqualError(t, err, "foo") require.False(t, isRetriableConnectorError(err)) require.Nil(t, conn) @@ -883,7 +883,7 @@ func TestConnector_dialSQLServer(t *testing.T) { t.Run("failed to dial with transient error", func(t *testing.T) { c := &connector{StartupMsg: &pgproto3.StartupMessage{}} defer testutils.TestingHook(&BackendDial, - func(msg *pgproto3.StartupMessage, serverAddress string, + func(ctx context.Context, msg *pgproto3.StartupMessage, serverAddress string, tlsConfig *tls.Config) (net.Conn, error) { require.Equal(t, c.StartupMsg, msg) require.Equal(t, "127.0.0.2:4567", serverAddress) @@ -894,7 +894,7 @@ func TestConnector_dialSQLServer(t *testing.T) { sa := balancer.NewServerAssignment(tenantID, tracker, nil, "127.0.0.2:4567") defer sa.Close() - conn, err := c.dialSQLServer(sa) + conn, err := c.dialSQLServer(ctx, sa) require.EqualError(t, err, "codeBackendDown: bar") require.True(t, isRetriableConnectorError(err)) require.Nil(t, conn) diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 23abc08eb643..33d75d01c494 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -549,7 +549,7 @@ func TestLongDBName(t *testing.T) { defer te.Close() defer testutils.TestingHook(&BackendDial, func( - _ *pgproto3.StartupMessage, outgoingAddr string, _ *tls.Config, + _ context.Context, _ *pgproto3.StartupMessage, outgoingAddr string, _ *tls.Config, ) (net.Conn, error) { require.Equal(t, outgoingAddr, "127.0.0.1:26257") return nil, withCode(errors.New("boom"), codeParamsRoutingFailed) @@ -588,7 +588,7 @@ func TestBackendDownRetry(t *testing.T) { callCount := 0 defer testutils.TestingHook(&BackendDial, func( - _ *pgproto3.StartupMessage, outgoingAddr string, _ *tls.Config, + _ context.Context, _ *pgproto3.StartupMessage, outgoingAddr string, _ *tls.Config, ) (net.Conn, error) { callCount++ // After 3 dials, we delete the tenant. @@ -820,7 +820,7 @@ func TestProxyTLSConf(t *testing.T) { defer te.Close() defer testutils.TestingHook(&BackendDial, func( - _ *pgproto3.StartupMessage, _ string, tlsConf *tls.Config, + _ context.Context, _ *pgproto3.StartupMessage, _ string, tlsConf *tls.Config, ) (net.Conn, error) { require.Nil(t, tlsConf) return nil, withCode(errors.New("boom"), codeParamsRoutingFailed) @@ -843,7 +843,7 @@ func TestProxyTLSConf(t *testing.T) { defer te.Close() defer testutils.TestingHook(&BackendDial, func( - _ *pgproto3.StartupMessage, _ string, tlsConf *tls.Config, + _ context.Context, _ *pgproto3.StartupMessage, _ string, tlsConf *tls.Config, ) (net.Conn, error) { require.True(t, tlsConf.InsecureSkipVerify) return nil, withCode(errors.New("boom"), codeParamsRoutingFailed) @@ -867,7 +867,7 @@ func TestProxyTLSConf(t *testing.T) { defer te.Close() defer testutils.TestingHook(&BackendDial, func( - _ *pgproto3.StartupMessage, outgoingAddress string, tlsConf *tls.Config, + _ context.Context, _ *pgproto3.StartupMessage, outgoingAddress string, tlsConf *tls.Config, ) (net.Conn, error) { outgoingHost, _, err := addr.SplitHostPort(outgoingAddress, "") require.NoError(t, err) @@ -987,7 +987,7 @@ func TestProxyModifyRequestParams(t *testing.T) { originalBackendDial := BackendDial defer testutils.TestingHook(&BackendDial, func( - msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, + ctx context.Context, msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, ) (net.Conn, error) { params := msg.Parameters authToken, ok := params["authToken"] @@ -1003,7 +1003,7 @@ func TestProxyModifyRequestParams(t *testing.T) { delete(params, "authToken") params["user"] = "testuser" - return originalBackendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) + return originalBackendDial(ctx, msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) })() s, proxyAddr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) @@ -1125,7 +1125,7 @@ func TestErroneousBackend(t *testing.T) { defer te.Close() defer testutils.TestingHook(&BackendDial, func( - msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, + _ context.Context, msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, ) (net.Conn, error) { return nil, errors.New(backendError) })() @@ -1151,7 +1151,7 @@ func TestProxyRefuseConn(t *testing.T) { defer te.Close() defer testutils.TestingHook(&BackendDial, func( - msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, + _ context.Context, msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, ) (net.Conn, error) { return nil, withCode(errors.New("too many attempts"), codeProxyRefusedConnection) })() @@ -1248,9 +1248,9 @@ func TestDenylistUpdate(t *testing.T) { originalBackendDial := BackendDial defer testutils.TestingHook(&BackendDial, func( - msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, + ctx context.Context, msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, ) (net.Conn, error) { - return originalBackendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) + return originalBackendDial(ctx, msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) })() opts := &ProxyOptions{ @@ -1352,7 +1352,7 @@ func TestDirectoryConnect(t *testing.T) { // Retry the backend connection 3 times before permanent failure. countFailures := 0 defer testutils.TestingHook(&BackendDial, func( - *pgproto3.StartupMessage, string, *tls.Config, + context.Context, *pgproto3.StartupMessage, string, *tls.Config, ) (net.Conn, error) { countFailures++ if countFailures >= 3 {