diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 4125a98e4fca..5e17445aad3f 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -57,6 +57,7 @@ go_test( size = "medium", srcs = [ "authentication_test.go", + "backend_dialer_test.go", "conn_migration_test.go", "connector_test.go", "forwarder_test.go", diff --git a/pkg/ccl/sqlproxyccl/backend_dialer.go b/pkg/ccl/sqlproxyccl/backend_dialer.go index 5448133971fd..475612be211d 100644 --- a/pkg/ccl/sqlproxyccl/backend_dialer.go +++ b/pkg/ccl/sqlproxyccl/backend_dialer.go @@ -25,9 +25,7 @@ import ( // BackendDial uses a dial timeout of 5 seconds to mitigate network black // holes. // -// TODO(jaylim-crl): Move dialer into connector in the future. When moving this -// into the connector, we should be careful as this is also used by CC's -// codebase. +// TODO(jaylim-crl): Move dialer into connector in the future. var BackendDial = func( msg *pgproto3.StartupMessage, serverAddress string, tlsConfig *tls.Config, ) (_ net.Conn, retErr error) { @@ -37,61 +35,36 @@ var BackendDial = func( // speculative retries. conn, err := net.DialTimeout("tcp", serverAddress, time.Second*5) if err != nil { - return nil, newErrorf( - codeBackendDown, "unable to reach backend SQL server: %v", err, - ) + return nil, newErrorf(codeBackendDown, "unable to reach backend SQL server: %v", err) } + + // Ensure that conn is closed whenever BackendDial returns an error. defer func() { if retErr != nil { conn.Close() } }() - conn, err = sslOverlay(conn, tlsConfig) - if err != nil { - return nil, err - } - err = relayStartupMsg(conn, msg) - if err != nil { - return nil, newErrorf( - codeBackendDown, "relaying StartupMessage to target server %v: %v", - serverAddress, err) - } - return conn, nil -} - -// sslOverlay attempts to upgrade the PG connection to use SSL if a tls.Config -// is specified. -func sslOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { - if tlsConfig == nil { - return conn, nil - } - var err error - // Send SSLRequest. - if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil { - return nil, newErrorf( - codeBackendDown, "sending SSLRequest to target server: %v", err, - ) - } - - response := make([]byte, 1) - if _, err = io.ReadFull(conn, response); err != nil { - return nil, - newErrorf(codeBackendDown, "reading response to SSLRequest") + // Try to upgrade the PG connection to use SSL. + if tlsConfig != nil { + // Send SSLRequest. + if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil { + return nil, newErrorf(codeBackendDown, "sending SSLRequest to target server: %v", err) + } + response := make([]byte, 1) + if _, err = io.ReadFull(conn, response); err != nil { + return nil, newErrorf(codeBackendDown, "reading response to SSLRequest") + } + if response[0] != pgAcceptSSLRequest { + return nil, newErrorf(codeBackendRefusedTLS, "target server refused TLS connection") + } + conn = tls.Client(conn, tlsConfig.Clone()) } - if response[0] != pgAcceptSSLRequest { - return nil, newErrorf( - codeBackendRefusedTLS, "target server refused TLS connection", - ) + // Forward startup message to the backend connection. + if _, err := conn.Write(msg.Encode(nil)); err != nil { + return nil, newErrorf(codeBackendDown, + "relaying StartupMessage to target server %v: %v", serverAddress, err) } - - outCfg := tlsConfig.Clone() - return tls.Client(conn, outCfg), nil -} - -// relayStartupMsg forwards the start message on the backend connection. -func relayStartupMsg(conn net.Conn, msg *pgproto3.StartupMessage) (err error) { - _, err = conn.Write(msg.Encode(nil)) - return + return conn, nil } diff --git a/pkg/ccl/sqlproxyccl/backend_dialer_test.go b/pkg/ccl/sqlproxyccl/backend_dialer_test.go new file mode 100644 index 000000000000..c266f4a35c91 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/backend_dialer_test.go @@ -0,0 +1,50 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + pgproto3 "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" +) + +func TestBackendDialTLS(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + startupMsg := &pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber} + tlsConfig := &tls.Config{InsecureSkipVerify: true} + + t.Run("insecure server", func(t *testing.T) { + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true}) + defer sql.Stopper().Stop(ctx) + + conn, err := BackendDial(startupMsg, sql.ServingSQLAddr(), tlsConfig) + require.Error(t, err) + require.Regexp(t, "target server refused TLS connection", err) + require.Nil(t, conn) + }) + + t.Run("secure server", func(t *testing.T) { + sql, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: false}) + defer sql.Stopper().Stop(ctx) + + conn, err := BackendDial(startupMsg, sql.ServingSQLAddr(), tlsConfig) + require.NoError(t, err) + require.NotNil(t, conn) + }) +}