Skip to content

Commit

Permalink
changed up and down functions to an interface
Browse files Browse the repository at this point in the history
  • Loading branch information
hilariocoelho committed Mar 2, 2020
1 parent ff2a18b commit deb0514
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 76 deletions.
29 changes: 14 additions & 15 deletions 1_global_migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,7 @@ func TestMigrationsRegistration(t *testing.T) {
}()
globalMigrate = NewMigrate(nil, nil)

err := Register(func(session *mongo.Session, db *mongo.Database) error {
return nil
}, func(session *mongo.Session, db *mongo.Database) error {
return nil
})
err := Register(migration{})
if err != nil {
t.Errorf("Unexpected register error: %v", err)
return
Expand All @@ -53,11 +49,7 @@ func TestMigrationsRegistration(t *testing.T) {
t.Errorf("Unexpected version/description: %d %s", registered[0].Version, registered[0].Description)
}

err = Register(func(session *mongo.Session, db *mongo.Database) error {
return nil
}, func(session *mongo.Session, db *mongo.Database) error {
return nil
})
err = Register(migration{})
if err == nil {
t.Errorf("Unexpected nil error")
}
Expand All @@ -72,11 +64,7 @@ func TestMigrationMustRegistration(t *testing.T) {
}
}()
globalMigrate = NewMigrate(nil, nil)
MustRegister(func(session *mongo.Session, db *mongo.Database) error {
return nil
}, func(session *mongo.Session, db *mongo.Database) error {
return nil
})
MustRegister(migration{})
registered := RegisteredMigrations()
if len(registered) <= 0 || len(registered) > 1 {
t.Errorf("Unexpected length of registered migrations")
Expand All @@ -86,3 +74,14 @@ func TestMigrationMustRegistration(t *testing.T) {
t.Errorf("Unexpected version/description: %d %s", registered[0].Version, registered[0].Description)
}
}

type migration struct {
}

func (migration) Up(client *mongo.Client, db *mongo.Database) error {
return nil
}

func (migration) Down(client *mongo.Client, db *mongo.Database) error {
return nil
}
33 changes: 20 additions & 13 deletions 1_sample_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,24 @@ import (
const globalTestCollection = "test-global"

func init() {
Register(func(client *mongo.Client, db *mongo.Database) error {
_, err := db.Collection(globalTestCollection).InsertOne(context.TODO(), bson.D{{"a", "b"}})
if err != nil {
return err
}
return nil
}, func(client *mongo.Client, db *mongo.Database) error {
_, err := db.Collection(globalTestCollection).DeleteOne(context.TODO(), bson.D{{"a", "b"}})
if err != nil {
return err
}
return nil
})
Register(migration{})
}

type migration struct {
}

func (migration) Up(client *mongo.Client, db *mongo.Database) error {
_, err := db.Collection(globalTestCollection).InsertOne(context.TODO(), bson.D{{"a", "b"}})
if err != nil {
return err
}
return nil
}

func (migration) Down(client *mongo.Client, db *mongo.Database) error {
_, err := db.Collection(globalTestCollection).DeleteOne(context.TODO(), bson.D{{"a", "b"}})
if err != nil {
return err
}
return nil
}
39 changes: 22 additions & 17 deletions 2_sample_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,31 @@ import (

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)

const globalTestIndexName = "test_idx_2"

func init() {
Register(func(client *mongo.Client, db *mongo.Database) error {
keys := bson.D{{"a", 1}}
opt := options.Index().SetName(globalTestIndexName)
model := mongo.IndexModel{Keys: keys, Options: opt}
_, err := db.Collection(globalTestCollection).Indexes().CreateOne(context.TODO(), model)
if err != nil {
return err
}
return nil
}, func(client *mongo.Client, db *mongo.Database) error {
_, err := db.Collection(globalTestCollection).Indexes().DropOne(context.TODO(), globalTestIndexName)
if err != nil {
return err
}
return nil
})
Register(migration{})
}

type migration struct{}

func (migration) Up(client *mongo.Client, db *mongo.Database) error {
keys := bson.D{{"a", 1}}
opt := options.Index().SetName(globalTestIndexName)
model := mongo.IndexModel{Keys: keys, Options: opt}
_, err := db.Collection(globalTestCollection).Indexes().CreateOne(context.TODO(), model)
if err != nil {
return err
}
return nil
}

func (migration) Down(client *mongo.Client, db *mongo.Database) error {
_, err := db.Collection(globalTestCollection).Indexes().DropOne(context.TODO(), globalTestIndexName)
if err != nil {
return err
}
return nil
}
14 changes: 2 additions & 12 deletions bad_migration_file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package migrate

import (
"testing"

"go.mongodb.org/mongo-driver/mongo"
)

func TestBadMigrationFile(t *testing.T) {
Expand All @@ -13,11 +11,7 @@ func TestBadMigrationFile(t *testing.T) {
}()
globalMigrate = NewMigrate(nil, nil)

err := Register(func(client *mongo.Client, db *mongo.Database) error {
return nil
}, func(client *mongo.Client, db *mongo.Database) error {
return nil
})
err := Register(migration{})
if err == nil {
t.Errorf("Unexpected nil error")
}
Expand All @@ -32,9 +26,5 @@ func TestBadMigrationFilePanic(t *testing.T) {
}
}()
globalMigrate = NewMigrate(nil, nil)
MustRegister(func(client *mongo.Client, db *mongo.Database) error {
return nil
}, func(client *mongo.Client, db *mongo.Database) error {
return nil
})
MustRegister(migration{})
}
17 changes: 8 additions & 9 deletions global_migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

