Skip to content

Commit

Permalink
feat(automigrate): detect renamed tables
Browse files Browse the repository at this point in the history
  • Loading branch information
bevzzz committed Oct 27, 2024
1 parent 6dfdf95 commit c03938f
Show file tree
Hide file tree
Showing 12 changed files with 335 additions and 101 deletions.
26 changes: 26 additions & 0 deletions dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package pgdialect

import (
"context"
"database/sql"
"fmt"

"github.com/uptrace/bun/migrate/sqlschema"
)

func (d *Dialect) Migrator(sqldb *sql.DB) sqlschema.Migrator {
return &Migrator{sqldb: sqldb}
}

type Migrator struct {
sqldb *sql.DB
}

func (m *Migrator) RenameTable(ctx context.Context, oldName, newName string) error {
query := fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldName, newName)
_, err := m.sqldb.ExecContext(ctx, query)
if err != nil {
return err
}
return nil
}
2 changes: 2 additions & 0 deletions dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
"github.com/uptrace/bun/schema/inspector"
)
Expand All @@ -32,6 +33,7 @@ type Dialect struct {

var _ schema.Dialect = (*Dialect)(nil)
var _ inspector.Dialect = (*Dialect)(nil)
var _ sqlschema.MigratorDialect = (*Dialect)(nil)

func New() *Dialect {
d := new(Dialect)
Expand Down
29 changes: 19 additions & 10 deletions dialect/pgdialect/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,38 @@ import (
"github.com/uptrace/bun/schema"
)

func (d *Dialect) Inspector(db *bun.DB) schema.Inspector {
return newDatabaseInspector(db)
func (d *Dialect) Inspector(db *bun.DB, excludeTables ...string) schema.Inspector {
return newInspector(db, excludeTables...)
}

type DatabaseInspector struct {
db *bun.DB
type Inspector struct {
db *bun.DB
excludeTables []string
}

var _ schema.Inspector = (*DatabaseInspector)(nil)
var _ schema.Inspector = (*Inspector)(nil)

func newDatabaseInspector(db *bun.DB) *DatabaseInspector {
return &DatabaseInspector{db: db}
func newInspector(db *bun.DB, excludeTables ...string) *Inspector {
return &Inspector{db: db, excludeTables: excludeTables}
}

func (di *DatabaseInspector) Inspect(ctx context.Context) (schema.State, error) {
func (in *Inspector) Inspect(ctx context.Context) (schema.State, error) {
var state schema.State

exclude := in.excludeTables
if len(exclude) == 0 {
// Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice.
exclude = []string{""}
}

var tables []*InformationSchemaTable
if err := di.db.NewRaw(sqlInspectTables).Scan(ctx, &tables); err != nil {
if err := in.db.NewRaw(sqlInspectTables, bun.In(exclude)).Scan(ctx, &tables); err != nil {
return state, err
}

for _, table := range tables {
var columns []*InformationSchemaColumn
if err := di.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil {
if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil {
return state, err
}
colDefs := make(map[string]schema.ColumnDef)
Expand Down Expand Up @@ -105,6 +113,7 @@ FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema <> 'information_schema'
AND table_schema NOT LIKE 'pg_%'
AND table_name NOT IN (?)
`

// sqlInspectColumnsQuery retrieves column definitions for the specified table.
Expand Down
7 changes: 4 additions & 3 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,12 @@ func testEachDB(t *testing.T, f func(t *testing.T, dbName string, db *bun.DB)) {
}

// testEachDialect allows testing dialect-specific functionality that does not require database interactions.
func testEachDialect(t *testing.T, f func(t *testing.T, dialectName string, dialect func() schema.Dialect)) {
func testEachDialect(t *testing.T, f func(t *testing.T, dialectName string, dialect schema.Dialect)) {
for _, newDialect := range allDialects {
name := newDialect().Name().String()
d := newDialect()
name := d.Name().String()
t.Run(name, func(t *testing.T) {
f(t, name, newDialect)
f(t, name, d)
})
}
}
Expand Down
6 changes: 3 additions & 3 deletions internal/dbtest/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
Title string `bun:",notnull,unique:title_author"`
Locale string `bun:",type:varchar(5),default:'en-GB'"`
Pages int8 `bun:"page_count,notnull,default:1"`
Count int32 `bun:"book_count,autoincrement"`
Count int32 `bun:"book_count,autoincrement"`
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand All @@ -34,7 +34,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
}

ctx := context.Background()
createTableOrSkip(t, ctx, db, (*Book)(nil))
mustResetModel(t, ctx, db, (*Book)(nil))

dbInspector := dialect.Inspector(db)
want := schema.State{
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
func getDatabaseInspectorOrSkip(tb testing.TB, db *bun.DB) schema.Inspector {
dialect := db.Dialect()
if id, ok := dialect.(inspector.Dialect); ok {
return id.Inspector(db)
return id.Inspector(db, migrationsTable, migrationLocksTable)
}
tb.Skipf("%q dialect does not implement inspector.Dialect", dialect.Name())
return nil
Expand Down
51 changes: 20 additions & 31 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ func testMigrateUpError(t *testing.T, db *bun.DB) {
require.Equal(t, []string{"down2", "down1"}, history)
}

func TestAutoMigrator_Migrate(t *testing.T) {
func TestAutoMigrator_Run(t *testing.T) {

tests := []struct {
fn func(t *testing.T, db *bun.DB)
}{
Expand Down Expand Up @@ -190,14 +191,17 @@ func testRenameTable(t *testing.T, db *bun.DB) {
// Arrange
ctx := context.Background()
di := getDatabaseInspectorOrSkip(t, db)
createTableOrSkip(t, ctx, db, (*initial)(nil))
mustResetModel(t, ctx, db, (*initial)(nil))
mustDropTableOnCleanup(t, ctx, db, (*changed)(nil))

m, err := migrate.NewAutoMigrator(db)
m, err := migrate.NewAutoMigrator(db,
migrate.WithTableNameAuto(migrationsTable),
migrate.WithLocksTableNameAuto(migrationLocksTable),
migrate.WithModel((*changed)(nil)))
require.NoError(t, err)
m.SetModels((*changed)(nil))

// Act
err = m.Migrate(ctx)
err = m.Run(ctx)
require.NoError(t, err)

// Assert
Expand All @@ -209,27 +213,13 @@ func testRenameTable(t *testing.T, db *bun.DB) {
require.Equal(t, "changed", tables[0].Name)
}

func createTableOrSkip(tb testing.TB, ctx context.Context, db *bun.DB, model interface{}) {
tb.Helper()
if _, err := db.NewCreateTable().IfNotExists().Model(model).Exec(ctx); err != nil {
tb.Skip("setup failed:", err)
}
tb.Cleanup(func() {
if _, err := db.NewDropTable().IfExists().Model(model).Exec(ctx); err != nil {
tb.Log("cleanup:", err)
}
})
}

func TestDetector_Diff(t *testing.T) {
tests := []struct {
name string
states func(testing.TB, context.Context, func() schema.Dialect) (stateDb schema.State, stateModel schema.State)
states func(testing.TB, context.Context, schema.Dialect) (stateDb schema.State, stateModel schema.State)
operations []migrate.Operation
}{
{
name: "find a renamed table",
states: renamedTableStates,
states: testDetectRenamedTable,
operations: []migrate.Operation{
&migrate.RenameTable{
From: "books",
Expand All @@ -239,13 +229,9 @@ func TestDetector_Diff(t *testing.T) {
},
}

testEachDialect(t, func(t *testing.T, dialectName string, dialect func() schema.Dialect) {
if dialectName != "pg" {
t.Skip()
}

testEachDialect(t, func(t *testing.T, dialectName string, dialect schema.Dialect) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Run(funcName(tt.states), func(t *testing.T) {
ctx := context.Background()
var d migrate.Detector
stateDb, stateModel := tt.states(t, ctx, dialect)
Expand All @@ -258,7 +244,7 @@ func TestDetector_Diff(t *testing.T) {
})
}

func renamedTableStates(tb testing.TB, ctx context.Context, dialect func() schema.Dialect) (s1, s2 schema.State) {
func testDetectRenamedTable(tb testing.TB, ctx context.Context, dialect schema.Dialect) (s1, s2 schema.State) {
type Book struct {
bun.BaseModel

Expand All @@ -279,17 +265,20 @@ func renamedTableStates(tb testing.TB, ctx context.Context, dialect func() schem
Title string `bun:"title,notnull"`
Pages int `bun:"page_count,notnull,default:0"`
}
return getState(tb, ctx, dialect(),
return getState(tb, ctx, dialect,
(*Author)(nil),
(*Book)(nil),
), getState(tb, ctx, dialect(),
), getState(tb, ctx, dialect,
(*Author)(nil),
(*BookRenamed)(nil),
)
}

func getState(tb testing.TB, ctx context.Context, dialect schema.Dialect, models ...interface{}) schema.State {
inspector := schema.NewInspector(dialect, models...)
tables := schema.NewTables(dialect)
tables.Register(models...)

inspector := schema.NewInspector(tables)
state, err := inspector.Inspect(ctx)
if err != nil {
tb.Skip("get state: %w", err)
Expand Down
Loading

0 comments on commit c03938f

Please sign in to comment.