diff --git a/Justfile b/Justfile index a271a6ebd8..d60f20b532 100644 --- a/Justfile +++ b/Justfile @@ -66,7 +66,7 @@ init-db: # Regenerate SQLC code (requires init-db to be run first) build-sqlc: - @mk backend/controller/sql/{db.go,models.go,querier.go,queries.sql.go} : backend/controller/sql/queries.sql backend/controller/sql/schema sqlc.yaml -- "just init-db && sqlc generate" + @mk backend/controller/sql/{db.go,models.go,querier.go,queries.sql.go} common/configuration/sql/{db.go,models.go,querier.go,queries.sql.go} : backend/controller/sql/queries.sql common/configuration/sql/queries.sql backend/controller/sql/schema sqlc.yaml -- "just init-db && sqlc generate" # Build the ZIP files that are embedded in the FTL release binaries build-zips: build-kt-runtime diff --git a/backend/controller/controller.go b/backend/controller/controller.go index 856bd899cf..39be4be094 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -47,6 +47,7 @@ import ( schemapb "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/schema" "github.com/TBD54566975/ftl/backend/schema" cf "github.com/TBD54566975/ftl/common/configuration" + "github.com/TBD54566975/ftl/db/dalerrs" frontend "github.com/TBD54566975/ftl/frontend" "github.com/TBD54566975/ftl/internal/cors" ftlhttp "github.com/TBD54566975/ftl/internal/http" @@ -282,7 +283,7 @@ func New(ctx context.Context, db *dal.DAL, config Config, runnerScaling scaling. func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { routes, err := s.dal.GetIngressRoutes(r.Context(), r.Method) if err != nil { - if errors.Is(err, dal.ErrNotFound) { + if errors.Is(err, dalerrs.ErrNotFound) { http.NotFound(w, r) return } @@ -481,7 +482,7 @@ func (s *Service) UpdateDeploy(ctx context.Context, req *connect.Request[ftlv1.U err = s.dal.SetDeploymentReplicas(ctx, deploymentKey, int(req.Msg.MinReplicas)) if err != nil { - if errors.Is(err, dal.ErrNotFound) { + if errors.Is(err, dalerrs.ErrNotFound) { logger.Errorf(err, "Deployment not found: %s", deploymentKey) return nil, connect.NewError(connect.CodeNotFound, errors.New("deployment not found")) } @@ -503,10 +504,10 @@ func (s *Service) ReplaceDeploy(ctx context.Context, c *connect.Request[ftlv1.Re err = s.dal.ReplaceDeployment(ctx, newDeploymentKey, int(c.Msg.MinReplicas)) if err != nil { - if errors.Is(err, dal.ErrNotFound) { + if errors.Is(err, dalerrs.ErrNotFound) { logger.Errorf(err, "Deployment not found: %s", newDeploymentKey) return nil, connect.NewError(connect.CodeNotFound, errors.New("deployment not found")) - } else if errors.Is(err, dal.ErrConflict) { + } else if errors.Is(err, dalerrs.ErrConflict) { logger.Infof("Reusing deployment: %s", newDeploymentKey) } else { logger.Errorf(err, "Could not replace deployment: %s", newDeploymentKey) @@ -566,14 +567,14 @@ func (s *Service) RegisterRunner(ctx context.Context, stream *connect.ClientStre Deployment: maybeDeployment, Labels: msg.Labels.AsMap(), }) - if errors.Is(err, dal.ErrConflict) { + if errors.Is(err, dalerrs.ErrConflict) { return nil, connect.NewError(connect.CodeAlreadyExists, err) } else if err != nil { return nil, err } routes, err := s.dal.GetRoutingTable(ctx, nil) - if errors.Is(err, dal.ErrNotFound) { + if errors.Is(err, dalerrs.ErrNotFound) { routes = map[string][]dal.Route{} } else if err != nil { return nil, err @@ -1220,7 +1221,7 @@ func (s *Service) executeAsyncCalls(ctx context.Context) (time.Duration, error) logger.Tracef("Acquiring async call") call, err := s.dal.AcquireAsyncCall(ctx) - if errors.Is(err, dal.ErrNotFound) { + if errors.Is(err, dalerrs.ErrNotFound) { logger.Tracef("No async calls to execute") return time.Second * 2, nil } else if err != nil { @@ -1555,7 +1556,7 @@ func (s *Service) getDeploymentLogger(ctx context.Context, deploymentKey model.D // Periodically sync the routing table from the DB. func (s *Service) syncRoutes(ctx context.Context) (time.Duration, error) { routes, err := s.dal.GetRoutingTable(ctx, nil) - if errors.Is(err, dal.ErrNotFound) { + if errors.Is(err, dalerrs.ErrNotFound) { routes = map[string][]dal.Route{} } else if err != nil { return 0, err diff --git a/backend/controller/dal/async_calls.go b/backend/controller/dal/async_calls.go index b2dfd26654..9b760b88b3 100644 --- a/backend/controller/dal/async_calls.go +++ b/backend/controller/dal/async_calls.go @@ -12,6 +12,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/db/dalerrs" ) type asyncOriginParseRoot struct { @@ -94,10 +95,10 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, err error) ttl := time.Second * 5 row, err := tx.db.AcquireAsyncCall(ctx, ttl) if err != nil { - err = translatePGError(err) + err = dalerrs.TranslatePGError(err) // We get a NULL constraint violation if there are no async calls to acquire, so translate it to ErrNotFound. - if errors.Is(err, ErrConstraint) { - return nil, fmt.Errorf("no pending async calls: %w", ErrNotFound) + if errors.Is(err, dalerrs.ErrConstraint) { + return nil, fmt.Errorf("no pending async calls: %w", dalerrs.ErrNotFound) } return nil, fmt.Errorf("failed to acquire async call: %w", err) } @@ -126,7 +127,7 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, err error) func (d *DAL) CompleteAsyncCall(ctx context.Context, call *AsyncCall, result either.Either[[]byte, string], finalise func(tx *Tx) error) (err error) { tx, err := d.Begin(ctx) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } defer tx.CommitOrRollback(ctx, &err) @@ -134,7 +135,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context, call *AsyncCall, result eit case either.Left[[]byte, string]: // Successful response. _, err = tx.db.SucceedAsyncCall(ctx, result.Get(), call.ID) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } case either.Right[[]byte, string]: // Failure message. @@ -148,12 +149,12 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context, call *AsyncCall, result eit ScheduledAt: time.Now().Add(call.Backoff), }) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } } else { _, err = tx.db.FailAsyncCall(ctx, result.Get(), call.ID) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } } } @@ -164,7 +165,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context, call *AsyncCall, result eit func (d *DAL) LoadAsyncCall(ctx context.Context, id int64) (*AsyncCall, error) { row, err := d.db.LoadAsyncCall(ctx, id) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } origin, err := ParseAsyncOrigin(row.Origin) if err != nil { diff --git a/backend/controller/dal/dal.go b/backend/controller/dal/dal.go index 708f91f6bb..d57862128f 100644 --- a/backend/controller/dal/dal.go +++ b/backend/controller/dal/dal.go @@ -3,7 +3,6 @@ package dal import ( "context" - stdsql "database/sql" "encoding/json" "errors" "fmt" @@ -14,15 +13,13 @@ import ( "github.com/alecthomas/types/optional" "github.com/alecthomas/types/pubsub" sets "github.com/deckarep/golang-set/v2" - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" "google.golang.org/protobuf/proto" "github.com/TBD54566975/ftl/backend/controller/sql" ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/maps" "github.com/TBD54566975/ftl/internal/model" @@ -30,17 +27,6 @@ import ( "github.com/TBD54566975/ftl/internal/slices" ) -var ( - // ErrConflict is returned by select methods in the DAL when a resource already exists. - // - // Its use will be documented in the corresponding methods. - ErrConflict = errors.New("conflict") - // ErrNotFound is returned by select methods in the DAL when no results are found. - ErrNotFound = errors.New("not found") - // ErrConstraint is returned by select methods in the DAL when a constraint is violated. - ErrConstraint = errors.New("constraint violation") -) - type IngressRoute struct { Runner model.RunnerKey Deployment model.DeploymentKey @@ -284,7 +270,7 @@ func (t *Tx) Rollback(ctx context.Context) error { func (d *DAL) Begin(ctx context.Context) (*Tx, error) { tx, err := d.db.Begin(ctx) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return &Tx{&DAL{ db: tx, @@ -295,7 +281,7 @@ func (d *DAL) Begin(ctx context.Context) (*Tx, error) { func (d *DAL) GetActiveControllers(ctx context.Context) ([]Controller, error) { controllers, err := d.db.GetActiveControllers(ctx) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.Map(controllers, func(in sql.Controller) Controller { return Controller{ @@ -308,23 +294,23 @@ func (d *DAL) GetActiveControllers(ctx context.Context) ([]Controller, error) { func (d *DAL) GetStatus(ctx context.Context) (Status, error) { controllers, err := d.GetActiveControllers(ctx) if err != nil { - return Status{}, fmt.Errorf("could not get control planes: %w", translatePGError(err)) + return Status{}, fmt.Errorf("could not get control planes: %w", dalerrs.TranslatePGError(err)) } runners, err := d.db.GetActiveRunners(ctx) if err != nil { - return Status{}, fmt.Errorf("could not get active runners: %w", translatePGError(err)) + return Status{}, fmt.Errorf("could not get active runners: %w", dalerrs.TranslatePGError(err)) } deployments, err := d.db.GetActiveDeployments(ctx) if err != nil { - return Status{}, fmt.Errorf("could not get active deployments: %w", translatePGError(err)) + return Status{}, fmt.Errorf("could not get active deployments: %w", dalerrs.TranslatePGError(err)) } ingressRoutes, err := d.db.GetActiveIngressRoutes(ctx) if err != nil { - return Status{}, fmt.Errorf("could not get ingress routes: %w", translatePGError(err)) + return Status{}, fmt.Errorf("could not get ingress routes: %w", dalerrs.TranslatePGError(err)) } routes, err := d.db.GetRoutingTable(ctx, nil) if err != nil { - return Status{}, fmt.Errorf("could not get routing table: %w", translatePGError(err)) + return Status{}, fmt.Errorf("could not get routing table: %w", dalerrs.TranslatePGError(err)) } statusDeployments, err := slices.MapErr(deployments, func(in sql.GetActiveDeploymentsRow) (Deployment, error) { labels := model.Labels{} @@ -397,7 +383,7 @@ func (d *DAL) GetRunnersForDeployment(ctx context.Context, deployment model.Depl runners := []Runner{} rows, err := d.db.GetRunnersForDeployment(ctx, deployment) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } for _, row := range rows { attrs := model.Labels{} @@ -418,14 +404,14 @@ func (d *DAL) GetRunnersForDeployment(ctx context.Context, deployment model.Depl func (d *DAL) UpsertModule(ctx context.Context, language, name string) (err error) { _, err = d.db.UpsertModule(ctx, language, name) - return translatePGError(err) + return dalerrs.TranslatePGError(err) } // GetMissingArtefacts returns the digests of the artefacts that are missing from the database. func (d *DAL) GetMissingArtefacts(ctx context.Context, digests []sha256.SHA256) ([]sha256.SHA256, error) { have, err := d.db.GetArtefactDigests(ctx, sha256esToBytes(digests)) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } haveStr := slices.Map(have, func(in sql.GetArtefactDigestsRow) sha256.SHA256 { return sha256.FromBytes(in.Digest) @@ -437,7 +423,7 @@ func (d *DAL) GetMissingArtefacts(ctx context.Context, digests []sha256.SHA256) func (d *DAL) CreateArtefact(ctx context.Context, content []byte) (digest sha256.SHA256, err error) { sha256digest := sha256.Sum(content) _, err = d.db.CreateArtefact(ctx, sha256digest[:], content) - return sha256digest, translatePGError(err) + return sha256digest, dalerrs.TranslatePGError(err) } type IngressRoutingEntry struct { @@ -481,7 +467,7 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem // TODO(aat): "schema" containing language? _, err = tx.UpsertModule(ctx, language, moduleSchema.Name) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to upsert module: %w", translatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to upsert module: %w", dalerrs.TranslatePGError(err)) } // upsert topics @@ -497,7 +483,7 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem EventType: t.Event.String(), }) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("could not insert topic: %w", translatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("could not insert topic: %w", dalerrs.TranslatePGError(err)) } } @@ -506,7 +492,7 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem // Create the deployment err = tx.CreateDeployment(ctx, moduleSchema.Name, schemaBytes, deploymentKey) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to create deployment: %w", translatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to create deployment: %w", dalerrs.TranslatePGError(err)) } uploadedDigests := slices.Map(artefacts, func(in DeploymentArtefact) []byte { return in.Digest[:] }) @@ -529,7 +515,7 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem Path: artefact.Path, }) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to associate artefact with deployment: %w", translatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to associate artefact with deployment: %w", dalerrs.TranslatePGError(err)) } } @@ -542,7 +528,7 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem Verb: ingressRoute.Verb, }) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to create ingress route: %w", translatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to create ingress route: %w", dalerrs.TranslatePGError(err)) } } @@ -559,16 +545,17 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem NextExecution: job.NextExecution, }) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to create cron job: %w", translatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to create cron job: %w", dalerrs.TranslatePGError(err)) } } + return deploymentKey, nil } func (d *DAL) GetDeployment(ctx context.Context, key model.DeploymentKey) (*model.Deployment, error) { deployment, err := d.db.GetDeployment(ctx, key) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return d.loadDeployment(ctx, deployment) } @@ -590,7 +577,7 @@ func (d *DAL) UpsertRunner(ctx context.Context, runner Runner) error { Labels: attrBytes, }) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } if runner.Deployment.Ok() && !deploymentID.Ok() { return fmt.Errorf("deployment %s not found", runner.Deployment) @@ -614,10 +601,10 @@ func (d *DAL) KillStaleControllers(ctx context.Context, age time.Duration) (int6 func (d *DAL) DeregisterRunner(ctx context.Context, key model.RunnerKey) error { count, err := d.db.DeregisterRunner(ctx, key) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } if count == 0 { - return ErrNotFound + return dalerrs.ErrNotFound } return nil } @@ -634,18 +621,18 @@ func (d *DAL) ReserveRunnerForDeployment(ctx context.Context, deployment model.D tx, err := d.db.Begin(ctx) if err != nil { cancel() - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } runner, err := tx.ReserveRunner(ctx, time.Now().Add(reservationTimeout), deployment, jsonLabels) if err != nil { if rerr := tx.Rollback(context.Background()); rerr != nil { - err = errors.Join(err, translatePGError(rerr)) + err = errors.Join(err, dalerrs.TranslatePGError(rerr)) } cancel() - if isNotFound(err) { - return nil, fmt.Errorf("no idle runners found matching labels %s: %w", jsonLabels, ErrNotFound) + if dalerrs.IsNotFound(err) { + return nil, fmt.Errorf("no idle runners found matching labels %s: %w", jsonLabels, dalerrs.ErrNotFound) } - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } runnerLabels := model.Labels{} if err := json.Unmarshal(runner.Labels, &runnerLabels); err != nil { @@ -676,12 +663,12 @@ type postgresClaim struct { func (p *postgresClaim) Commit(ctx context.Context) error { defer p.cancel() - return translatePGError(p.tx.Commit(ctx)) + return dalerrs.TranslatePGError(p.tx.Commit(ctx)) } func (p *postgresClaim) Rollback(ctx context.Context) error { defer p.cancel() - return translatePGError(p.tx.Rollback(ctx)) + return dalerrs.TranslatePGError(p.tx.Rollback(ctx)) } func (p *postgresClaim) Runner() Runner { return p.runner } @@ -691,29 +678,29 @@ func (d *DAL) SetDeploymentReplicas(ctx context.Context, key model.DeploymentKey // Start the transaction tx, err := d.db.Begin(ctx) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } defer tx.CommitOrRollback(ctx, &err) deployment, err := d.db.GetDeployment(ctx, key) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } err = d.db.SetDeploymentDesiredReplicas(ctx, key, int32(minReplicas)) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } if minReplicas == 0 { err = d.deploymentWillDeactivate(ctx, tx, key) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } } else if deployment.MinReplicas == 0 { err = d.deploymentWillActivate(ctx, tx, key) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } } err = tx.InsertDeploymentUpdatedEvent(ctx, sql.InsertDeploymentUpdatedEventParams{ @@ -722,7 +709,7 @@ func (d *DAL) SetDeploymentReplicas(ctx context.Context, key model.DeploymentKey PrevMinReplicas: deployment.MinReplicas, }) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } return nil @@ -733,19 +720,19 @@ func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.Depl // Start the transaction tx, err := d.db.Begin(ctx) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } defer tx.CommitOrRollback(ctx, &err) newDeployment, err := tx.GetDeployment(ctx, newDeploymentKey) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } // must be called before deploymentWillDeactivate for the old deployment err = d.deploymentWillActivate(ctx, tx, newDeploymentKey) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } // If there's an existing deployment, set its desired replicas to 0 @@ -754,23 +741,23 @@ func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.Depl if err == nil { count, err := tx.ReplaceDeployment(ctx, oldDeployment.Key, newDeploymentKey, int32(minReplicas)) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } if count == 1 { - return fmt.Errorf("deployment already exists: %w", ErrConflict) + return fmt.Errorf("deployment already exists: %w", dalerrs.ErrConflict) } err = d.deploymentWillDeactivate(ctx, tx, oldDeployment.Key) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } replacedDeploymentKey = optional.Some(oldDeployment.Key) - } else if !isNotFound(err) { - return translatePGError(err) + } else if !dalerrs.IsNotFound(err) { + return dalerrs.TranslatePGError(err) } else { // Set the desired replicas for the new deployment err = tx.SetDeploymentDesiredReplicas(ctx, newDeploymentKey, int32(minReplicas)) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } } @@ -782,7 +769,7 @@ func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.Depl Replaced: replacedDeploymentKey, }) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } return nil @@ -795,7 +782,7 @@ func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.Depl func (d *DAL) deploymentWillActivate(ctx context.Context, tx *sql.Tx, key model.DeploymentKey) error { module, err := tx.GetSchemaForDeployment(ctx, key) if err != nil { - return fmt.Errorf("could not get schema: %w", translatePGError(err)) + return fmt.Errorf("could not get schema: %w", dalerrs.TranslatePGError(err)) } err = d.createSubscriptions(ctx, tx, key, module) if err != nil { @@ -816,10 +803,10 @@ func (d *DAL) deploymentWillDeactivate(ctx context.Context, tx *sql.Tx, key mode func (d *DAL) GetDeploymentsNeedingReconciliation(ctx context.Context) ([]Reconciliation, error) { counts, err := d.db.GetDeploymentsNeedingReconciliation(ctx) if err != nil { - if isNotFound(err) { + if dalerrs.IsNotFound(err) { return nil, nil } - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.Map(counts, func(t sql.GetDeploymentsNeedingReconciliationRow) Reconciliation { return Reconciliation{ @@ -836,10 +823,10 @@ func (d *DAL) GetDeploymentsNeedingReconciliation(ctx context.Context) ([]Reconc func (d *DAL) GetActiveDeployments(ctx context.Context) ([]Deployment, error) { rows, err := d.db.GetActiveDeployments(ctx) if err != nil { - if isNotFound(err) { + if dalerrs.IsNotFound(err) { return nil, nil } - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.MapErr(rows, func(in sql.GetActiveDeploymentsRow) (Deployment, error) { return Deployment{ @@ -857,10 +844,10 @@ func (d *DAL) GetActiveDeployments(ctx context.Context) ([]Deployment, error) { func (d *DAL) GetDeploymentsWithMinReplicas(ctx context.Context) ([]Deployment, error) { rows, err := d.db.GetDeploymentsWithMinReplicas(ctx) if err != nil { - if isNotFound(err) { + if dalerrs.IsNotFound(err) { return nil, nil } - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.MapErr(rows, func(in sql.GetDeploymentsWithMinReplicasRow) (Deployment, error) { return Deployment{ @@ -877,7 +864,7 @@ func (d *DAL) GetDeploymentsWithMinReplicas(ctx context.Context) ([]Deployment, func (d *DAL) GetActiveDeploymentSchemas(ctx context.Context) ([]*schema.Module, error) { rows, err := d.db.GetActiveDeploymentSchemas(ctx) if err != nil { - return nil, fmt.Errorf("could not get active deployments: %w", translatePGError(err)) + return nil, fmt.Errorf("could not get active deployments: %w", dalerrs.TranslatePGError(err)) } return slices.MapErr(rows, func(in sql.GetActiveDeploymentSchemasRow) (*schema.Module, error) { return in.Schema, nil }) } @@ -899,7 +886,7 @@ type Process struct { func (d *DAL) GetProcessList(ctx context.Context) ([]Process, error) { rows, err := d.db.GetProcessList(ctx) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.MapErr(rows, func(row sql.GetProcessListRow) (Process, error) { var runner optional.Option[ProcessRunner] @@ -944,10 +931,10 @@ func (d *DAL) GetIdleRunners(ctx context.Context, limit int, labels model.Labels return nil, fmt.Errorf("could not marshal labels: %w", err) } runners, err := d.db.GetIdleRunners(ctx, jsonb, int64(limit)) - if isNotFound(err) { + if dalerrs.IsNotFound(err) { return nil, nil } else if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.MapErr(runners, func(row sql.Runner) (Runner, error) { rowLabels := model.Labels{} @@ -972,10 +959,10 @@ func (d *DAL) GetIdleRunners(ctx context.Context, limit int, labels model.Labels func (d *DAL) GetRoutingTable(ctx context.Context, modules []string) (map[string][]Route, error) { routes, err := d.db.GetRoutingTable(ctx, modules) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } if len(routes) == 0 { - return nil, fmt.Errorf("no routes found: %w", ErrNotFound) + return nil, fmt.Errorf("no routes found: %w", dalerrs.ErrNotFound) } out := make(map[string][]Route, len(routes)) for _, route := range routes { @@ -995,7 +982,7 @@ func (d *DAL) GetRoutingTable(ctx context.Context, modules []string) (map[string func (d *DAL) GetRunnerState(ctx context.Context, runnerKey model.RunnerKey) (RunnerState, error) { state, err := d.db.GetRunnerState(ctx, runnerKey) if err != nil { - return "", translatePGError(err) + return "", dalerrs.TranslatePGError(err) } return RunnerState(state), nil } @@ -1003,14 +990,14 @@ func (d *DAL) GetRunnerState(ctx context.Context, runnerKey model.RunnerKey) (Ru func (d *DAL) GetRunner(ctx context.Context, runnerKey model.RunnerKey) (Runner, error) { row, err := d.db.GetRunner(ctx, runnerKey) if err != nil { - return Runner{}, translatePGError(err) + return Runner{}, dalerrs.TranslatePGError(err) } return runnerFromDB(row), nil } func (d *DAL) ExpireRunnerClaims(ctx context.Context) (int64, error) { count, err := d.db.ExpireRunnerReservations(ctx) - return count, translatePGError(err) + return count, dalerrs.TranslatePGError(err) } func cronJobFromRow(row sql.GetCronJobsRow) model.CronJob { @@ -1029,7 +1016,7 @@ func cronJobFromRow(row sql.GetCronJobsRow) model.CronJob { func (d *DAL) GetCronJobs(ctx context.Context) ([]model.CronJob, error) { rows, err := d.db.GetCronJobs(ctx) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.Map(rows, cronJobFromRow), nil } @@ -1047,7 +1034,7 @@ func (d *DAL) StartCronJobs(ctx context.Context, jobs []model.CronJob) (attempte } rows, err := d.db.StartCronJobs(ctx, slices.Map(jobs, func(job model.CronJob) string { return job.Key.String() })) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } attemptedJobs = []AttemptedCronJob{} @@ -1075,7 +1062,7 @@ func (d *DAL) StartCronJobs(ctx context.Context, jobs []model.CronJob) (attempte func (d *DAL) EndCronJob(ctx context.Context, job model.CronJob, next time.Time) (model.CronJob, error) { row, err := d.db.EndCronJob(ctx, next, job.Key, job.StartTime) if err != nil { - return model.CronJob{}, translatePGError(err) + return model.CronJob{}, dalerrs.TranslatePGError(err) } return cronJobFromRow(sql.GetCronJobsRow(row)), nil } @@ -1084,7 +1071,7 @@ func (d *DAL) EndCronJob(ctx context.Context, job model.CronJob, next time.Time) func (d *DAL) GetStaleCronJobs(ctx context.Context, duration time.Duration) ([]model.CronJob, error) { rows, err := d.db.GetStaleCronJobs(ctx, duration) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.Map(rows, func(row sql.GetStaleCronJobsRow) model.CronJob { return cronJobFromRow(sql.GetCronJobsRow(row)) @@ -1100,7 +1087,7 @@ func (d *DAL) InsertLogEvent(ctx context.Context, log *LogEvent) error { if name, ok := log.RequestKey.Get(); ok { requestKey = optional.Some(name.String()) } - return translatePGError(d.db.InsertLogEvent(ctx, sql.InsertLogEventParams{ + return dalerrs.TranslatePGError(d.db.InsertLogEvent(ctx, sql.InsertLogEventParams{ DeploymentKey: log.DeploymentKey, RequestKey: requestKey, TimeStamp: log.Time, @@ -1121,7 +1108,7 @@ func (d *DAL) loadDeployment(ctx context.Context, deployment sql.GetDeploymentRo } artefacts, err := d.db.GetDeploymentArtefacts(ctx, deployment.Deployment.ID) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } out.Artefacts = slices.Map(artefacts, func(row sql.GetDeploymentArtefactsRow) *model.Artefact { return &model.Artefact{ @@ -1136,7 +1123,7 @@ func (d *DAL) loadDeployment(ctx context.Context, deployment sql.GetDeploymentRo func (d *DAL) CreateRequest(ctx context.Context, key model.RequestKey, addr string) error { if err := d.db.CreateRequest(ctx, sql.Origin(key.Payload.Origin), key, addr); err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } return nil } @@ -1144,10 +1131,10 @@ func (d *DAL) CreateRequest(ctx context.Context, key model.RequestKey, addr stri func (d *DAL) GetIngressRoutes(ctx context.Context, method string) ([]IngressRoute, error) { routes, err := d.db.GetIngressRoutes(ctx, method) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } if len(routes) == 0 { - return nil, ErrNotFound + return nil, dalerrs.ErrNotFound } return slices.Map(routes, func(row sql.GetIngressRoutesRow) IngressRoute { return IngressRoute{ @@ -1163,7 +1150,7 @@ func (d *DAL) GetIngressRoutes(ctx context.Context, method string) ([]IngressRou func (d *DAL) UpsertController(ctx context.Context, key model.ControllerKey, addr string) (int64, error) { id, err := d.db.UpsertController(ctx, key, addr) - return id, translatePGError(err) + return id, dalerrs.TranslatePGError(err) } func (d *DAL) InsertCallEvent(ctx context.Context, call *CallEvent) error { @@ -1175,7 +1162,7 @@ func (d *DAL) InsertCallEvent(ctx context.Context, call *CallEvent) error { if rn, ok := call.RequestKey.Get(); ok { requestKey = optional.Some(rn.String()) } - return translatePGError(d.db.InsertCallEvent(ctx, sql.InsertCallEventParams{ + return dalerrs.TranslatePGError(d.db.InsertCallEvent(ctx, sql.InsertCallEventParams{ DeploymentKey: call.DeploymentKey, RequestKey: requestKey, TimeStamp: call.Time, @@ -1194,39 +1181,13 @@ func (d *DAL) InsertCallEvent(ctx context.Context, call *CallEvent) error { func (d *DAL) GetActiveRunners(ctx context.Context) ([]Runner, error) { rows, err := d.db.GetActiveRunners(ctx) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.Map(rows, func(row sql.GetActiveRunnersRow) Runner { return runnerFromDB(sql.GetRunnerRow(row)) }), nil } -func (d *DAL) GetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) ([]byte, error) { - b, err := d.db.GetModuleConfiguration(ctx, module, name) - if err != nil { - return nil, translatePGError(err) - } - return b, nil -} - -func (d *DAL) SetModuleConfiguration(ctx context.Context, module optional.Option[string], name string, value []byte) error { - err := d.db.SetModuleConfiguration(ctx, module, name, value) - return translatePGError(err) -} - -func (d *DAL) UnsetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) error { - err := d.db.UnsetModuleConfiguration(ctx, module, name) - return translatePGError(err) -} - -func (d *DAL) ListModuleConfiguration(ctx context.Context) ([]sql.ModuleConfiguration, error) { - l, err := d.db.ListModuleConfiguration(ctx) - if err != nil { - return nil, translatePGError(err) - } - return l, nil -} - // Check if a deployment exists that exactly matches the given artefacts and schema. func (*DAL) checkForExistingDeployments(ctx context.Context, tx *sql.Tx, moduleSchema *schema.Module, artefacts []DeploymentArtefact) (model.DeploymentKey, error) { schemaBytes, err := schema.ModuleToBytes(moduleSchema) @@ -1262,7 +1223,7 @@ func (r *artefactReader) Close() error { return nil } func (r *artefactReader) Read(p []byte) (n int, err error) { content, err := r.db.GetArtefactContentRange(context.Background(), r.offset+1, int32(len(p)), r.id) if err != nil { - return 0, translatePGError(err) + return 0, dalerrs.TranslatePGError(err) } copy(p, content) clen := len(content) @@ -1272,32 +1233,3 @@ func (r *artefactReader) Read(p []byte) (n int, err error) { } return clen, err } - -func isNotFound(err error) bool { - return errors.Is(err, stdsql.ErrNoRows) || errors.Is(err, pgx.ErrNoRows) -} - -func translatePGError(err error) error { - if err == nil { - return nil - } - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) { - switch pgErr.Code { - case pgerrcode.ForeignKeyViolation: - return fmt.Errorf("%s: %w", strings.TrimSuffix(strings.TrimPrefix(pgErr.ConstraintName, pgErr.TableName+"_"), "_id_fkey"), ErrNotFound) - case pgerrcode.UniqueViolation: - return fmt.Errorf("%s: %w", pgErr.Message, ErrConflict) - case pgerrcode.IntegrityConstraintViolation, - pgerrcode.RestrictViolation, - pgerrcode.NotNullViolation, - pgerrcode.CheckViolation, - pgerrcode.ExclusionViolation: - return fmt.Errorf("%s: %w", pgErr.Message, ErrConstraint) - default: - } - } else if isNotFound(err) { - return ErrNotFound - } - return err -} diff --git a/backend/controller/dal/dal_test.go b/backend/controller/dal/dal_test.go index 85a4215f83..a1b0334c9c 100644 --- a/backend/controller/dal/dal_test.go +++ b/backend/controller/dal/dal_test.go @@ -12,10 +12,10 @@ import ( "github.com/alecthomas/types/optional" "golang.org/x/sync/errgroup" - "github.com/TBD54566975/ftl/backend/controller/sql" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/model" "github.com/TBD54566975/ftl/internal/sha256" @@ -88,7 +88,7 @@ func TestDAL(t *testing.T) { t.Run("GetMissingDeployment", func(t *testing.T) { _, err := dal.GetDeployment(ctx, model.NewDeploymentKey("invalid")) - assert.IsError(t, err, ErrNotFound) + assert.IsError(t, err, dalerrs.ErrNotFound) }) t.Run("GetMissingArtefacts", func(t *testing.T) { @@ -119,7 +119,7 @@ func TestDAL(t *testing.T) { State: RunnerStateIdle, }) assert.Error(t, err) - assert.IsError(t, err, ErrConflict) + assert.IsError(t, err, dalerrs.ErrConflict) }) t.Run("GetIdleRunnersForLanguage", func(t *testing.T) { @@ -168,7 +168,7 @@ func TestDAL(t *testing.T) { t.Run("ReserveRunnerForInvalidDeployment", func(t *testing.T) { _, err := dal.ReserveRunnerForDeployment(ctx, model.NewDeploymentKey("invalid"), time.Second, labels) assert.Error(t, err) - assert.IsError(t, err, ErrNotFound) + assert.IsError(t, err, dalerrs.ErrNotFound) assert.EqualError(t, err, "deployment: not found") }) @@ -192,7 +192,7 @@ func TestDAL(t *testing.T) { t.Run("ReserveRunnerForDeploymentFailsOnInvalidDeployment", func(t *testing.T) { _, err = dal.ReserveRunnerForDeployment(ctx, model.NewDeploymentKey("test"), time.Second, labels) - assert.IsError(t, err, ErrNotFound) + assert.IsError(t, err, dalerrs.ErrNotFound) }) t.Run("UpdateRunnerAssigned", func(t *testing.T) { @@ -319,7 +319,7 @@ func TestDAL(t *testing.T) { Deployment: optional.Some(model.NewDeploymentKey("test")), }) assert.Error(t, err) - assert.IsError(t, err, ErrNotFound) + assert.IsError(t, err, dalerrs.ErrNotFound) }) t.Run("ReleaseRunnerReservation", func(t *testing.T) { @@ -342,7 +342,7 @@ func TestDAL(t *testing.T) { t.Run("GetRoutingTable", func(t *testing.T) { _, err := dal.GetRoutingTable(ctx, []string{deployment.Module}) - assert.IsError(t, err, ErrNotFound) + assert.IsError(t, err, dalerrs.ErrNotFound) }) t.Run("DeregisterRunner", func(t *testing.T) { @@ -352,7 +352,7 @@ func TestDAL(t *testing.T) { t.Run("DeregisterRunnerFailsOnMissing", func(t *testing.T) { err = dal.DeregisterRunner(ctx, model.NewRunnerKey("localhost", "8080")) - assert.IsError(t, err, ErrNotFound) + assert.IsError(t, err, dalerrs.ErrNotFound) }) t.Run("VerifyDeploymentNotifications", func(t *testing.T) { @@ -402,93 +402,3 @@ func assertEventsEqual(t *testing.T, expected, actual []Event) { t.Helper() assert.Equal(t, normaliseEvents(expected), normaliseEvents(actual)) } - -func TestModuleConfiguration(t *testing.T) { - ctx := log.ContextWithNewDefaultLogger(context.Background()) - conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) - - tests := []struct { - TestName string - ModuleSet optional.Option[string] - ModuleGet optional.Option[string] - PresetGlobal bool - }{ - { - "SetModuleGetModule", - optional.Some("echo"), - optional.Some("echo"), - false, - }, - { - "SetGlobalGetGlobal", - optional.None[string](), - optional.None[string](), - false, - }, - { - "SetGlobalGetModule", - optional.None[string](), - optional.Some("echo"), - false, - }, - { - "SetModuleOverridesGlobal", - optional.Some("echo"), - optional.Some("echo"), - true, - }, - } - - b := []byte(`"asdf"`) - for _, test := range tests { - t.Run(test.TestName, func(t *testing.T) { - if test.PresetGlobal { - err := dal.SetModuleConfiguration(ctx, optional.None[string](), "configname", []byte(`"qwerty"`)) - assert.NoError(t, err) - } - err := dal.SetModuleConfiguration(ctx, test.ModuleSet, "configname", b) - assert.NoError(t, err) - gotBytes, err := dal.GetModuleConfiguration(ctx, test.ModuleGet, "configname") - assert.NoError(t, err) - assert.Equal(t, b, gotBytes) - err = dal.UnsetModuleConfiguration(ctx, test.ModuleGet, "configname") - assert.NoError(t, err) - }) - } - - t.Run("List", func(t *testing.T) { - sortedList := []sql.ModuleConfiguration{ - { - Module: optional.Some("echo"), - Name: "a", - }, - { - Module: optional.Some("echo"), - Name: "b", - }, - { - Module: optional.None[string](), - Name: "a", - }, - } - - // Insert entries in a separate order from how they should be returned to - // test sorting logic in the SQL query - err := dal.SetModuleConfiguration(ctx, sortedList[1].Module, sortedList[1].Name, []byte(`""`)) - assert.NoError(t, err) - err = dal.SetModuleConfiguration(ctx, sortedList[2].Module, sortedList[2].Name, []byte(`""`)) - assert.NoError(t, err) - err = dal.SetModuleConfiguration(ctx, sortedList[0].Module, sortedList[0].Name, []byte(`""`)) - assert.NoError(t, err) - - gotList, err := dal.ListModuleConfiguration(ctx) - assert.NoError(t, err) - for i := range sortedList { - assert.Equal(t, sortedList[i].Module, gotList[i].Module) - assert.Equal(t, sortedList[i].Name, gotList[i].Name) - } - }) -} diff --git a/backend/controller/dal/events.go b/backend/controller/dal/events.go index adec1b2a92..54839d9dfc 100644 --- a/backend/controller/dal/events.go +++ b/backend/controller/dal/events.go @@ -12,6 +12,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/model" ) @@ -260,7 +261,7 @@ func (d *DAL) QueryEvents(ctx context.Context, limit int, filters ...EventFilter } rows, err := d.db.Conn().Query(ctx, deploymentQuery, deploymentArgs...) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } deploymentIDs := []int64{} for rows.Next() { @@ -315,7 +316,7 @@ func (d *DAL) QueryEvents(ctx context.Context, limit int, filters ...EventFilter // Issue query. rows, err = d.db.Conn().Query(ctx, q, args...) if err != nil { - return nil, fmt.Errorf("%s: %w", q, translatePGError(err)) + return nil, fmt.Errorf("%s: %w", q, dalerrs.TranslatePGError(err)) } defer rows.Close() diff --git a/backend/controller/dal/fsm.go b/backend/controller/dal/fsm.go index 43fbc39282..cb1a599ea9 100644 --- a/backend/controller/dal/fsm.go +++ b/backend/controller/dal/fsm.go @@ -12,6 +12,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/leases" "github.com/TBD54566975/ftl/backend/controller/sql" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/db/dalerrs" ) // StartFSMTransition sends an event to an executing instance of an FSM. @@ -39,7 +40,7 @@ func (d *DAL) StartFSMTransition(ctx context.Context, fsm schema.RefKey, executi MaxBackoff: retryParams.MaxBackoff, }) if err != nil { - return fmt.Errorf("failed to create FSM async call: %w", translatePGError(err)) + return fmt.Errorf("failed to create FSM async call: %w", dalerrs.TranslatePGError(err)) } // Start a transition. @@ -50,9 +51,9 @@ func (d *DAL) StartFSMTransition(ctx context.Context, fsm schema.RefKey, executi AsyncCallID: asyncCallID, }) if err != nil { - err = translatePGError(err) - if errors.Is(err, ErrNotFound) { - return fmt.Errorf("transition already executing: %w", ErrConflict) + err = dalerrs.TranslatePGError(err) + if errors.Is(err, dalerrs.ErrNotFound) { + return fmt.Errorf("transition already executing: %w", dalerrs.ErrConflict) } return fmt.Errorf("failed to start FSM transition: %w", err) } @@ -61,17 +62,17 @@ func (d *DAL) StartFSMTransition(ctx context.Context, fsm schema.RefKey, executi func (d *DAL) FinishFSMTransition(ctx context.Context, fsm schema.RefKey, instanceKey string) error { _, err := d.db.FinishFSMTransition(ctx, fsm, instanceKey) - return translatePGError(err) + return dalerrs.TranslatePGError(err) } func (d *DAL) FailFSMInstance(ctx context.Context, fsm schema.RefKey, instanceKey string) error { _, err := d.db.FailFSMInstance(ctx, fsm, instanceKey) - return translatePGError(err) + return dalerrs.TranslatePGError(err) } func (d *DAL) SucceedFSMInstance(ctx context.Context, fsm schema.RefKey, instanceKey string) error { _, err := d.db.SucceedFSMInstance(ctx, fsm, instanceKey) - return translatePGError(err) + return dalerrs.TranslatePGError(err) } type FSMStatus = sql.FsmStatus @@ -103,8 +104,8 @@ func (d *DAL) AcquireFSMInstance(ctx context.Context, fsm schema.RefKey, instanc } row, err := d.db.GetFSMInstance(ctx, fsm, instanceKey) if err != nil { - err = translatePGError(err) - if !errors.Is(err, ErrNotFound) { + err = dalerrs.TranslatePGError(err) + if !errors.Is(err, dalerrs.ErrNotFound) { return nil, err } row.Status = sql.FsmStatusRunning diff --git a/backend/controller/dal/fsm_test.go b/backend/controller/dal/fsm_test.go index a3c50bf74f..02a8041ecf 100644 --- a/backend/controller/dal/fsm_test.go +++ b/backend/controller/dal/fsm_test.go @@ -10,6 +10,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" ) @@ -20,14 +21,14 @@ func TestSendFSMEvent(t *testing.T) { assert.NoError(t, err) _, err = dal.AcquireAsyncCall(ctx) - assert.IsError(t, err, ErrNotFound) + assert.IsError(t, err, dalerrs.ErrNotFound) ref := schema.RefKey{Module: "module", Name: "verb"} err = dal.StartFSMTransition(ctx, schema.RefKey{Module: "test", Name: "test"}, "invoiceID", ref, []byte(`{}`), schema.RetryParams{}) assert.NoError(t, err) err = dal.StartFSMTransition(ctx, schema.RefKey{Module: "test", Name: "test"}, "invoiceID", ref, []byte(`{}`), schema.RetryParams{}) - assert.IsError(t, err, ErrConflict) + assert.IsError(t, err, dalerrs.ErrConflict) assert.EqualError(t, err, "transition already executing: conflict") call, err := dal.AcquireAsyncCall(ctx) diff --git a/backend/controller/dal/lease.go b/backend/controller/dal/lease.go index 138c571905..e9711c55ec 100644 --- a/backend/controller/dal/lease.go +++ b/backend/controller/dal/lease.go @@ -12,6 +12,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/leases" "github.com/TBD54566975/ftl/backend/controller/sql" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" ) @@ -47,8 +48,8 @@ func (l *Lease) renew(ctx context.Context, cancelCtx context.CancelFunc) { cancel() if err != nil { - err = translatePGError(err) - if errors.Is(err, ErrNotFound) { + err = dalerrs.TranslatePGError(err) + if errors.Is(err, dalerrs.ErrNotFound) { logger.Warnf("Lease expired") } else { logger.Errorf(err, "Failed to renew lease %s", l.key) @@ -64,7 +65,7 @@ func (l *Lease) renew(ctx context.Context, cancelCtx context.CancelFunc) { } logger.Debugf("Releasing lease") _, err := l.db.ReleaseLease(ctx, l.idempotencyKey, l.key) - l.errch <- translatePGError(err) + l.errch <- dalerrs.TranslatePGError(err) cancelCtx() return } @@ -95,7 +96,7 @@ func (d *DAL) AcquireLease(ctx context.Context, key leases.Key, ttl time.Duratio } idempotencyKey, err := d.db.NewLease(ctx, key, ttl, metadataBytes) if err != nil { - return nil, nil, translatePGError(err) + return nil, nil, dalerrs.TranslatePGError(err) } leaseCtx, lease := d.newLease(ctx, key, idempotencyKey, ttl) return leaseCtx, lease, nil @@ -121,7 +122,7 @@ func (d *DAL) newLease(ctx context.Context, key leases.Key, idempotencyKey uuid. func (d *DAL) GetLeaseInfo(ctx context.Context, key leases.Key, metadata any) (expiry time.Time, err error) { l, err := d.db.GetLeaseInfo(ctx, key) if err != nil { - return expiry, translatePGError(err) + return expiry, dalerrs.TranslatePGError(err) } if err := json.Unmarshal(l.Metadata, metadata); err != nil { return expiry, fmt.Errorf("could not unmarshal lease metadata: %w", err) @@ -136,5 +137,5 @@ func (d *DAL) ExpireLeases(ctx context.Context) error { if count > 0 { log.FromContext(ctx).Warnf("Expired %d leases", count) } - return translatePGError(err) + return dalerrs.TranslatePGError(err) } diff --git a/backend/controller/dal/lease_test.go b/backend/controller/dal/lease_test.go index 67205abcf4..88bd7e64db 100644 --- a/backend/controller/dal/lease_test.go +++ b/backend/controller/dal/lease_test.go @@ -13,16 +13,17 @@ import ( "github.com/TBD54566975/ftl/backend/controller/leases" "github.com/TBD54566975/ftl/backend/controller/sql" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" ) func leaseExists(t *testing.T, conn sql.ConnI, idempotencyKey uuid.UUID, key leases.Key) bool { t.Helper() var count int - err := translatePGError(conn. + err := dalerrs.TranslatePGError(conn. QueryRow(context.Background(), "SELECT COUNT(*) FROM leases WHERE idempotency_key = $1 AND key = $2", idempotencyKey, key). Scan(&count)) - if errors.Is(err, ErrNotFound) { + if errors.Is(err, dalerrs.ErrNotFound) { return false } assert.NoError(t, err) @@ -49,7 +50,7 @@ func TestLease(t *testing.T) { // Try to acquire the same lease again, which should fail. _, _, err = dal.AcquireLease(ctx, leases.SystemKey("test"), time.Second*5, optional.None[any]()) - assert.IsError(t, err, ErrConflict) + assert.IsError(t, err, dalerrs.ErrConflict) time.Sleep(time.Second * 6) diff --git a/backend/controller/dal/notify.go b/backend/controller/dal/notify.go index 7cdd192d29..c3afa5975f 100644 --- a/backend/controller/dal/notify.go +++ b/backend/controller/dal/notify.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jpillora/backoff" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/model" ) @@ -102,7 +103,7 @@ func (d *DAL) publishNotification(ctx context.Context, notification event, logge deployment, err := decodeNotification(notification, func(key model.DeploymentKey) (Deployment, optional.Option[model.DeploymentKey], error) { row, err := d.db.GetDeployment(ctx, key) if err != nil { - return Deployment{}, optional.None[model.DeploymentKey](), translatePGError(err) + return Deployment{}, optional.None[model.DeploymentKey](), dalerrs.TranslatePGError(err) } return Deployment{ CreatedAt: row.Deployment.CreatedAt, diff --git a/backend/controller/dal/pubsub.go b/backend/controller/dal/pubsub.go index 3988d3c00e..815b720a10 100644 --- a/backend/controller/dal/pubsub.go +++ b/backend/controller/dal/pubsub.go @@ -7,6 +7,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/model" "github.com/TBD54566975/ftl/internal/slices" @@ -20,7 +21,7 @@ func (d *DAL) PublishEventForTopic(ctx context.Context, module, topic string, pa Payload: payload, }) if err != nil { - return translatePGError(err) + return dalerrs.TranslatePGError(err) } return nil } @@ -28,7 +29,7 @@ func (d *DAL) PublishEventForTopic(ctx context.Context, module, topic string, pa func (d *DAL) GetSubscriptionsNeedingUpdate(ctx context.Context) ([]model.Subscription, error) { rows, err := d.db.GetSubscriptionsNeedingUpdate(ctx) if err != nil { - return nil, translatePGError(err) + return nil, dalerrs.TranslatePGError(err) } return slices.Map(rows, func(row sql.GetSubscriptionsNeedingUpdateRow) model.Subscription { return model.Subscription{ @@ -53,18 +54,18 @@ func (d *DAL) ProgressSubscriptions(ctx context.Context, eventConsumptionDelay t // also gets a lock on the subscription, and skips any subscriptions locked by others subs, err := tx.db.GetSubscriptionsNeedingUpdate(ctx) if err != nil { - return 0, fmt.Errorf("could not get subscriptions to progress: %w", translatePGError(err)) + return 0, fmt.Errorf("could not get subscriptions to progress: %w", dalerrs.TranslatePGError(err)) } successful := 0 for _, subscription := range subs { nextCursor, err := tx.db.GetNextEventForSubscription(ctx, eventConsumptionDelay, subscription.Topic, subscription.Cursor) if err != nil { - return 0, fmt.Errorf("failed to get next cursor: %w", translatePGError(err)) + return 0, fmt.Errorf("failed to get next cursor: %w", dalerrs.TranslatePGError(err)) } nextCursorKey, ok := nextCursor.Event.Get() if !ok { - return 0, fmt.Errorf("could not find event to progress subscription: %w", translatePGError(err)) + return 0, fmt.Errorf("could not find event to progress subscription: %w", dalerrs.TranslatePGError(err)) } if !nextCursor.Ready { logger.Tracef("Skipping subscription %s because event is too new", subscription.Key) @@ -79,7 +80,7 @@ func (d *DAL) ProgressSubscriptions(ctx context.Context, eventConsumptionDelay t err = tx.db.BeginConsumingTopicEvent(ctx, subscription.Key, nextCursorKey) if err != nil { - return 0, fmt.Errorf("failed to progress subscription: %w", translatePGError(err)) + return 0, fmt.Errorf("failed to progress subscription: %w", dalerrs.TranslatePGError(err)) } origin := AsyncOriginPubSub{ @@ -97,7 +98,7 @@ func (d *DAL) ProgressSubscriptions(ctx context.Context, eventConsumptionDelay t MaxBackoff: subscriber.MaxBackoff, }) if err != nil { - return 0, fmt.Errorf("failed to schedule async task for subscription: %w", translatePGError(err)) + return 0, fmt.Errorf("failed to schedule async task for subscription: %w", dalerrs.TranslatePGError(err)) } successful++ } @@ -107,7 +108,7 @@ func (d *DAL) ProgressSubscriptions(ctx context.Context, eventConsumptionDelay t func (d *DAL) CompleteEventForSubscription(ctx context.Context, module, name string) error { err := d.db.CompleteEventForSubscription(ctx, name, module) if err != nil { - return fmt.Errorf("failed to complete event for subscription: %w", translatePGError(err)) + return fmt.Errorf("failed to complete event for subscription: %w", dalerrs.TranslatePGError(err)) } return nil } @@ -134,7 +135,7 @@ func (d *DAL) createSubscriptions(ctx context.Context, tx *sql.Tx, key model.Dep TopicName: s.Topic.Name, Name: s.Name, }); err != nil { - return fmt.Errorf("could not insert subscription: %w", translatePGError(err)) + return fmt.Errorf("could not insert subscription: %w", dalerrs.TranslatePGError(err)) } } return nil @@ -193,7 +194,7 @@ func (d *DAL) createSubscribers(ctx context.Context, tx *sql.Tx, key model.Deplo MaxBackoff: retryParams.MaxBackoff, }) if err != nil { - return fmt.Errorf("could not insert subscriber: %w", translatePGError(err)) + return fmt.Errorf("could not insert subscriber: %w", dalerrs.TranslatePGError(err)) } } } @@ -202,10 +203,10 @@ func (d *DAL) createSubscribers(ctx context.Context, tx *sql.Tx, key model.Deplo func (d *DAL) removeSubscriptionsAndSubscribers(ctx context.Context, tx *sql.Tx, key model.DeploymentKey) error { if err := tx.DeleteSubscriptions(ctx, key); err != nil { - return fmt.Errorf("could not delete old subscriptions: %w", translatePGError(err)) + return fmt.Errorf("could not delete old subscriptions: %w", dalerrs.TranslatePGError(err)) } if err := tx.DeleteSubscribers(ctx, key); err != nil { - return fmt.Errorf("could not delete old subscribers: %w", translatePGError(err)) + return fmt.Errorf("could not delete old subscribers: %w", dalerrs.TranslatePGError(err)) } return nil } diff --git a/backend/controller/ingress/handler.go b/backend/controller/ingress/handler.go index c875d3a19b..349727628b 100644 --- a/backend/controller/ingress/handler.go +++ b/backend/controller/ingress/handler.go @@ -13,6 +13,7 @@ import ( ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" schemapb "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/schema" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/model" ) @@ -30,7 +31,7 @@ func Handle( logger.Debugf("%s %s", r.Method, r.URL.Path) route, err := GetIngressRoute(routes, r.Method, r.URL.Path) if err != nil { - if errors.Is(err, dal.ErrNotFound) { + if errors.Is(err, dalerrs.ErrNotFound) { http.NotFound(w, r) return } diff --git a/backend/controller/ingress/ingress.go b/backend/controller/ingress/ingress.go index ce6f3ceb5a..d35a62ddd0 100644 --- a/backend/controller/ingress/ingress.go +++ b/backend/controller/ingress/ingress.go @@ -12,6 +12,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/dal" "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/slices" ) @@ -27,7 +28,7 @@ func GetIngressRoute(routes []dal.IngressRoute, method string, path string) (*da }) if len(matchedRoutes) == 0 { - return nil, dal.ErrNotFound + return nil, dalerrs.ErrNotFound } // TODO: add load balancing at some point diff --git a/backend/controller/leader/leader.go b/backend/controller/leader/leader.go index 8352b235d5..81a34c7911 100644 --- a/backend/controller/leader/leader.go +++ b/backend/controller/leader/leader.go @@ -20,8 +20,8 @@ import ( "sync" "time" - "github.com/TBD54566975/ftl/backend/controller/dal" "github.com/TBD54566975/ftl/backend/controller/leases" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/TBD54566975/ftl/internal/log" "github.com/alecthomas/types/optional" ) @@ -132,7 +132,7 @@ func (c *Coordinator[P]) Get() (leaderOrFollower P, err error) { logger.Tracef("new leader for %s: %s", c.key, c.advertise) return l, nil } - if !errors.Is(leaseErr, dal.ErrConflict) { + if !errors.Is(leaseErr, dalerrs.ErrConflict) { return leaderOrFollower, fmt.Errorf("could not acquire lease for %s: %w", c.key, leaseErr) } // lease already held @@ -155,7 +155,7 @@ func (c *Coordinator[P]) createFollower() (out P, err error) { var urlString string expiry, err := c.leaser.GetLeaseInfo(c.ctx, c.key, &urlString) if err != nil { - if errors.Is(err, dal.ErrNotFound) { + if errors.Is(err, dalerrs.ErrNotFound) { return out, fmt.Errorf("could not acquire or find lease for %s", c.key) } return out, fmt.Errorf("could not get lease for %s: %w", c.key, err) diff --git a/backend/controller/sql/querier.go b/backend/controller/sql/querier.go index 9ac20032ba..3d14409a71 100644 --- a/backend/controller/sql/querier.go +++ b/backend/controller/sql/querier.go @@ -64,7 +64,6 @@ type Querier interface { // Get the runner endpoints corresponding to the given ingress route. GetIngressRoutes(ctx context.Context, method string) ([]GetIngressRoutesRow, error) GetLeaseInfo(ctx context.Context, key leases.Key) (GetLeaseInfoRow, error) - GetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) ([]byte, error) GetModulesByID(ctx context.Context, ids []int64) ([]Module, error) GetNextEventForSubscription(ctx context.Context, consumptionDelay time.Duration, topic model.TopicKey, cursor optional.Option[model.TopicEventKey]) (GetNextEventForSubscriptionRow, error) GetProcessList(ctx context.Context) ([]GetProcessListRow, error) @@ -90,7 +89,6 @@ type Querier interface { // Mark any controller entries that haven't been updated recently as dead. KillStaleControllers(ctx context.Context, timeout time.Duration) (int64, error) KillStaleRunners(ctx context.Context, timeout time.Duration) (int64, error) - ListModuleConfiguration(ctx context.Context) ([]ModuleConfiguration, error) LoadAsyncCall(ctx context.Context, id int64) (AsyncCall, error) NewLease(ctx context.Context, key leases.Key, ttl time.Duration, metadata []byte) (uuid.UUID, error) PublishEventForTopic(ctx context.Context, arg PublishEventForTopicParams) error @@ -100,7 +98,6 @@ type Querier interface { // Find an idle runner and reserve it for the given deployment. ReserveRunner(ctx context.Context, reservationTimeout time.Time, deploymentKey model.DeploymentKey, labels []byte) (Runner, error) SetDeploymentDesiredReplicas(ctx context.Context, key model.DeploymentKey, minReplicas int32) error - SetModuleConfiguration(ctx context.Context, module optional.Option[string], name string, value []byte) error StartCronJobs(ctx context.Context, keys []string) ([]StartCronJobsRow, error) // Start a new FSM transition, populating the destination state and async call ID. // @@ -108,7 +105,6 @@ type Querier interface { StartFSMTransition(ctx context.Context, arg StartFSMTransitionParams) (FsmInstance, error) SucceedAsyncCall(ctx context.Context, response []byte, iD int64) (bool, error) SucceedFSMInstance(ctx context.Context, fsm schema.RefKey, key string) (bool, error) - UnsetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) error UpsertController(ctx context.Context, key model.ControllerKey, endpoint string) (int64, error) UpsertModule(ctx context.Context, language string, name string) (int64, error) // Upsert a runner and return the deployment ID that it is assigned to, if any. diff --git a/backend/controller/sql/queries.sql b/backend/controller/sql/queries.sql index 0c7d97cb45..f766822048 100644 --- a/backend/controller/sql/queries.sql +++ b/backend/controller/sql/queries.sql @@ -798,25 +798,3 @@ UPDATE topic_subscriptions SET state = 'idle' WHERE name = @name::TEXT AND module_id = (SELECT id FROM module); - --- name: GetModuleConfiguration :one -SELECT value -FROM module_configuration -WHERE - (module IS NULL OR module = @module) - AND name = @name -ORDER BY module NULLS LAST -LIMIT 1; - --- name: ListModuleConfiguration :many -SELECT * -FROM module_configuration -ORDER BY module, name; - --- name: SetModuleConfiguration :exec -INSERT INTO module_configuration (module, name, value) -VALUES ($1, $2, $3); - --- name: UnsetModuleConfiguration :exec -DELETE FROM module_configuration -WHERE module = @module AND name = @name; diff --git a/backend/controller/sql/queries.sql.go b/backend/controller/sql/queries.sql.go index f29e291933..1f5a26efb8 100644 --- a/backend/controller/sql/queries.sql.go +++ b/backend/controller/sql/queries.sql.go @@ -1184,23 +1184,6 @@ func (q *Queries) GetLeaseInfo(ctx context.Context, key leases.Key) (GetLeaseInf return i, err } -const getModuleConfiguration = `-- name: GetModuleConfiguration :one -SELECT value -FROM module_configuration -WHERE - (module IS NULL OR module = $1) - AND name = $2 -ORDER BY module NULLS LAST -LIMIT 1 -` - -func (q *Queries) GetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) ([]byte, error) { - row := q.db.QueryRow(ctx, getModuleConfiguration, module, name) - var value []byte - err := row.Scan(&value) - return value, err -} - const getModulesByID = `-- name: GetModulesByID :many SELECT id, language, name FROM modules @@ -1932,38 +1915,6 @@ func (q *Queries) KillStaleRunners(ctx context.Context, timeout time.Duration) ( return count, err } -const listModuleConfiguration = `-- name: ListModuleConfiguration :many -SELECT id, created_at, module, name, value -FROM module_configuration -ORDER BY module, name -` - -func (q *Queries) ListModuleConfiguration(ctx context.Context) ([]ModuleConfiguration, error) { - rows, err := q.db.Query(ctx, listModuleConfiguration) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ModuleConfiguration - for rows.Next() { - var i ModuleConfiguration - if err := rows.Scan( - &i.ID, - &i.CreatedAt, - &i.Module, - &i.Name, - &i.Value, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const loadAsyncCall = `-- name: LoadAsyncCall :one SELECT id, created_at, lease_id, verb, state, origin, scheduled_at, request, response, error, remaining_attempts, backoff, max_backoff FROM async_calls @@ -2146,16 +2097,6 @@ func (q *Queries) SetDeploymentDesiredReplicas(ctx context.Context, key model.De return err } -const setModuleConfiguration = `-- name: SetModuleConfiguration :exec -INSERT INTO module_configuration (module, name, value) -VALUES ($1, $2, $3) -` - -func (q *Queries) SetModuleConfiguration(ctx context.Context, module optional.Option[string], name string, value []byte) error { - _, err := q.db.Exec(ctx, setModuleConfiguration, module, name, value) - return err -} - const startCronJobs = `-- name: StartCronJobs :many WITH updates AS ( UPDATE cron_jobs @@ -2310,16 +2251,6 @@ func (q *Queries) SucceedFSMInstance(ctx context.Context, fsm schema.RefKey, key return column_1, err } -const unsetModuleConfiguration = `-- name: UnsetModuleConfiguration :exec -DELETE FROM module_configuration -WHERE module = $1 AND name = $2 -` - -func (q *Queries) UnsetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) error { - _, err := q.db.Exec(ctx, unsetModuleConfiguration, module, name) - return err -} - const upsertController = `-- name: UpsertController :one INSERT INTO controller (key, endpoint) VALUES ($1, $2) diff --git a/cmd/ftl-controller/main.go b/cmd/ftl-controller/main.go index 4181a6fe7d..e6506a78fd 100644 --- a/cmd/ftl-controller/main.go +++ b/cmd/ftl-controller/main.go @@ -14,9 +14,9 @@ import ( "github.com/TBD54566975/ftl" "github.com/TBD54566975/ftl/backend/controller" - "github.com/TBD54566975/ftl/backend/controller/dal" "github.com/TBD54566975/ftl/backend/controller/scaling" cf "github.com/TBD54566975/ftl/common/configuration" + cfdal "github.com/TBD54566975/ftl/common/configuration/dal" _ "github.com/TBD54566975/ftl/internal/automaxprocs" // Set GOMAXPROCS to match Linux container CPU quota. "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/observability" @@ -47,7 +47,7 @@ func main() { // The FTL controller currently only supports DB as a configuration provider/resolver. conn, err := pgxpool.New(ctx, cli.ControllerConfig.DSN) kctx.FatalIfErrorf(err) - dal, err := dal.New(ctx, conn) + dal, err := cfdal.New(ctx, conn) kctx.FatalIfErrorf(err) configProviders := []cf.Provider[cf.Configuration]{cf.NewDBConfigProvider(dal)} configResolver := cf.NewDBConfigResolver(dal) diff --git a/common/configuration/dal/dal.go b/common/configuration/dal/dal.go new file mode 100644 index 0000000000..7648792c62 --- /dev/null +++ b/common/configuration/dal/dal.go @@ -0,0 +1,47 @@ +// Package dal provides a data abstraction layer for managing module configurations +package dal + +import ( + "context" + + "github.com/alecthomas/types/optional" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/TBD54566975/ftl/common/configuration/sql" + "github.com/TBD54566975/ftl/db/dalerrs" +) + +type DAL struct { + db sql.DBI +} + +func New(ctx context.Context, pool *pgxpool.Pool) (*DAL, error) { + dal := &DAL{db: sql.NewDB(pool)} + return dal, nil +} + +func (d *DAL) GetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) ([]byte, error) { + b, err := d.db.GetModuleConfiguration(ctx, module, name) + if err != nil { + return nil, dalerrs.TranslatePGError(err) + } + return b, nil +} + +func (d *DAL) SetModuleConfiguration(ctx context.Context, module optional.Option[string], name string, value []byte) error { + err := d.db.SetModuleConfiguration(ctx, module, name, value) + return dalerrs.TranslatePGError(err) +} + +func (d *DAL) UnsetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) error { + err := d.db.UnsetModuleConfiguration(ctx, module, name) + return dalerrs.TranslatePGError(err) +} + +func (d *DAL) ListModuleConfiguration(ctx context.Context) ([]sql.ModuleConfiguration, error) { + l, err := d.db.ListModuleConfiguration(ctx) + if err != nil { + return nil, dalerrs.TranslatePGError(err) + } + return l, nil +} diff --git a/common/configuration/dal/dal_test.go b/common/configuration/dal/dal_test.go new file mode 100644 index 0000000000..6b4434cedb --- /dev/null +++ b/common/configuration/dal/dal_test.go @@ -0,0 +1,103 @@ +package dal + +import ( + "context" + "testing" + + "github.com/alecthomas/assert/v2" + "github.com/alecthomas/types/optional" + + "github.com/TBD54566975/ftl/backend/controller/sql" + "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" + "github.com/TBD54566975/ftl/internal/log" +) + +func TestModuleConfiguration(t *testing.T) { + ctx := log.ContextWithNewDefaultLogger(context.Background()) + conn := sqltest.OpenForTesting(ctx, t) + dal, err := New(ctx, conn) + assert.NoError(t, err) + assert.NotZero(t, dal) + + tests := []struct { + TestName string + ModuleSet optional.Option[string] + ModuleGet optional.Option[string] + PresetGlobal bool + }{ + { + "SetModuleGetModule", + optional.Some("echo"), + optional.Some("echo"), + false, + }, + { + "SetGlobalGetGlobal", + optional.None[string](), + optional.None[string](), + false, + }, + { + "SetGlobalGetModule", + optional.None[string](), + optional.Some("echo"), + false, + }, + { + "SetModuleOverridesGlobal", + optional.Some("echo"), + optional.Some("echo"), + true, + }, + } + + b := []byte(`"asdf"`) + for _, test := range tests { + t.Run(test.TestName, func(t *testing.T) { + if test.PresetGlobal { + err := dal.SetModuleConfiguration(ctx, optional.None[string](), "configname", []byte(`"qwerty"`)) + assert.NoError(t, err) + } + err := dal.SetModuleConfiguration(ctx, test.ModuleSet, "configname", b) + assert.NoError(t, err) + gotBytes, err := dal.GetModuleConfiguration(ctx, test.ModuleGet, "configname") + assert.NoError(t, err) + assert.Equal(t, b, gotBytes) + err = dal.UnsetModuleConfiguration(ctx, test.ModuleGet, "configname") + assert.NoError(t, err) + }) + } + + t.Run("List", func(t *testing.T) { + sortedList := []sql.ModuleConfiguration{ + { + Module: optional.Some("echo"), + Name: "a", + }, + { + Module: optional.Some("echo"), + Name: "b", + }, + { + Module: optional.None[string](), + Name: "a", + }, + } + + // Insert entries in a separate order from how they should be returned to + // test sorting logic in the SQL query + err := dal.SetModuleConfiguration(ctx, sortedList[1].Module, sortedList[1].Name, []byte(`""`)) + assert.NoError(t, err) + err = dal.SetModuleConfiguration(ctx, sortedList[2].Module, sortedList[2].Name, []byte(`""`)) + assert.NoError(t, err) + err = dal.SetModuleConfiguration(ctx, sortedList[0].Module, sortedList[0].Name, []byte(`""`)) + assert.NoError(t, err) + + gotList, err := dal.ListModuleConfiguration(ctx) + assert.NoError(t, err) + for i := range sortedList { + assert.Equal(t, sortedList[i].Module, gotList[i].Module) + assert.Equal(t, sortedList[i].Name, gotList[i].Name) + } + }) +} diff --git a/common/configuration/db_config_provider.go b/common/configuration/db_config_provider.go index efee5a5810..7337cfe436 100644 --- a/common/configuration/db_config_provider.go +++ b/common/configuration/db_config_provider.go @@ -4,7 +4,7 @@ import ( "context" "net/url" - "github.com/TBD54566975/ftl/backend/controller/dal" + "github.com/TBD54566975/ftl/db/dalerrs" "github.com/alecthomas/types/optional" ) @@ -31,7 +31,7 @@ func (DBConfigProvider) Key() string { return "db" } func (d DBConfigProvider) Load(ctx context.Context, ref Ref, key *url.URL) ([]byte, error) { value, err := d.dal.GetModuleConfiguration(ctx, ref.Module, ref.Name) if err != nil { - return nil, dal.ErrNotFound + return nil, dalerrs.ErrNotFound } return value, nil } diff --git a/common/configuration/db_config_resolver.go b/common/configuration/db_config_resolver.go index 7f04d00b12..3fc5aeb538 100644 --- a/common/configuration/db_config_resolver.go +++ b/common/configuration/db_config_resolver.go @@ -4,7 +4,7 @@ import ( "context" "net/url" - "github.com/TBD54566975/ftl/backend/controller/sql" + "github.com/TBD54566975/ftl/common/configuration/sql" "github.com/TBD54566975/ftl/internal/slices" ) diff --git a/common/configuration/db_config_resolver_test.go b/common/configuration/db_config_resolver_test.go index 5204d4f9a3..3ff3bd403f 100644 --- a/common/configuration/db_config_resolver_test.go +++ b/common/configuration/db_config_resolver_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/TBD54566975/ftl/backend/controller/sql" + "github.com/TBD54566975/ftl/common/configuration/sql" "github.com/alecthomas/assert/v2" ) diff --git a/common/configuration/sql/conn.go b/common/configuration/sql/conn.go new file mode 100644 index 0000000000..065487cefa --- /dev/null +++ b/common/configuration/sql/conn.go @@ -0,0 +1,21 @@ +package sql + +type DBI interface { + Querier + Conn() ConnI +} + +type ConnI interface { + DBTX +} + +type DB struct { + conn ConnI + *Queries +} + +func NewDB(conn ConnI) *DB { + return &DB{conn: conn, Queries: New(conn)} +} + +func (d *DB) Conn() ConnI { return d.conn } diff --git a/common/configuration/sql/db.go b/common/configuration/sql/db.go new file mode 100644 index 0000000000..c4b45fb311 --- /dev/null +++ b/common/configuration/sql/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.26.0 + +package sql + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/common/configuration/sql/models.go b/common/configuration/sql/models.go new file mode 100644 index 0000000000..6d2095f7ba --- /dev/null +++ b/common/configuration/sql/models.go @@ -0,0 +1,541 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.26.0 + +package sql + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "time" + + "github.com/TBD54566975/ftl/backend/controller/leases" + "github.com/TBD54566975/ftl/backend/schema" + "github.com/TBD54566975/ftl/internal/model" + "github.com/alecthomas/types/optional" + "github.com/google/uuid" +) + +type AsyncCallState string + +const ( + AsyncCallStatePending AsyncCallState = "pending" + AsyncCallStateExecuting AsyncCallState = "executing" + AsyncCallStateSuccess AsyncCallState = "success" + AsyncCallStateError AsyncCallState = "error" +) + +func (e *AsyncCallState) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AsyncCallState(s) + case string: + *e = AsyncCallState(s) + default: + return fmt.Errorf("unsupported scan type for AsyncCallState: %T", src) + } + return nil +} + +type NullAsyncCallState struct { + AsyncCallState AsyncCallState + Valid bool // Valid is true if AsyncCallState is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAsyncCallState) Scan(value interface{}) error { + if value == nil { + ns.AsyncCallState, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AsyncCallState.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAsyncCallState) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AsyncCallState), nil +} + +type ControllerState string + +const ( + ControllerStateLive ControllerState = "live" + ControllerStateDead ControllerState = "dead" +) + +func (e *ControllerState) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ControllerState(s) + case string: + *e = ControllerState(s) + default: + return fmt.Errorf("unsupported scan type for ControllerState: %T", src) + } + return nil +} + +type NullControllerState struct { + ControllerState ControllerState + Valid bool // Valid is true if ControllerState is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullControllerState) Scan(value interface{}) error { + if value == nil { + ns.ControllerState, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.ControllerState.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullControllerState) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.ControllerState), nil +} + +type CronJobState string + +const ( + CronJobStateIdle CronJobState = "idle" + CronJobStateExecuting CronJobState = "executing" +) + +func (e *CronJobState) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = CronJobState(s) + case string: + *e = CronJobState(s) + default: + return fmt.Errorf("unsupported scan type for CronJobState: %T", src) + } + return nil +} + +type NullCronJobState struct { + CronJobState CronJobState + Valid bool // Valid is true if CronJobState is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullCronJobState) Scan(value interface{}) error { + if value == nil { + ns.CronJobState, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.CronJobState.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullCronJobState) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.CronJobState), nil +} + +type EventType string + +const ( + EventTypeCall EventType = "call" + EventTypeLog EventType = "log" + EventTypeDeploymentCreated EventType = "deployment_created" + EventTypeDeploymentUpdated EventType = "deployment_updated" +) + +func (e *EventType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = EventType(s) + case string: + *e = EventType(s) + default: + return fmt.Errorf("unsupported scan type for EventType: %T", src) + } + return nil +} + +type NullEventType struct { + EventType EventType + Valid bool // Valid is true if EventType is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullEventType) Scan(value interface{}) error { + if value == nil { + ns.EventType, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.EventType.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullEventType) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.EventType), nil +} + +type FsmStatus string + +const ( + FsmStatusRunning FsmStatus = "running" + FsmStatusCompleted FsmStatus = "completed" + FsmStatusFailed FsmStatus = "failed" +) + +func (e *FsmStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = FsmStatus(s) + case string: + *e = FsmStatus(s) + default: + return fmt.Errorf("unsupported scan type for FsmStatus: %T", src) + } + return nil +} + +type NullFsmStatus struct { + FsmStatus FsmStatus + Valid bool // Valid is true if FsmStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullFsmStatus) Scan(value interface{}) error { + if value == nil { + ns.FsmStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.FsmStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullFsmStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.FsmStatus), nil +} + +type Origin string + +const ( + OriginIngress Origin = "ingress" + OriginCron Origin = "cron" + OriginPubsub Origin = "pubsub" +) + +func (e *Origin) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = Origin(s) + case string: + *e = Origin(s) + default: + return fmt.Errorf("unsupported scan type for Origin: %T", src) + } + return nil +} + +type NullOrigin struct { + Origin Origin + Valid bool // Valid is true if Origin is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullOrigin) Scan(value interface{}) error { + if value == nil { + ns.Origin, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.Origin.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullOrigin) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.Origin), nil +} + +type RunnerState string + +const ( + RunnerStateIdle RunnerState = "idle" + RunnerStateReserved RunnerState = "reserved" + RunnerStateAssigned RunnerState = "assigned" + RunnerStateDead RunnerState = "dead" +) + +func (e *RunnerState) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = RunnerState(s) + case string: + *e = RunnerState(s) + default: + return fmt.Errorf("unsupported scan type for RunnerState: %T", src) + } + return nil +} + +type NullRunnerState struct { + RunnerState RunnerState + Valid bool // Valid is true if RunnerState is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullRunnerState) Scan(value interface{}) error { + if value == nil { + ns.RunnerState, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.RunnerState.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullRunnerState) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.RunnerState), nil +} + +type TopicSubscriptionState string + +const ( + TopicSubscriptionStateIdle TopicSubscriptionState = "idle" + TopicSubscriptionStateExecuting TopicSubscriptionState = "executing" +) + +func (e *TopicSubscriptionState) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = TopicSubscriptionState(s) + case string: + *e = TopicSubscriptionState(s) + default: + return fmt.Errorf("unsupported scan type for TopicSubscriptionState: %T", src) + } + return nil +} + +type NullTopicSubscriptionState struct { + TopicSubscriptionState TopicSubscriptionState + Valid bool // Valid is true if TopicSubscriptionState is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullTopicSubscriptionState) Scan(value interface{}) error { + if value == nil { + ns.TopicSubscriptionState, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.TopicSubscriptionState.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullTopicSubscriptionState) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.TopicSubscriptionState), nil +} + +type Artefact struct { + ID int64 + CreatedAt time.Time + Digest []byte + Content []byte +} + +type AsyncCall struct { + ID int64 + CreatedAt time.Time + LeaseID optional.Option[int64] + Verb schema.RefKey + State AsyncCallState + Origin string + ScheduledAt time.Time + Request []byte + Response []byte + Error optional.Option[string] + RemainingAttempts int32 + Backoff time.Duration + MaxBackoff time.Duration +} + +type Controller struct { + ID int64 + Key model.ControllerKey + Created time.Time + LastSeen time.Time + State ControllerState + Endpoint string +} + +type CronJob struct { + ID int64 + Key model.CronJobKey + DeploymentID int64 + Verb string + Schedule string + StartTime time.Time + NextExecution time.Time + State model.CronJobState + ModuleName string +} + +type Deployment struct { + ID int64 + CreatedAt time.Time + ModuleID int64 + Key model.DeploymentKey + Schema *schema.Module + Labels []byte + MinReplicas int32 +} + +type DeploymentArtefact struct { + ArtefactID int64 + DeploymentID int64 + CreatedAt time.Time + Executable bool + Path string +} + +type Event struct { + ID int64 + TimeStamp time.Time + DeploymentID int64 + RequestID optional.Option[int64] + Type EventType + CustomKey1 optional.Option[string] + CustomKey2 optional.Option[string] + CustomKey3 optional.Option[string] + CustomKey4 optional.Option[string] + Payload json.RawMessage +} + +type FsmInstance struct { + ID int64 + CreatedAt time.Time + Fsm schema.RefKey + Key string + Status FsmStatus + CurrentState optional.Option[schema.RefKey] + DestinationState optional.Option[schema.RefKey] + AsyncCallID optional.Option[int64] +} + +type IngressRoute struct { + Method string + Path string + DeploymentID int64 + Module string + Verb string +} + +type Lease struct { + ID int64 + IdempotencyKey uuid.UUID + Key leases.Key + CreatedAt time.Time + ExpiresAt time.Time + Metadata []byte +} + +type Module struct { + ID int64 + Language string + Name string +} + +type ModuleConfiguration struct { + ID int64 + CreatedAt time.Time + Module optional.Option[string] + Name string + Value []byte +} + +type Request struct { + ID int64 + Origin Origin + Key model.RequestKey + SourceAddr string +} + +type Runner struct { + ID int64 + Key model.RunnerKey + Created time.Time + LastSeen time.Time + ReservationTimeout optional.Option[time.Time] + State RunnerState + Endpoint string + ModuleName optional.Option[string] + DeploymentID optional.Option[int64] + Labels []byte +} + +type Topic struct { + ID int64 + Key model.TopicKey + CreatedAt time.Time + ModuleID int64 + Name string + Type string + Head optional.Option[int64] +} + +type TopicEvent struct { + ID int64 + CreatedAt time.Time + Key model.TopicEventKey + TopicID int64 + Payload []byte +} + +type TopicSubscriber struct { + ID int64 + Key model.SubscriberKey + CreatedAt time.Time + TopicSubscriptionsID int64 + DeploymentID int64 + Sink schema.RefKey + RetryAttempts int32 + Backoff time.Duration + MaxBackoff time.Duration +} + +type TopicSubscription struct { + ID int64 + Key model.SubscriptionKey + CreatedAt time.Time + TopicID int64 + ModuleID int64 + DeploymentID int64 + Name string + Cursor optional.Option[int64] + State TopicSubscriptionState +} diff --git a/common/configuration/sql/querier.go b/common/configuration/sql/querier.go new file mode 100644 index 0000000000..17b7d75e74 --- /dev/null +++ b/common/configuration/sql/querier.go @@ -0,0 +1,20 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.26.0 + +package sql + +import ( + "context" + + "github.com/alecthomas/types/optional" +) + +type Querier interface { + GetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) ([]byte, error) + ListModuleConfiguration(ctx context.Context) ([]ModuleConfiguration, error) + SetModuleConfiguration(ctx context.Context, module optional.Option[string], name string, value []byte) error + UnsetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) error +} + +var _ Querier = (*Queries)(nil) diff --git a/common/configuration/sql/queries.sql b/common/configuration/sql/queries.sql new file mode 100644 index 0000000000..bf8835e2f3 --- /dev/null +++ b/common/configuration/sql/queries.sql @@ -0,0 +1,21 @@ +-- name: GetModuleConfiguration :one +SELECT value +FROM module_configuration +WHERE + (module IS NULL OR module = @module) + AND name = @name +ORDER BY module NULLS LAST +LIMIT 1; + +-- name: ListModuleConfiguration :many +SELECT * +FROM module_configuration +ORDER BY module, name; + +-- name: SetModuleConfiguration :exec +INSERT INTO module_configuration (module, name, value) +VALUES ($1, $2, $3); + +-- name: UnsetModuleConfiguration :exec +DELETE FROM module_configuration +WHERE module = @module AND name = @name; diff --git a/common/configuration/sql/queries.sql.go b/common/configuration/sql/queries.sql.go new file mode 100644 index 0000000000..9a640d3efa --- /dev/null +++ b/common/configuration/sql/queries.sql.go @@ -0,0 +1,81 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.26.0 +// source: queries.sql + +package sql + +import ( + "context" + + "github.com/alecthomas/types/optional" +) + +const getModuleConfiguration = `-- name: GetModuleConfiguration :one +SELECT value +FROM module_configuration +WHERE + (module IS NULL OR module = $1) + AND name = $2 +ORDER BY module NULLS LAST +LIMIT 1 +` + +func (q *Queries) GetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) ([]byte, error) { + row := q.db.QueryRow(ctx, getModuleConfiguration, module, name) + var value []byte + err := row.Scan(&value) + return value, err +} + +const listModuleConfiguration = `-- name: ListModuleConfiguration :many +SELECT id, created_at, module, name, value +FROM module_configuration +ORDER BY module, name +` + +func (q *Queries) ListModuleConfiguration(ctx context.Context) ([]ModuleConfiguration, error) { + rows, err := q.db.Query(ctx, listModuleConfiguration) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ModuleConfiguration + for rows.Next() { + var i ModuleConfiguration + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.Module, + &i.Name, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const setModuleConfiguration = `-- name: SetModuleConfiguration :exec +INSERT INTO module_configuration (module, name, value) +VALUES ($1, $2, $3) +` + +func (q *Queries) SetModuleConfiguration(ctx context.Context, module optional.Option[string], name string, value []byte) error { + _, err := q.db.Exec(ctx, setModuleConfiguration, module, name, value) + return err +} + +const unsetModuleConfiguration = `-- name: UnsetModuleConfiguration :exec +DELETE FROM module_configuration +WHERE module = $1 AND name = $2 +` + +func (q *Queries) UnsetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) error { + _, err := q.db.Exec(ctx, unsetModuleConfiguration, module, name) + return err +} diff --git a/db/dalerrs/dalerrs.go b/db/dalerrs/dalerrs.go new file mode 100644 index 0000000000..216ebbcfc8 --- /dev/null +++ b/db/dalerrs/dalerrs.go @@ -0,0 +1,55 @@ +// Package dalerrs provides common error handling utilities for all domain-specific DALs, +// e.g. controller DAL and configuration DAL, which all connect to the same underlying DB +// and maintain the same interface guarantees +package dalerrs + +import ( + stdsql "database/sql" + "errors" + "fmt" + "strings" + + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +var ( + // ErrConflict is returned by select methods in the DAL when a resource already exists. + // + // Its use will be documented in the corresponding methods. + ErrConflict = errors.New("conflict") + // ErrNotFound is returned by select methods in the DAL when no results are found. + ErrNotFound = errors.New("not found") + // ErrConstraint is returned by select methods in the DAL when a constraint is violated. + ErrConstraint = errors.New("constraint violation") +) + +func IsNotFound(err error) bool { + return errors.Is(err, stdsql.ErrNoRows) || errors.Is(err, pgx.ErrNoRows) +} + +func TranslatePGError(err error) error { + if err == nil { + return nil + } + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + switch pgErr.Code { + case pgerrcode.ForeignKeyViolation: + return fmt.Errorf("%s: %w", strings.TrimSuffix(strings.TrimPrefix(pgErr.ConstraintName, pgErr.TableName+"_"), "_id_fkey"), ErrNotFound) + case pgerrcode.UniqueViolation: + return fmt.Errorf("%s: %w", pgErr.Message, ErrConflict) + case pgerrcode.IntegrityConstraintViolation, + pgerrcode.RestrictViolation, + pgerrcode.NotNullViolation, + pgerrcode.CheckViolation, + pgerrcode.ExclusionViolation: + return fmt.Errorf("%s: %w", pgErr.Message, ErrConstraint) + default: + } + } else if IsNotFound(err) { + return ErrNotFound + } + return err +} diff --git a/sqlc.yaml b/sqlc.yaml index 28d2ff675c..cd5fd3dcf3 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -1,12 +1,13 @@ version: "2" sql: - - engine: "postgresql" + - &daldir + engine: "postgresql" queries: "backend/controller/sql/queries.sql" schema: "backend/controller/sql/schema" database: uri: postgres://localhost:15432/ftl?sslmode=disable&user=postgres&password=secret gen: - go: + go: &gengo package: "sql" sql_package: "pgx/v5" out: "backend/controller/sql" @@ -137,6 +138,12 @@ sql: - sqlc/db-prepare # - postgresql-query-too-costly - postgresql-no-seq-scan + - <<: *daldir + queries: "common/configuration/sql/queries.sql" + gen: + go: + <<: *gengo + out: "common/configuration/sql" rules: - name: postgresql-query-too-costly message: "Query cost estimate is too high"