Skip to content

Commit

Permalink
core/store/migrate: add context and error; remove panics (#11295)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmank88 authored Nov 15, 2023
1 parent 75c70f7 commit 50fbfd2
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 38 deletions.
6 changes: 3 additions & 3 deletions core/cmd/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.G
return nil, err
}

err = handleNodeVersioning(db, appLggr, cfg.RootDir(), cfg.Database(), cfg.WebServer().HTTPPort())
err = handleNodeVersioning(ctx, db, appLggr, cfg.RootDir(), cfg.Database(), cfg.WebServer().HTTPPort())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -229,7 +229,7 @@ func (n ChainlinkAppFactory) NewApplication(ctx context.Context, cfg chainlink.G
}

// handleNodeVersioning is a setup-time helper to encapsulate version changes and db migration
func handleNodeVersioning(db *sqlx.DB, appLggr logger.Logger, rootDir string, cfg config.Database, healthReportPort uint16) error {
func handleNodeVersioning(ctx context.Context, db *sqlx.DB, appLggr logger.Logger, rootDir string, cfg config.Database, healthReportPort uint16) error {
var err error
// Set up the versioning Configs
verORM := versioning.NewORM(db, appLggr, cfg.DefaultQueryTimeout())
Expand Down Expand Up @@ -260,7 +260,7 @@ func handleNodeVersioning(db *sqlx.DB, appLggr logger.Logger, rootDir string, cf

// Migrate the database
if cfg.MigrateDatabase() {
if err = migrate.Migrate(db.DB, appLggr); err != nil {
if err = migrate.Migrate(ctx, db.DB, appLggr); err != nil {
return fmt.Errorf("initializeORM#Migrate: %w", err)
}
}
Expand Down
34 changes: 23 additions & 11 deletions core/cmd/shell_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,17 @@ func (s *Shell) validateDB(c *cli.Context) error {
return s.configExitErr(s.Config.ValidateDB)
}

// ctx returns a context.Context that will be cancelled when SIGINT|SIGTERM is received
func (s *Shell) ctx() context.Context {
ctx, cancel := context.WithCancel(context.Background())
go shutdown.HandleShutdown(func(_ string) { cancel() })
return ctx
}

// ResetDatabase drops, creates and migrates the database specified by CL_DATABASE_URL or Database.URL
// in secrets TOML. This is useful to set up the database for testing
func (s *Shell) ResetDatabase(c *cli.Context) error {
ctx := s.ctx()
cfg := s.Config.Database()
parsed := cfg.URL()
if parsed.String() == "" {
Expand All @@ -727,7 +735,7 @@ func (s *Shell) ResetDatabase(c *cli.Context) error {
return s.errorOut(err)
}
lggr.Debugf("Migrating database: %#v", parsed.String())
if err := migrateDB(cfg, lggr); err != nil {
if err := migrateDB(ctx, cfg, lggr); err != nil {
return s.errorOut(err)
}
schema, err := dumpSchema(parsed)
Expand All @@ -736,7 +744,7 @@ func (s *Shell) ResetDatabase(c *cli.Context) error {
}
lggr.Debugf("Testing rollback and re-migrate for database: %#v", parsed.String())
var baseVersionID int64 = 54
if err := downAndUpDB(cfg, lggr, baseVersionID); err != nil {
if err := downAndUpDB(ctx, cfg, lggr, baseVersionID); err != nil {
return s.errorOut(err)
}
if err := checkSchema(parsed, schema); err != nil {
Expand Down Expand Up @@ -828,6 +836,7 @@ func (s *Shell) PrepareTestDatabaseUserOnly(c *cli.Context) error {

// MigrateDatabase migrates the database
func (s *Shell) MigrateDatabase(_ *cli.Context) error {
ctx := s.ctx()
cfg := s.Config.Database()
parsed := cfg.URL()
if parsed.String() == "" {
Expand All @@ -840,14 +849,15 @@ func (s *Shell) MigrateDatabase(_ *cli.Context) error {
}

s.Logger.Infof("Migrating database: %#v", parsed.String())
if err := migrateDB(cfg, s.Logger); err != nil {
if err := migrateDB(ctx, cfg, s.Logger); err != nil {
return s.errorOut(err)
}
return nil
}

// RollbackDatabase rolls back the database via down migrations.
func (s *Shell) RollbackDatabase(c *cli.Context) error {
ctx := s.ctx()
var version null.Int
if c.Args().Present() {
arg := c.Args().First()
Expand All @@ -863,7 +873,7 @@ func (s *Shell) RollbackDatabase(c *cli.Context) error {
return fmt.Errorf("failed to initialize orm: %v", err)
}

if err := migrate.Rollback(db.DB, s.Logger, version); err != nil {
if err := migrate.Rollback(ctx, db.DB, s.Logger, version); err != nil {
return fmt.Errorf("migrateDB failed: %v", err)
}

Expand All @@ -872,12 +882,13 @@ func (s *Shell) RollbackDatabase(c *cli.Context) error {

// VersionDatabase displays the current database version.
func (s *Shell) VersionDatabase(_ *cli.Context) error {
ctx := s.ctx()
db, err := newConnection(s.Config.Database())
if err != nil {
return fmt.Errorf("failed to initialize orm: %v", err)
}

version, err := migrate.Current(db.DB, s.Logger)
version, err := migrate.Current(ctx, db.DB, s.Logger)
if err != nil {
return fmt.Errorf("migrateDB failed: %v", err)
}
Expand All @@ -888,12 +899,13 @@ func (s *Shell) VersionDatabase(_ *cli.Context) error {

// StatusDatabase displays the database migration status
func (s *Shell) StatusDatabase(_ *cli.Context) error {
ctx := s.ctx()
db, err := newConnection(s.Config.Database())
if err != nil {
return fmt.Errorf("failed to initialize orm: %v", err)
}

if err = migrate.Status(db.DB, s.Logger); err != nil {
if err = migrate.Status(ctx, db.DB, s.Logger); err != nil {
return fmt.Errorf("Status failed: %v", err)
}
return nil
Expand Down Expand Up @@ -1017,27 +1029,27 @@ func dropAndCreatePristineDB(db *sqlx.DB, template string) (err error) {
return nil
}

func migrateDB(config dbConfig, lggr logger.Logger) error {
func migrateDB(ctx context.Context, config dbConfig, lggr logger.Logger) error {
db, err := newConnection(config)
if err != nil {
return fmt.Errorf("failed to initialize orm: %v", err)
}

if err = migrate.Migrate(db.DB, lggr); err != nil {
if err = migrate.Migrate(ctx, db.DB, lggr); err != nil {
return fmt.Errorf("migrateDB failed: %v", err)
}
return db.Close()
}

func downAndUpDB(cfg dbConfig, lggr logger.Logger, baseVersionID int64) error {
func downAndUpDB(ctx context.Context, cfg dbConfig, lggr logger.Logger, baseVersionID int64) error {
db, err := newConnection(cfg)
if err != nil {
return fmt.Errorf("failed to initialize orm: %v", err)
}
if err = migrate.Rollback(db.DB, lggr, null.IntFrom(baseVersionID)); err != nil {
if err = migrate.Rollback(ctx, db.DB, lggr, null.IntFrom(baseVersionID)); err != nil {
return fmt.Errorf("test rollback failed: %v", err)
}
if err = migrate.Migrate(db.DB, lggr); err != nil {
if err = migrate.Migrate(ctx, db.DB, lggr); err != nil {
return fmt.Errorf("second migrateDB failed: %v", err)
}
return db.Close()
Expand Down
43 changes: 24 additions & 19 deletions core/store/migrate/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ func init() {
}

// Ensure we migrated from v1 migrations to goose_migrations
func ensureMigrated(db *sql.DB, lggr logger.Logger) {
func ensureMigrated(ctx context.Context, db *sql.DB, lggr logger.Logger) error {
sqlxDB := pg.WrapDbWithSqlx(db)
var names []string
err := sqlxDB.Select(&names, `SELECT id FROM migrations`)
err := sqlxDB.SelectContext(ctx, &names, `SELECT id FROM migrations`)
if err != nil {
// already migrated
return
return nil
}
err = pg.SqlTransaction(context.Background(), db, lggr, func(tx *sqlx.Tx) error {
err = pg.SqlTransaction(ctx, db, lggr, func(tx *sqlx.Tx) error {
// ensure that no legacy job specs are present: we _must_ bail out early if
// so because otherwise we run the risk of dropping working jobs if the
// user has not read the release notes
return migrations.CheckNoLegacyJobs(tx.Tx)
})
if err != nil {
panic(err)
return err
}

// Look for the squashed migration. If not present, the db needs to be migrated on an earlier release first
Expand All @@ -62,18 +62,18 @@ func ensureMigrated(db *sql.DB, lggr logger.Logger) {
}
}
if !found {
panic("Database state is too old. Need to migrate to chainlink version 0.9.10 first before upgrading to this version. This upgrade is NOT REVERSIBLE, so it is STRONGLY RECOMMENDED that you take a database backup before continuing.")
return errors.New("database state is too old. Need to migrate to chainlink version 0.9.10 first before upgrading to this version. This upgrade is NOT REVERSIBLE, so it is STRONGLY RECOMMENDED that you take a database backup before continuing")
}

// ensure a goose migrations table exists with it's initial v0
if _, err = goose.GetDBVersion(db); err != nil {
panic(err)
return err
}

// insert records for existing migrations
//nolint
sql := fmt.Sprintf(`INSERT INTO %s (version_id, is_applied) VALUES ($1, true);`, goose.TableName())
err = pg.SqlTransaction(context.Background(), db, lggr, func(tx *sqlx.Tx) error {
return pg.SqlTransaction(ctx, db, lggr, func(tx *sqlx.Tx) error {
for _, name := range names {
var id int64
// the first migration doesn't follow the naming convention
Expand All @@ -100,33 +100,38 @@ func ensureMigrated(db *sql.DB, lggr logger.Logger) {
_, err = db.Exec("DROP TABLE migrations;")
return err
})
if err != nil {
panic(err)
}
}

func Migrate(db *sql.DB, lggr logger.Logger) error {
ensureMigrated(db, lggr)
func Migrate(ctx context.Context, db *sql.DB, lggr logger.Logger) error {
if err := ensureMigrated(ctx, db, lggr); err != nil {
return err
}
// WithAllowMissing is necessary when upgrading from 0.10.14 since it
// includes out-of-order migrations
return goose.Up(db, MIGRATIONS_DIR, goose.WithAllowMissing())
}

func Rollback(db *sql.DB, lggr logger.Logger, version null.Int) error {
ensureMigrated(db, lggr)
func Rollback(ctx context.Context, db *sql.DB, lggr logger.Logger, version null.Int) error {
if err := ensureMigrated(ctx, db, lggr); err != nil {
return err
}
if version.Valid {
return goose.DownTo(db, MIGRATIONS_DIR, version.Int64)
}
return goose.Down(db, MIGRATIONS_DIR)
}

func Current(db *sql.DB, lggr logger.Logger) (int64, error) {
ensureMigrated(db, lggr)
func Current(ctx context.Context, db *sql.DB, lggr logger.Logger) (int64, error) {
if err := ensureMigrated(ctx, db, lggr); err != nil {
return -1, err
}
return goose.EnsureDBVersion(db)
}

func Status(db *sql.DB, lggr logger.Logger) error {
ensureMigrated(db, lggr)
func Status(ctx context.Context, db *sql.DB, lggr logger.Logger) error {
if err := ensureMigrated(ctx, db, lggr); err != nil {
return err
}
return goose.Status(db, MIGRATIONS_DIR)
}

Expand Down
11 changes: 6 additions & 5 deletions core/store/migrate/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,25 +391,26 @@ func TestMigrate_101_GenericOCR2(t *testing.T) {
}

func TestMigrate(t *testing.T) {
ctx := testutils.Context(t)
lggr := logger.TestLogger(t)
_, db := heavyweight.FullTestDBEmptyV2(t, nil)
err := goose.UpTo(db.DB, migrationDir, 100)
require.NoError(t, err)

err = migrate.Status(db.DB, lggr)
err = migrate.Status(ctx, db.DB, lggr)
require.NoError(t, err)

ver, err := migrate.Current(db.DB, lggr)
ver, err := migrate.Current(ctx, db.DB, lggr)
require.NoError(t, err)
require.Equal(t, int64(100), ver)

err = migrate.Migrate(db.DB, lggr)
err = migrate.Migrate(ctx, db.DB, lggr)
require.NoError(t, err)

err = migrate.Rollback(db.DB, lggr, null.IntFrom(99))
err = migrate.Rollback(ctx, db.DB, lggr, null.IntFrom(99))
require.NoError(t, err)

ver, err = migrate.Current(db.DB, lggr)
ver, err = migrate.Current(ctx, db.DB, lggr)
require.NoError(t, err)
require.Equal(t, int64(99), ver)
}
Expand Down

0 comments on commit 50fbfd2

Please sign in to comment.