Skip to content

Commit

Permalink
refactor: decouple conflicting ModuleContext requirements (#2376)
Browse files Browse the repository at this point in the history
The ModuleContext was designed to be an abstract data model in the
Controller for the resources required by a module, but along the way it
started to be used for storing DB connections for use by the go-runtime.
This change cleanly separates those requirements so that the go-runtime
is entirely responsible for creating new connections from the DSN
provided by the ModuleContext.

I think there's a bit more work to be done here, in that the
ModuleContext knows about testing in the go-runtime, which it really
shouldn't, but this will unblock #2373 for now.
  • Loading branch information
alecthomas authored Aug 15, 2024
1 parent af62b9e commit e339fb6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
22 changes: 18 additions & 4 deletions go-runtime/ftl/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"

"github.com/alecthomas/types/once"
_ "github.com/jackc/pgx/v5/stdlib" // Register Postgres driver

"github.com/TBD54566975/ftl/internal/modulecontext"
Expand All @@ -13,24 +14,37 @@ import (
type Database struct {
Name string
DBType modulecontext.DBType

db *once.Handle[*sql.DB]
}

// PostgresDatabase returns a handler for the named database.
func PostgresDatabase(name string) Database {
return Database{
Name: name,
DBType: modulecontext.DBTypePostgres,
db: once.Once(func(ctx context.Context) (*sql.DB, error) {
provider := modulecontext.FromContext(ctx).CurrentContext()
dsn, err := provider.GetDatabase(name, modulecontext.DBTypePostgres)
if err != nil {
return nil, fmt.Errorf("failed to get database %q: %w", name, err)
}
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, fmt.Errorf("failed to open database %q: %w", name, err)
}
return db, nil
}),
}
}

func (d Database) String() string { return fmt.Sprintf("database %q", d.Name) }

// Get returns the sql db connection for the database.
// Get returns the SQL DB connection for the database.
func (d Database) Get(ctx context.Context) *sql.DB {
provider := modulecontext.FromContext(ctx).CurrentContext()
db, err := provider.GetDatabase(d.Name, d.DBType)
db, err := d.db.Get(ctx)
if err != nil {
panic(err.Error())
panic(err)
}
return db
}
7 changes: 0 additions & 7 deletions internal/modulecontext/database.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package modulecontext

import (
"database/sql"
"fmt"
"strconv"

Expand All @@ -14,19 +13,13 @@ type Database struct {
DSN string
DBType DBType
isTestDB bool
db *sql.DB
}

// NewDatabase creates a Database that can be added to ModuleContext
func NewDatabase(dbType DBType, dsn string) (Database, error) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return Database{}, fmt.Errorf("failed to bring up DB connection: %w", err)
}
return Database{
DSN: dsn,
DBType: dbType,
db: db,
}, nil
}

Expand Down
18 changes: 9 additions & 9 deletions internal/modulecontext/module_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package modulecontext

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -130,22 +129,23 @@ func (m ModuleContext) GetSecret(name string, value any) error {
return json.Unmarshal(data, value)
}

// GetDatabase gets a database connection
// GetDatabase gets a database DSN by name and type.
//
// Returns an error if no database with that name is found or it is not the expected type
// When in a testing context (via ftltest), an error is returned if the database is not a test database
func (m ModuleContext) GetDatabase(name string, dbType DBType) (*sql.DB, error) {
// Returns an error if no database with that name is found or it is not the
// expected type. When in a testing context (via ftltest), an error is returned
// if the database is not a test database.
func (m ModuleContext) GetDatabase(name string, dbType DBType) (string, error) {
db, ok := m.databases[name]
if !ok {
return nil, fmt.Errorf("missing DSN for database %s", name)
return "", fmt.Errorf("missing DSN for database %s", name)
}
if db.DBType != dbType {
return nil, fmt.Errorf("database %s does not match expected type of %s", name, dbType)
return "", fmt.Errorf("database %s does not match expected type of %s", name, dbType)
}
if m.isTesting && !db.isTestDB {
return nil, fmt.Errorf("accessing non-test database %q while testing: try adding ftltest.WithDatabase(db) as an option with ftltest.Context(...)", name)
return "", fmt.Errorf("accessing non-test database %q while testing: try adding ftltest.WithDatabase(db) as an option with ftltest.Context(...)", name)
}
return db.db, nil
return db.DSN, nil
}

// LeaseClient is the interface for acquiring, heartbeating and releasing leases
Expand Down

0 comments on commit e339fb6

Please sign in to comment.