Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: tests for the pgproxy protocol #3440

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions internal/pgproxy/pgproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,34 +49,37 @@ func (p *PgProxy) Start(ctx context.Context) error {
logger.Errorf(err, "failed to accept connection")
continue
}
go p.handleConnection(ctx, conn)
go HandleConnection(ctx, conn, p.connectionStringFn)
}
}

func (p *PgProxy) handleConnection(ctx context.Context, conn net.Conn) {
// HandleConnection proxies a single connection.
//
// This should be run as the first thing after accepting a connection.
// It will block until the connection is closed.
func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstructor) {
defer conn.Close()

logger := log.FromContext(ctx)
logger.Infof("new connection established: %s", conn.RemoteAddr())

backend, startup, err := p.connectBackend(ctx, conn)
backend, startup, err := connectBackend(ctx, conn)
if err != nil {
logger.Errorf(err, "failed to connect backend")
return
}
logger.Debugf("startup message: %+v", startup)
logger.Debugf("backend connected: %s", conn.RemoteAddr())

frontend, err := p.connectFrontend(ctx, startup)
dsn, err := connectionFn(ctx, startup.Parameters)
if err != nil {
logger.Errorf(err, "failed to connect frontend")
backend.Send(&pgproto3.ErrorResponse{
Severity: "FATAL",
Message: err.Error(),
})
if err := backend.Flush(); err != nil {
logger.Errorf(err, "failed to flush backend error response")
}
handleBackendError(ctx, backend, err)
return
}

frontend, err := connectFrontend(ctx, dsn)
if err != nil {
handleBackendError(ctx, backend, err)
return
}
logger.Debugf("frontend connected")
Expand All @@ -88,15 +91,27 @@ func (p *PgProxy) handleConnection(ctx context.Context, conn net.Conn) {
return
}

if err := p.proxy(ctx, backend, frontend); err != nil {
if err := proxy(ctx, backend, frontend); err != nil {
logger.Warnf("disconnecting %s due to: %s", conn.RemoteAddr(), err)
return
}
logger.Infof("terminating connection to %s", conn.RemoteAddr())
}

func handleBackendError(ctx context.Context, backend *pgproto3.Backend, err error) {
logger := log.FromContext(ctx)
logger.Errorf(err, "backend error")
backend.Send(&pgproto3.ErrorResponse{
Severity: "FATAL",
Message: err.Error(),
})
if err := backend.Flush(); err != nil {
logger.Errorf(err, "failed to flush backend error response")
}
}

// connectBackend establishes a connection according to https://www.postgresql.org/docs/current/protocol-flow.html
func (p *PgProxy) connectBackend(_ context.Context, conn net.Conn) (*pgproto3.Backend, *pgproto3.StartupMessage, error) {
func connectBackend(_ context.Context, conn net.Conn) (*pgproto3.Backend, *pgproto3.StartupMessage, error) {
backend := pgproto3.NewBackend(conn, conn)

for {
Expand Down Expand Up @@ -127,12 +142,7 @@ func (p *PgProxy) connectBackend(_ context.Context, conn net.Conn) (*pgproto3.Ba
}
}

func (p *PgProxy) connectFrontend(ctx context.Context, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, error) {
dsn, err := p.connectionStringFn(ctx, startup.Parameters)
if err != nil {
return nil, err
}

func connectFrontend(ctx context.Context, dsn string) (*pgproto3.Frontend, error) {
conn, err := pgconn.Connect(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to backend: %w", err)
Expand All @@ -142,7 +152,7 @@ func (p *PgProxy) connectFrontend(ctx context.Context, startup *pgproto3.Startup
return frontend, nil
}

func (p *PgProxy) proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error {
func proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error {
logger := log.FromContext(ctx)
frontendMessages := make(chan pgproto3.BackendMessage)
backendMessages := make(chan pgproto3.FrontendMessage)
Expand Down
83 changes: 83 additions & 0 deletions internal/pgproxy/pgproxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package pgproxy_test

import (
"context"
"net"
"testing"

"github.com/TBD54566975/ftl/internal/dev"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/pgproxy"
"github.com/alecthomas/assert/v2"
"github.com/jackc/pgx/v5/pgproto3"
)

func TestPgProxy(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
client, proxy := net.Pipe()

dsn, err := dev.SetupPostgres(ctx, "postgres:15.8", 0, false)
assert.NoError(t, err)

frontend := pgproto3.NewFrontend(client, client)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
go pgproxy.HandleConnection(ctx, proxy, func(ctx context.Context, parameters map[string]string) (string, error) {
return dsn, nil
})

t.Run("denies SSL", func(t *testing.T) {
frontend.Send(&pgproto3.SSLRequest{})
assert.NoError(t, frontend.Flush())

assert.Equal(t, readOneByte(t, client), 'N')
})

t.Run("denies GSSEnc", func(t *testing.T) {
frontend.Send(&pgproto3.GSSEncRequest{})
assert.NoError(t, frontend.Flush())

assert.Equal(t, readOneByte(t, client), 'N')
})

t.Run("authenticates with startup message", func(t *testing.T) {
frontend.Send(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{
"user": "ftl",
}})
assert.NoError(t, frontend.Flush())

assertResponseType[*pgproto3.AuthenticationOk](t, frontend)
assertResponseType[*pgproto3.ReadyForQuery](t, frontend)
})

t.Run("proxies a query to the underlying DB", func(t *testing.T) {
frontend.Send(&pgproto3.Query{String: "SELECT 1"})
assert.NoError(t, frontend.Flush())

assertResponseType[*pgproto3.RowDescription](t, frontend)
assertResponseType[*pgproto3.DataRow](t, frontend)
assertResponseType[*pgproto3.CommandComplete](t, frontend)
assertResponseType[*pgproto3.ReadyForQuery](t, frontend)
})
}

func readOneByte(t *testing.T, client net.Conn) byte {
t.Helper()

response := make([]byte, 1)
n, err := client.Read(response)
assert.NoError(t, err)
assert.Equal(t, n, 1)
return response[0]
}

func assertResponseType[T any](t *testing.T, f *pgproto3.Frontend) {
t.Helper()

var zero T
resp, err := f.Receive()
assert.NoError(t, err)
_, ok := resp.(T)
assert.True(t, ok, "expected response type %T, got %T", zero, resp)
}
Loading