Skip to content

Commit

Permalink
feat: Add goose provider (#635)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Nov 9, 2023
1 parent 8503d4e commit 04e12b8
Show file tree
Hide file tree
Showing 32 changed files with 1,524 additions and 1,689 deletions.
4 changes: 4 additions & 0 deletions database/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database

import (
"context"
"database/sql"
"errors"
"fmt"

Expand Down Expand Up @@ -100,6 +101,9 @@ func (s *store) GetMigration(
&result.Timestamp,
&result.IsApplied,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
}
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
}
return &result, nil
Expand Down
10 changes: 8 additions & 2 deletions database/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@ package database

import (
"context"
"errors"
"time"
)

var (
// ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found.
ErrVersionNotFound = errors.New("version not found")
)

// Store is an interface that defines methods for managing database migrations and versioning. By
// defining a Store interface, we can support multiple databases with consistent functionality.
//
Expand All @@ -24,8 +30,8 @@ type Store interface {
// Delete deletes a version id from the version table.
Delete(ctx context.Context, db DBTxConn, version int64) error

// GetMigration retrieves a single migration by version id. This method may return the raw sql
// error if the query fails so the caller can assert for errors such as [sql.ErrNoRows].
// GetMigration retrieves a single migration by version id. If the query succeeds, but the
// version is not found, this method must return [ErrVersionNotFound].
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)

// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
Expand Down
2 changes: 1 addition & 1 deletion database/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func testStore(
err = runConn(ctx, db, func(conn *sql.Conn) error {
_, err := store.GetMigration(ctx, conn, 0)
check.HasError(t, err)
check.Bool(t, errors.Is(err, sql.ErrNoRows), true)
check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true)
return nil
})
check.NoError(t, err)
Expand Down
60 changes: 17 additions & 43 deletions globals.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,20 @@ func ResetGlobalMigrations() {
// [NewGoMigration] function.
//
// Not safe for concurrent use.
func SetGlobalMigrations(migrations ...Migration) error {
for _, migration := range migrations {
m := &migration
func SetGlobalMigrations(migrations ...*Migration) error {
for _, m := range migrations {
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
if err := checkMigration(m); err != nil {
if err := checkGoMigration(m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
registeredGoMigrations[m.Version] = m
}
return nil
}

func checkMigration(m *Migration) error {
func checkGoMigration(m *Migration) error {
if !m.construct {
return errors.New("must use NewGoMigration to construct migrations")
}
Expand All @@ -63,10 +62,10 @@ func checkMigration(m *Migration) error {
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
}
}
if err := setGoFunc(m.goUp); err != nil {
if err := checkGoFunc(m.goUp); err != nil {
return fmt.Errorf("up function: %w", err)
}
if err := setGoFunc(m.goDown); err != nil {
if err := checkGoFunc(m.goDown); err != nil {
return fmt.Errorf("down function: %w", err)
}
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
Expand All @@ -84,47 +83,22 @@ func checkMigration(m *Migration) error {
return nil
}

func setGoFunc(f *GoFunc) error {
if f == nil {
f = &GoFunc{Mode: TransactionEnabled}
return nil
}
func checkGoFunc(f *GoFunc) error {
if f.RunTx != nil && f.RunDB != nil {
return errors.New("must specify exactly one of RunTx or RunDB")
}
if f.RunTx == nil && f.RunDB == nil {
switch f.Mode {
case 0:
// Default to TransactionEnabled ONLY if mode is not set explicitly.
f.Mode = TransactionEnabled
case TransactionEnabled, TransactionDisabled:
// No functions but mode is set. This is not an error. It means the user wants to record
// a version with the given mode but not run any functions.
default:
return fmt.Errorf("invalid mode: %d", f.Mode)
}
return nil
}
if f.RunDB != nil {
switch f.Mode {
case 0, TransactionDisabled:
f.Mode = TransactionDisabled
default:
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
}
switch f.Mode {
case TransactionEnabled, TransactionDisabled:
// No functions, but mode is set. This is not an error. It means the user wants to
// record a version with the given mode but not run any functions.
default:
return fmt.Errorf("invalid mode: %d", f.Mode)
}
if f.RunTx != nil {
switch f.Mode {
case 0, TransactionEnabled:
f.Mode = TransactionEnabled
default:
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
}
if f.RunDB != nil && f.Mode != TransactionDisabled {
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
}
// This is a defensive check. If the mode is still 0, it means we failed to infer the mode from
// the functions or return an error. This should never happen.
if f.Mode == 0 {
return errors.New("failed to infer transaction mode")
if f.RunTx != nil && f.Mode != TransactionEnabled {
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
}
return nil
}
63 changes: 41 additions & 22 deletions globals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ func TestTransactionMode(t *testing.T) {
// reset so we can check the default is set
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
err = SetGlobalMigrations(migration2)
check.NoError(t, err)
check.Number(t, len(registeredGoMigrations), 2)
registered = registeredGoMigrations[2]
check.Bool(t, registered.goUp != nil, true)
check.Bool(t, registered.goDown != nil, true)
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0")

migration3 := NewGoMigration(3, nil, nil)
// reset so we can check the default is set
migration3.goDown.Mode = 0
err = SetGlobalMigrations(migration3)
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0")
})
t.Run("unknown_mode", func(t *testing.T) {
m := NewGoMigration(1, nil, nil)
Expand Down Expand Up @@ -192,7 +194,7 @@ func TestGlobalRegister(t *testing.T) {
runTx := func(context.Context, *sql.Tx) error { return nil }

// Success.
err := SetGlobalMigrations([]Migration{}...)
err := SetGlobalMigrations([]*Migration{}...)
check.NoError(t, err)
err = SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
Expand All @@ -204,62 +206,79 @@ func TestGlobalRegister(t *testing.T) {
)
check.HasError(t, err)
check.Contains(t, err.Error(), "go migration with version 1 already registered")
err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo})
err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
}

func TestCheckMigration(t *testing.T) {
// Success.
err := checkGoMigration(NewGoMigration(1, nil, nil))
check.NoError(t, err)
// Failures.
err := checkMigration(&Migration{})
err = checkGoMigration(&Migration{})
check.HasError(t, err)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
err = checkMigration(&Migration{construct: true})
err = checkGoMigration(&Migration{construct: true})
check.HasError(t, err)
check.Contains(t, err.Error(), "must be registered")
err = checkMigration(&Migration{construct: true, Registered: true})
err = checkGoMigration(&Migration{construct: true, Registered: true})
check.HasError(t, err)
check.Contains(t, err.Error(), `type must be "go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "version must be greater than zero")
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}})
check.HasError(t, err)
check.Contains(t, err.Error(), "up function: invalid mode: 0")
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}})
check.HasError(t, err)
check.Contains(t, err.Error(), "down function: invalid mode: 0")
// Success.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}})
check.NoError(t, err)
// Failures.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
check.HasError(t, err)
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), `no filename separator '_' found`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
check.HasError(t, err)
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFn: func(*sql.Tx) error { return nil },
UpFnNoTx: func(*sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFn: func(*sql.Tx) error { return nil },
DownFnNoTx: func(*sql.DB) error { return nil },
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
Expand Down
Loading

0 comments on commit 04e12b8

Please sign in to comment.