var globalMigrate = NewMigrate(nil, nil)

func internalRegister(up, down MigrationFunc, skip int) error {
func internalRegister(implementation IMigration, skip int) error {
_, file, _, _ := runtime.Caller(skip)
version, description, err := extractVersionDescription(file)
if err != nil {
Expand All @@ -19,10 +19,9 @@ func internalRegister(up, down MigrationFunc, skip int) error {
return fmt.Errorf("migration with version %v already registered", version)
}
globalMigrate.migrations = append(globalMigrate.migrations, Migration{
Version: version,
Description: description,
Up: up,
Down: down,
Version: version,
Description: description,
Implementation: implementation,
})
return nil
}
Expand Down Expand Up @@ -61,13 +60,13 @@ func internalRegister(up, down MigrationFunc, skip int) error {
// return nil
// })
// }
func Register(up, down MigrationFunc) error {
return internalRegister(up, down, 2)
func Register(implementation IMigration) error {
return internalRegister(implementation, 2)
}

// MustRegister acts like Register but panics on errors.
func MustRegister(up, down MigrationFunc) {
if err := internalRegister(up, down, 2); err != nil {
func MustRegister(implementation IMigration) {
if err := internalRegister(implementation, 2); err != nil {
panic(err)
}
}
Expand Down
9 changes: 5 additions & 4 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type Migrate struct {
migrationsCollection string
}

// NewMigrate creates a migration
func NewMigrate(client *mongo.Client, db *mongo.Database, migrations ...Migration) *Migrate {
internalMigrations := make([]Migration, len(migrations))
copy(internalMigrations, migrations)
Expand Down Expand Up @@ -187,11 +188,11 @@ func (m *Migrate) Up(n int) error {

for i, p := 0, 0; i < len(m.migrations) && p < n; i++ {
migration := m.migrations[i]
if migration.Version <= currentVersion || migration.Up == nil {
if migration.Version <= currentVersion || migration.Implementation == nil {
continue
}
p++
if err := migration.Up(m.client, m.db); err != nil {
if err := migration.Implementation.Up(m.client, m.db); err != nil {
return err
}
if err := m.SetVersion(migration.Version, migration.Description); err != nil {
Expand All @@ -216,11 +217,11 @@ func (m *Migrate) Down(n int) error {

for i, p := len(m.migrations)-1, 0; i >= 0 && p < n; i-- {
migration := m.migrations[i]
if migration.Version > currentVersion || migration.Down == nil {
if migration.Version > currentVersion || migration.Implementation == nil {
continue
}
p++
if err := migration.Down(m.client, m.db); err != nil {
if err := migration.Implementation.Down(m.client, m.db); err != nil {
return err
}

Expand Down
14 changes: 8 additions & 6 deletions migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ import (
"go.mongodb.org/mongo-driver/mongo"
)

// MigrationFunc used to define actions to be performed for a migration.
type MigrationFunc func(client *mongo.Client, db *mongo.Database) error
// IMigration represents the interface that each migration needs to implement
type IMigration interface {
Up(client *mongo.Client, db *mongo.Database) error
Down(client *mongo.Client, db *mongo.Database) error
}

// Migration represents single database migration.
// Migration contains:
Expand All @@ -20,10 +23,9 @@ type MigrationFunc func(client *mongo.Client, db *mongo.Database) error
//
// - down: callback which will be called in "down" migration process for reverting changes
type Migration struct {
Version uint64
Description string
Up MigrationFunc
Down MigrationFunc
Version uint64
Description string
Implementation IMigration
}

func migrationSort(migrations []Migration) {
Expand Down

0 comments on commit deb0514

Please sign in to comment.