diff --git a/core/cmd/shell.go b/core/cmd/shell.go index e1ac0b99caa..07cd2185dc2 100644 --- a/core/cmd/shell.go +++ b/core/cmd/shell.go @@ -145,7 +145,7 @@ func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.G return nil, err } - err = handleNodeVersioning(db, appLggr, cfg.RootDir(), cfg.Database(), cfg.WebServer().HTTPPort()) + err = handleNodeVersioning(ctx, db, appLggr, cfg.RootDir(), cfg.Database(), cfg.WebServer().HTTPPort()) if err != nil { return nil, err } @@ -229,7 +229,7 @@ func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.G } // handleNodeVersioning is a setup-time helper to encapsulate version changes and db migration -func handleNodeVersioning(db *sqlx.DB, appLggr logger.Logger, rootDir string, cfg config.Database, healthReportPort uint16) error { +func handleNodeVersioning(ctx context.Context, db *sqlx.DB, appLggr logger.Logger, rootDir string, cfg config.Database, healthReportPort uint16) error { var err error // Set up the versioning Configs verORM := versioning.NewORM(db, appLggr, cfg.DefaultQueryTimeout()) @@ -260,7 +260,7 @@ func handleNodeVersioning(db *sqlx.DB, appLggr logger.Logger, rootDir string, cf // Migrate the database if cfg.MigrateDatabase() { - if err = migrate.Migrate(db.DB, appLggr); err != nil { + if err = migrate.Migrate(ctx, db.DB, appLggr); err != nil { return fmt.Errorf("initializeORM#Migrate: %w", err) } } diff --git a/core/cmd/shell_local.go b/core/cmd/shell_local.go index 954cead5c37..70d75f761b5 100644 --- a/core/cmd/shell_local.go +++ b/core/cmd/shell_local.go @@ -705,9 +705,17 @@ func (s *Shell) validateDB(c *cli.Context) error { return s.configExitErr(s.Config.ValidateDB) } +// ctx returns a context.Context that will be cancelled when SIGINT|SIGTERM is received +func (s *Shell) ctx() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + go shutdown.HandleShutdown(func(_ string) { cancel() }) + return ctx +} + // ResetDatabase drops, creates and migrates the database specified by CL_DATABASE_URL or Database.URL // in secrets TOML. This is useful to set up the database for testing func (s *Shell) ResetDatabase(c *cli.Context) error { + ctx := s.ctx() cfg := s.Config.Database() parsed := cfg.URL() if parsed.String() == "" { @@ -727,7 +735,7 @@ func (s *Shell) ResetDatabase(c *cli.Context) error { return s.errorOut(err) } lggr.Debugf("Migrating database: %#v", parsed.String()) - if err := migrateDB(cfg, lggr); err != nil { + if err := migrateDB(ctx, cfg, lggr); err != nil { return s.errorOut(err) } schema, err := dumpSchema(parsed) @@ -736,7 +744,7 @@ func (s *Shell) ResetDatabase(c *cli.Context) error { } lggr.Debugf("Testing rollback and re-migrate for database: %#v", parsed.String()) var baseVersionID int64 = 54 - if err := downAndUpDB(cfg, lggr, baseVersionID); err != nil { + if err := downAndUpDB(ctx, cfg, lggr, baseVersionID); err != nil { return s.errorOut(err) } if err := checkSchema(parsed, schema); err != nil { @@ -828,6 +836,7 @@ func (s *Shell) PrepareTestDatabaseUserOnly(c *cli.Context) error { // MigrateDatabase migrates the database func (s *Shell) MigrateDatabase(_ *cli.Context) error { + ctx := s.ctx() cfg := s.Config.Database() parsed := cfg.URL() if parsed.String() == "" { @@ -840,7 +849,7 @@ func (s *Shell) MigrateDatabase(_ *cli.Context) error { } s.Logger.Infof("Migrating database: %#v", parsed.String()) - if err := migrateDB(cfg, s.Logger); err != nil { + if err := migrateDB(ctx, cfg, s.Logger); err != nil { return s.errorOut(err) } return nil @@ -848,6 +857,7 @@ func (s *Shell) MigrateDatabase(_ *cli.Context) error { // RollbackDatabase rolls back the database via down migrations. func (s *Shell) RollbackDatabase(c *cli.Context) error { + ctx := s.ctx() var version null.Int if c.Args().Present() { arg := c.Args().First() @@ -863,7 +873,7 @@ func (s *Shell) RollbackDatabase(c *cli.Context) error { return fmt.Errorf("failed to initialize orm: %v", err) } - if err := migrate.Rollback(db.DB, s.Logger, version); err != nil { + if err := migrate.Rollback(ctx, db.DB, s.Logger, version); err != nil { return fmt.Errorf("migrateDB failed: %v", err) } @@ -872,12 +882,13 @@ func (s *Shell) RollbackDatabase(c *cli.Context) error { // VersionDatabase displays the current database version. func (s *Shell) VersionDatabase(_ *cli.Context) error { + ctx := s.ctx() db, err := newConnection(s.Config.Database()) if err != nil { return fmt.Errorf("failed to initialize orm: %v", err) } - version, err := migrate.Current(db.DB, s.Logger) + version, err := migrate.Current(ctx, db.DB, s.Logger) if err != nil { return fmt.Errorf("migrateDB failed: %v", err) } @@ -888,12 +899,13 @@ func (s *Shell) VersionDatabase(_ *cli.Context) error { // StatusDatabase displays the database migration status func (s *Shell) StatusDatabase(_ *cli.Context) error { + ctx := s.ctx() db, err := newConnection(s.Config.Database()) if err != nil { return fmt.Errorf("failed to initialize orm: %v", err) } - if err = migrate.Status(db.DB, s.Logger); err != nil { + if err = migrate.Status(ctx, db.DB, s.Logger); err != nil { return fmt.Errorf("Status failed: %v", err) } return nil @@ -1017,27 +1029,27 @@ func dropAndCreatePristineDB(db *sqlx.DB, template string) (err error) { return nil } -func migrateDB(config dbConfig, lggr logger.Logger) error { +func migrateDB(ctx context.Context, config dbConfig, lggr logger.Logger) error { db, err := newConnection(config) if err != nil { return fmt.Errorf("failed to initialize orm: %v", err) } - if err = migrate.Migrate(db.DB, lggr); err != nil { + if err = migrate.Migrate(ctx, db.DB, lggr); err != nil { return fmt.Errorf("migrateDB failed: %v", err) } return db.Close() } -func downAndUpDB(cfg dbConfig, lggr logger.Logger, baseVersionID int64) error { +func downAndUpDB(ctx context.Context, cfg dbConfig, lggr logger.Logger, baseVersionID int64) error { db, err := newConnection(cfg) if err != nil { return fmt.Errorf("failed to initialize orm: %v", err) } - if err = migrate.Rollback(db.DB, lggr, null.IntFrom(baseVersionID)); err != nil { + if err = migrate.Rollback(ctx, db.DB, lggr, null.IntFrom(baseVersionID)); err != nil { return fmt.Errorf("test rollback failed: %v", err) } - if err = migrate.Migrate(db.DB, lggr); err != nil { + if err = migrate.Migrate(ctx, db.DB, lggr); err != nil { return fmt.Errorf("second migrateDB failed: %v", err) } return db.Close() diff --git a/core/store/migrate/migrate.go b/core/store/migrate/migrate.go index 1e58d7a0b05..04da4c48e92 100644 --- a/core/store/migrate/migrate.go +++ b/core/store/migrate/migrate.go @@ -36,22 +36,22 @@ func init() { } // Ensure we migrated from v1 migrations to goose_migrations -func ensureMigrated(db *sql.DB, lggr logger.Logger) { +func ensureMigrated(ctx context.Context, db *sql.DB, lggr logger.Logger) error { sqlxDB := pg.WrapDbWithSqlx(db) var names []string - err := sqlxDB.Select(&names, `SELECT id FROM migrations`) + err := sqlxDB.SelectContext(ctx, &names, `SELECT id FROM migrations`) if err != nil { // already migrated - return + return nil } - err = pg.SqlTransaction(context.Background(), db, lggr, func(tx *sqlx.Tx) error { + err = pg.SqlTransaction(ctx, db, lggr, func(tx *sqlx.Tx) error { // ensure that no legacy job specs are present: we _must_ bail out early if // so because otherwise we run the risk of dropping working jobs if the // user has not read the release notes return migrations.CheckNoLegacyJobs(tx.Tx) }) if err != nil { - panic(err) + return err } // Look for the squashed migration. If not present, the db needs to be migrated on an earlier release first @@ -62,18 +62,18 @@ func ensureMigrated(db *sql.DB, lggr logger.Logger) { } } if !found { - panic("Database state is too old. Need to migrate to chainlink version 0.9.10 first before upgrading to this version. This upgrade is NOT REVERSIBLE, so it is STRONGLY RECOMMENDED that you take a database backup before continuing.") + return errors.New("database state is too old. Need to migrate to chainlink version 0.9.10 first before upgrading to this version. This upgrade is NOT REVERSIBLE, so it is STRONGLY RECOMMENDED that you take a database backup before continuing") } // ensure a goose migrations table exists with it's initial v0 if _, err = goose.GetDBVersion(db); err != nil { - panic(err) + return err } // insert records for existing migrations //nolint sql := fmt.Sprintf(`INSERT INTO %s (version_id, is_applied) VALUES ($1, true);`, goose.TableName()) - err = pg.SqlTransaction(context.Background(), db, lggr, func(tx *sqlx.Tx) error { + return pg.SqlTransaction(ctx, db, lggr, func(tx *sqlx.Tx) error { for _, name := range names { var id int64 // the first migration doesn't follow the naming convention @@ -100,33 +100,38 @@ func ensureMigrated(db *sql.DB, lggr logger.Logger) { _, err = db.Exec("DROP TABLE migrations;") return err }) - if err != nil { - panic(err) - } } -func Migrate(db *sql.DB, lggr logger.Logger) error { - ensureMigrated(db, lggr) +func Migrate(ctx context.Context, db *sql.DB, lggr logger.Logger) error { + if err := ensureMigrated(ctx, db, lggr); err != nil { + return err + } // WithAllowMissing is necessary when upgrading from 0.10.14 since it // includes out-of-order migrations return goose.Up(db, MIGRATIONS_DIR, goose.WithAllowMissing()) } -func Rollback(db *sql.DB, lggr logger.Logger, version null.Int) error { - ensureMigrated(db, lggr) +func Rollback(ctx context.Context, db *sql.DB, lggr logger.Logger, version null.Int) error { + if err := ensureMigrated(ctx, db, lggr); err != nil { + return err + } if version.Valid { return goose.DownTo(db, MIGRATIONS_DIR, version.Int64) } return goose.Down(db, MIGRATIONS_DIR) } -func Current(db *sql.DB, lggr logger.Logger) (int64, error) { - ensureMigrated(db, lggr) +func Current(ctx context.Context, db *sql.DB, lggr logger.Logger) (int64, error) { + if err := ensureMigrated(ctx, db, lggr); err != nil { + return -1, err + } return goose.EnsureDBVersion(db) } -func Status(db *sql.DB, lggr logger.Logger) error { - ensureMigrated(db, lggr) +func Status(ctx context.Context, db *sql.DB, lggr logger.Logger) error { + if err := ensureMigrated(ctx, db, lggr); err != nil { + return err + } return goose.Status(db, MIGRATIONS_DIR) } diff --git a/core/store/migrate/migrate_test.go b/core/store/migrate/migrate_test.go index ef105c75ff6..0ff09000161 100644 --- a/core/store/migrate/migrate_test.go +++ b/core/store/migrate/migrate_test.go @@ -391,25 +391,26 @@ func TestMigrate_101_GenericOCR2(t *testing.T) { } func TestMigrate(t *testing.T) { + ctx := testutils.Context(t) lggr := logger.TestLogger(t) _, db := heavyweight.FullTestDBEmptyV2(t, nil) err := goose.UpTo(db.DB, migrationDir, 100) require.NoError(t, err) - err = migrate.Status(db.DB, lggr) + err = migrate.Status(ctx, db.DB, lggr) require.NoError(t, err) - ver, err := migrate.Current(db.DB, lggr) + ver, err := migrate.Current(ctx, db.DB, lggr) require.NoError(t, err) require.Equal(t, int64(100), ver) - err = migrate.Migrate(db.DB, lggr) + err = migrate.Migrate(ctx, db.DB, lggr) require.NoError(t, err) - err = migrate.Rollback(db.DB, lggr, null.IntFrom(99)) + err = migrate.Rollback(ctx, db.DB, lggr, null.IntFrom(99)) require.NoError(t, err) - ver, err = migrate.Current(db.DB, lggr) + ver, err = migrate.Current(ctx, db.DB, lggr) require.NoError(t, err) require.Equal(t, int64(99), ver) }