Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix pg proxy issues, and remove hard coded DSNs #3501

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -1063,18 +1063,8 @@ func (s *Service) CreateDeployment(ctx context.Context, req *connect.Request[ftl
return nil, fmt.Errorf("invalid module schema: %w", err)
}

for _, d := range module.Decls {
if db, ok := d.(*schema.Database); ok && db.Runtime != nil {
key := dsnSecretKey(module.Name, db.Name)

if err := s.sm.Set(ctx, configuration.NewRef(module.Name, key), db.Runtime.DSN); err != nil {
return nil, fmt.Errorf("could not set database secret %s: %w", key, err)
}
logger.Infof("Database declaration: %s -> %s type %s", db.Name, db.Runtime.DSN, db.Type)
}
}

dkey, err := s.dal.CreateDeployment(ctx, ms.Runtime.Language, module, artefacts)

if err != nil {
logger.Errorf(err, "Could not create deployment")
return nil, fmt.Errorf("could not create deployment: %w", err)
Expand Down
5 changes: 3 additions & 2 deletions backend/provisioner/provisioner_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import (
"fmt"
"testing"

in "github.com/TBD54566975/ftl/internal/integration"
"github.com/alecthomas/assert/v2"

in "github.com/TBD54566975/ftl/internal/integration"
)

func TestDeploymentThrougDevProvisionerCreatePostgresDB(t *testing.T) {
func TestDeploymentThroughDevProvisionerCreatePostgresDB(t *testing.T) {
in.Run(t,
in.WithFTLConfig("./ftl-project.toml"),
in.CopyModule("echo"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- migrate:up
CREATE TABLE messages( message TEXT );
-- migrate:down
DROP TABLE messages;
11 changes: 2 additions & 9 deletions backend/provisioner/testdata/go/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,12 @@ func (EchoDBConfig) Name() string { return "echodb" }
//
//ftl:verb export
func Echo(ctx context.Context, req string, db ftl.DatabaseHandle[EchoDBConfig]) (string, error) {
_, err := db.Get(ctx).Exec(`CREATE TABLE IF NOT EXISTS messages(
message TEXT
);`)
_, err := db.Get(ctx).Exec(`INSERT INTO messages (message) VALUES ($1);`, req)
if err != nil {
return "", err
}

_, err = db.Get(ctx).Exec(`INSERT INTO messages (message) VALUES ($1);`, req)
if err != nil {
return "", err
}

rows, err := db.Get(ctx).Query(`SELECT message FROM messages;`)
rows, err := db.Get(ctx).Query(`SELECT DISTINCT message FROM messages;`)
if err != nil {
return "", err
}
Expand Down
9 changes: 0 additions & 9 deletions ftl-project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,9 @@ disable-ide-integration = true
key = "inline://InZhbHVlIg"

[modules]
[modules.database]
[modules.database.secrets]
FTL_DSN_DATABASE_TESTDB = "inline://InBvc3RncmVzOi8vMTI3LjAuMC4xOjE1NDMyL2RhdGFiYXNlX3Rlc3RkYj9zc2xtb2RlPWRpc2FibGVcdTAwMjZ1c2VyPXBvc3RncmVzXHUwMDI2cGFzc3dvcmQ9c2VjcmV0Ig"
[modules.echo]
[modules.echo.configuration]
default = "inline://ImFub255bW91cyI"
[modules.mysql]
[modules.mysql.secrets]
FTL_DSN_MYSQL_TESTDB = "inline://InJvb3Q6c2VjcmV0QHRjcCgxMjcuMC4wLjE6MTMzMDYpL215c3FsX3Rlc3RkYj9hbGxvd05hdGl2ZVBhc3N3b3Jkcz1UcnVlIg"
[modules.test]
[modules.test.configuration]
[modules.test.secrets]

[commands]
startup = ["echo 'FTL startup command ⚡️'"]
76 changes: 45 additions & 31 deletions internal/pgproxy/pgproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ func (p *PgProxy) Start(ctx context.Context, started chan<- Started) error {
// It will block until the connection is closed.
func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstructor) {
defer conn.Close()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

logger := log.FromContext(ctx)
logger.Infof("new connection established: %s", conn.RemoteAddr())
logger.Debugf("new connection established: %s", conn.RemoteAddr())

backend, startup, err := connectBackend(ctx, conn)
if err != nil {
Expand All @@ -90,30 +92,33 @@ func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstr
logger.Infof("client disconnected without startup message: %s", conn.RemoteAddr())
return
}
logger.Debugf("startup message: %+v", startup)
logger.Debugf("backend connected: %s", conn.RemoteAddr())
logger.Tracef("startup message: %+v", startup)
logger.Tracef("backend connected: %s", conn.RemoteAddr())

frontend, err := connectFrontend(ctx, connectionFn, startup)
hijacked, err := connectFrontend(ctx, connectionFn, startup)
if err != nil {
// try again, in case there was a credential rotation
logger.Warnf("failed to connect frontend: %s, trying again", err)
logger.Debugf("failed to connect frontend: %s, trying again", err)

frontend, err = connectFrontend(ctx, connectionFn, startup)
hijacked, err = connectFrontend(ctx, connectionFn, startup)
if err != nil {
handleBackendError(ctx, backend, err)
return
}
}
backend.Send(&pgproto3.AuthenticationOk{})
logger.Debugf("frontend connected")
for key, value := range hijacked.ParameterStatuses {
backend.Send(&pgproto3.ParameterStatus{Name: key, Value: value})
}

backend.Send(&pgproto3.AuthenticationOk{})
backend.Send(&pgproto3.ReadyForQuery{})
backend.Send(&pgproto3.ReadyForQuery{TxStatus: 'I'})
if err := backend.Flush(); err != nil {
logger.Errorf(err, "failed to flush backend authentication ok")
return
}

if err := proxy(ctx, backend, frontend); err != nil {
if err := proxy(ctx, backend, hijacked.Frontend); err != nil {
logger.Warnf("disconnecting %s due to: %s", conn.RemoteAddr(), err)
return
}
Expand Down Expand Up @@ -171,7 +176,7 @@ func connectBackend(ctx context.Context, conn net.Conn) (*pgproto3.Backend, *pgp
}
}

func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, error) {
func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgconn.HijackedConn, error) {
dsn, err := connectionFn(ctx, startup.Parameters)
if err != nil {
return nil, fmt.Errorf("failed to construct dsn: %w", err)
Expand All @@ -181,59 +186,68 @@ func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *
if err != nil {
return nil, fmt.Errorf("failed to connect to backend: %w", err)
}
frontend := pgproto3.NewFrontend(conn.Conn(), conn.Conn())

return frontend, nil
hijacked, err := conn.Hijack()
if err != nil {
return nil, fmt.Errorf("failed to hijack backend: %w", err)
}
return hijacked, nil
}

func proxy(ctx context.Context, backend *pgproto3.Backend, frontend *pgproto3.Frontend) error {
logger := log.FromContext(ctx)
frontendMessages := make(chan pgproto3.BackendMessage)
backendMessages := make(chan pgproto3.FrontendMessage)
errors := make(chan error, 2)

go func() {
for {
msg, err := backend.Receive()
select {
case <-ctx.Done():
return
default:
}
if err != nil {
errors <- fmt.Errorf("failed to receive backend message: %w", err)
return
}
logger.Tracef("backend message: %T", msg)
backendMessages <- msg
frontend.Send(msg)
err = frontend.Flush()
if err != nil {
errors <- fmt.Errorf("failed to receive backend message: %w", err)
return
}
if _, ok := msg.(*pgproto3.Terminate); ok {
return
}
}
}()

go func() {
for {
msg, err := frontend.Receive()
select {
case <-ctx.Done():
return
default:
}
if err != nil {
errors <- fmt.Errorf("failed to receive frontend message: %w", err)
return
}
logger.Tracef("frontend message: %T", msg)
frontendMessages <- msg
backend.Send(msg)
err = backend.Flush()
if err != nil {
errors <- fmt.Errorf("failed to receive backend message: %w", err)
return
}
}
}()

for {
select {
case <-ctx.Done():
return fmt.Errorf("context done: %w", ctx.Err())
case msg := <-backendMessages:
frontend.Send(msg)
if err := frontend.Flush(); err != nil {
return fmt.Errorf("failed to flush frontend message: %w", err)
}

if _, ok := msg.(*pgproto3.Terminate); ok {
return nil
}
case msg := <-frontendMessages:
backend.Send(msg)
if err := backend.Flush(); err != nil {
return fmt.Errorf("failed to flush backend message: %w", err)
}
case err := <-errors:
return err
}
Expand Down
8 changes: 6 additions & 2 deletions internal/pgproxy/pgproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import (
"net"
"testing"

"github.com/alecthomas/assert/v2"
"github.com/jackc/pgx/v5/pgproto3"

"github.com/TBD54566975/ftl/internal/dev"
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/pgproxy"
"github.com/alecthomas/assert/v2"
"github.com/jackc/pgx/v5/pgproto3"
)

func TestPgProxy(t *testing.T) {
Expand Down Expand Up @@ -48,6 +49,9 @@ func TestPgProxy(t *testing.T) {
assert.NoError(t, frontend.Flush())

assertResponseType[*pgproto3.AuthenticationOk](t, frontend)
for range 13 {
assertResponseType[*pgproto3.ParameterStatus](t, frontend)
}
assertResponseType[*pgproto3.ReadyForQuery](t, frontend)
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
import io.quarkus.agroal.spi.JdbcDataSourceBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.GeneratedResourceBuildItem;
import io.quarkus.deployment.builditem.SystemPropertyBuildItem;
import xyz.block.ftl.runtime.FTLDatasourceCredentials;
import xyz.block.ftl.runtime.FTLRecorder;
import xyz.block.ftl.runtime.config.FTLConfigSource;
import xyz.block.ftl.v1.ModuleContextResponse;
import xyz.block.ftl.v1.schema.Database;
import xyz.block.ftl.v1.schema.Decl;

Expand All @@ -21,10 +25,12 @@ public class DatasourceProcessor {
private static final Logger log = Logger.getLogger(DatasourceProcessor.class);

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public SchemaContributorBuildItem registerDatasources(
List<JdbcDataSourceBuildItem> datasources,
BuildProducer<SystemPropertyBuildItem> systemPropProducer,
BuildProducer<GeneratedResourceBuildItem> generatedResourceBuildItemBuildProducer) {
BuildProducer<GeneratedResourceBuildItem> generatedResourceBuildItemBuildProducer,
FTLRecorder recorder) {
log.infof("Processing %d datasource annotations into decls", datasources.size());
List<Decl> decls = new ArrayList<>();
List<String> namedDatasources = new ArrayList<>();
Expand All @@ -37,6 +43,11 @@ public SchemaContributorBuildItem registerDatasources(
// FTL and quarkus use slightly different names
dbKind = "postgres";
}
if (dbKind.equals("mysql")) {
recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.MYSQL);
} else {
recorder.registerDatabase(ds.getName(), ModuleContextResponse.DBType.POSTGRES);
}
//default name is <default> which is not a valid name
String sanitisedName = ds.getName().replace("<", "").replace(">", "");
//we use a dynamic credentials provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -36,6 +38,8 @@ public class FTLController implements LeaseClient {

private static volatile FTLController controller;

private final Map<String, ModuleContextResponse.DBType> databases = new ConcurrentHashMap<>();

/**
* TODO: look at how init should work, this is terrible and will break dev mode
*/
Expand Down Expand Up @@ -71,6 +75,10 @@ public static FTLController instance() {
verbService = VerbServiceGrpc.newStub(channel);
}

public void registerDatabase(String name, ModuleContextResponse.DBType type) {
databases.put(name, type);
}

public byte[] getSecret(String secretName) {
var context = getModuleContext();
if (context.containsSecrets(secretName)) {
Expand All @@ -88,6 +96,10 @@ public byte[] getConfig(String secretName) {
}

public Datasource getDatasource(String name) {
if (databases.get(name) == ModuleContextResponse.DBType.POSTGRES) {
var proxyAddress = System.getenv("FTL_PROXY_POSTGRES_ADDRESS");
return new Datasource("jdbc:postgresql://" + proxyAddress + "/" + name, "ftl", "ftl");
}
List<ModuleContextResponse.DSN> databasesList = getModuleContext().getDatabasesList();
for (var i : databasesList) {
if (i.getName().equals(name)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import xyz.block.ftl.runtime.http.FTLHttpHandler;
import xyz.block.ftl.runtime.http.HTTPVerbInvoker;
import xyz.block.ftl.v1.CallRequest;
import xyz.block.ftl.v1.ModuleContextResponse;

@Recorder
public class FTLRecorder {
Expand Down Expand Up @@ -171,4 +172,8 @@ public void run() {
}
});
}

public void registerDatabase(String dbKind, ModuleContextResponse.DBType name) {
FTLController.instance().registerDatabase(dbKind, name);
}
}
Loading