Skip to content

Commit

Permalink
Create an alias for Dialect from database.Dialect (#636)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Nov 10, 2023
1 parent 3fade11 commit 6692639
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 32 deletions.
22 changes: 12 additions & 10 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@ package goose
import (
"fmt"

"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/dialect"
)

// Dialect is the type of database dialect.
type Dialect string
// Dialect is the type of database dialect. It is an alias for [database.Dialect].
type Dialect = database.Dialect

const (
DialectClickHouse Dialect = "clickhouse"
DialectMSSQL Dialect = "mssql"
DialectMySQL Dialect = "mysql"
DialectPostgres Dialect = "postgres"
DialectRedshift Dialect = "redshift"
DialectSQLite3 Dialect = "sqlite3"
DialectTiDB Dialect = "tidb"
DialectVertica Dialect = "vertica"
DialectClickHouse Dialect = database.DialectClickHouse
DialectMSSQL Dialect = database.DialectMSSQL
DialectMySQL Dialect = database.DialectMySQL
DialectPostgres Dialect = database.DialectPostgres
DialectRedshift Dialect = database.DialectRedshift
DialectSQLite3 Dialect = database.DialectSQLite3
DialectTiDB Dialect = database.DialectTiDB
DialectVertica Dialect = database.DialectVertica
DialectYdB Dialect = database.DialectYdB
)

func init() {
Expand Down
5 changes: 5 additions & 0 deletions migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ func NewGoMigration(version int64, up, down *GoFunc) *Migration {
if f.RunTx == nil && f.RunDB != nil {
f.Mode = TransactionDisabled
}
// Always default to TransactionEnabled if both functions are nil. This is the most
// common use case.
if f.RunDB == nil && f.RunTx == nil {
f.Mode = TransactionEnabled
}
}
return f
}
Expand Down
2 changes: 1 addition & 1 deletion provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type Provider struct {
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
// Experimental: This API is experimental and may change in the future.
func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
if db == nil {
return nil, errors.New("db must not be nil")
}
Expand Down
16 changes: 8 additions & 8 deletions provider_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,34 @@ func TestNewProvider(t *testing.T) {
_, err = goose.NewProvider("unknown-dialect", db, fsys)
check.HasError(t, err)
// Nil db not allowed
_, err = goose.NewProvider(database.DialectSQLite3, nil, fsys)
_, err = goose.NewProvider(goose.DialectSQLite3, nil, fsys)
check.HasError(t, err)
// Nil store not allowed
_, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(nil))
_, err = goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithStore(nil))
check.HasError(t, err)
// Cannot set both dialect and store
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
store, err := database.NewStore(goose.DialectSQLite3, "custom_table")
check.NoError(t, err)
_, err = goose.NewProvider(database.DialectSQLite3, db, nil, goose.WithStore(store))
_, err = goose.NewProvider(goose.DialectSQLite3, db, nil, goose.WithStore(store))
check.HasError(t, err)
// Multiple stores not allowed
_, err = goose.NewProvider(database.DialectSQLite3, db, nil,
_, err = goose.NewProvider(goose.DialectSQLite3, db, nil,
goose.WithStore(store),
goose.WithStore(store),
)
check.HasError(t, err)
})
t.Run("valid", func(t *testing.T) {
// Valid dialect, db, and fsys allowed
_, err = goose.NewProvider(database.DialectSQLite3, db, fsys)
_, err = goose.NewProvider(goose.DialectSQLite3, db, fsys)
check.NoError(t, err)
// Valid dialect, db, fsys, and verbose allowed
_, err = goose.NewProvider(database.DialectSQLite3, db, fsys,
_, err = goose.NewProvider(goose.DialectSQLite3, db, fsys,
goose.WithVerbose(testing.Verbose()),
)
check.NoError(t, err)
// Custom store allowed
store, err := database.NewStore(database.DialectSQLite3, "custom_table")
store, err := database.NewStore(goose.DialectSQLite3, "custom_table")
check.NoError(t, err)
_, err = goose.NewProvider("", db, nil, goose.WithStore(store))
check.HasError(t, err)
Expand Down
19 changes: 9 additions & 10 deletions provider_run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"testing/fstest"

"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testdb"
"github.com/pressly/goose/v3/lock"
Expand Down Expand Up @@ -321,7 +320,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-2');
INSERT INTO owners (owner_name) VALUES ('seed-user-3');
`),
}
p, err := goose.NewProvider(database.DialectSQLite3, db, mapFS)
p, err := goose.NewProvider(goose.DialectSQLite3, db, mapFS)
check.NoError(t, err)
_, err = p.Up(ctx)
check.HasError(t, err)
Expand Down Expand Up @@ -486,7 +485,7 @@ func TestNoVersioning(t *testing.T) {
// These are owners created by migration files.
wantOwnerCount = 4
)
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys,
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys,
goose.WithVerbose(testing.Verbose()),
goose.WithDisableVersioning(false), // This is the default.
)
Expand All @@ -499,7 +498,7 @@ func TestNoVersioning(t *testing.T) {
check.Number(t, baseVersion, 3)
t.Run("seed-up-down-to-zero", func(t *testing.T) {
fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed"))
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys,
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys,
goose.WithVerbose(testing.Verbose()),
goose.WithDisableVersioning(true), // Provider with no versioning.
)
Expand Down Expand Up @@ -552,7 +551,7 @@ func TestAllowMissing(t *testing.T) {

t.Run("missing_now_allowed", func(t *testing.T) {
db := newDB(t)
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(),
p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(),
goose.WithAllowOutofOrder(false),
)
check.NoError(t, err)
Expand Down Expand Up @@ -607,7 +606,7 @@ func TestAllowMissing(t *testing.T) {

t.Run("missing_allowed", func(t *testing.T) {
db := newDB(t)
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(),
p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(),
goose.WithAllowOutofOrder(true),
)
check.NoError(t, err)
Expand Down Expand Up @@ -723,7 +722,7 @@ func TestGoOnly(t *testing.T) {
&goose.GoFunc{RunTx: newTxFn("DELETE FROM users")},
),
}
p, err := goose.NewProvider(database.DialectSQLite3, db, nil,
p, err := goose.NewProvider(goose.DialectSQLite3, db, nil,
goose.WithGoMigrations(register...),
)
check.NoError(t, err)
Expand Down Expand Up @@ -779,7 +778,7 @@ func TestGoOnly(t *testing.T) {
&goose.GoFunc{RunDB: newDBFn("DELETE FROM users")},
),
}
p, err := goose.NewProvider(database.DialectSQLite3, db, nil,
p, err := goose.NewProvider(goose.DialectSQLite3, db, nil,
goose.WithGoMigrations(register...),
)
check.NoError(t, err)
Expand Down Expand Up @@ -836,7 +835,7 @@ func TestLockModeAdvisorySession(t *testing.T) {
)
check.NoError(t, err)
p, err := goose.NewProvider(
database.DialectPostgres,
goose.DialectPostgres,
db,
os.DirFS("testdata/migrations"),
goose.WithSessionLocker(sessionLocker), // Use advisory session lock mode.
Expand Down Expand Up @@ -1093,7 +1092,7 @@ func newProviderWithDB(t *testing.T, opts ...goose.ProviderOption) (*goose.Provi
opts,
goose.WithVerbose(testing.Verbose()),
)
p, err := goose.NewProvider(database.DialectSQLite3, db, newFsys(), opts...)
p, err := goose.NewProvider(goose.DialectSQLite3, db, newFsys(), opts...)
check.NoError(t, err)
return p, db
}
Expand Down
5 changes: 2 additions & 3 deletions provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"testing/fstest"

"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
_ "modernc.org/sqlite"
)
Expand All @@ -19,7 +18,7 @@ func TestProvider(t *testing.T) {
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
t.Run("empty", func(t *testing.T) {
_, err := goose.NewProvider(database.DialectSQLite3, db, fstest.MapFS{})
_, err := goose.NewProvider(goose.DialectSQLite3, db, fstest.MapFS{})
check.HasError(t, err)
check.Bool(t, errors.Is(err, goose.ErrNoMigrations), true)
})
Expand All @@ -30,7 +29,7 @@ func TestProvider(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
p, err := goose.NewProvider(database.DialectSQLite3, db, fsys)
p, err := goose.NewProvider(goose.DialectSQLite3, db, fsys)
check.NoError(t, err)
sources := p.ListSources()
check.Equal(t, len(sources), 2)
Expand Down

0 comments on commit 6692639

Please sign in to comment.