diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go index d88974bb0..63e4fa38a 100644 --- a/dialect/pgdialect/alter_table.go +++ b/dialect/pgdialect/alter_table.go @@ -41,12 +41,14 @@ func (m *migrator) Apply(ctx context.Context, changes ...sqlschema.Operation) er switch change := change.(type) { case *migrate.CreateTable: + log.Printf("create table %q", change.Name) err = m.CreateTable(ctx, change.Model) if err != nil { return fmt.Errorf("apply changes: create table %s: %w", change.FQN(), err) } continue case *migrate.DropTable: + log.Printf("drop table %q", change.Name) err = m.DropTable(ctx, change.Schema, change.Name) if err != nil { return fmt.Errorf("apply changes: drop table %s: %w", change.FQN(), err) @@ -56,6 +58,10 @@ func (m *migrator) Apply(ctx context.Context, changes ...sqlschema.Operation) er b, err = m.renameTable(fmter, b, change) case *migrate.RenameColumn: b, err = m.renameColumn(fmter, b, change) + case *migrate.AddColumn: + b, err = m.addColumn(fmter, b, change) + case *migrate.DropColumn: + b, err = m.dropColumn(fmter, b, change) case *migrate.DropConstraint: b, err = m.dropContraint(fmter, b, change) case *migrate.AddForeignKey: @@ -87,9 +93,7 @@ func (m *migrator) renameTable(fmter schema.Formatter, b []byte, rename *migrate return b, err } b = append(b, " RENAME TO "...) - if b, err = bun.Ident(rename.NewName).AppendQuery(fmter, b); err != nil { - return b, err - } + b = fmter.AppendIdent(b, rename.NewName) return b, nil } @@ -101,14 +105,36 @@ func (m *migrator) renameColumn(fmter schema.Formatter, b []byte, rename *migrat } b = append(b, " RENAME COLUMN "...) - if b, err = bun.Ident(rename.OldName).AppendQuery(fmter, b); err != nil { - return b, err - } + b = fmter.AppendIdent(b, rename.OldName) b = append(b, " TO "...) - if b, err = bun.Ident(rename.NewName).AppendQuery(fmter, b); err != nil { - return b, err - } + b = fmter.AppendIdent(b, rename.NewName) + + return b, nil +} + +func (m *migrator) addColumn(fmter schema.Formatter, b []byte, add *migrate.AddColumn) (_ []byte, err error) { + b = append(b, "ALTER TABLE "...) + fqn := add.FQN() + b, _ = fqn.AppendQuery(fmter, b) + + b = append(b, " ADD COLUMN "...) + b = fmter.AppendName(b, add.Column) + b = append(b, " "...) + + b, _ = add.ColDef.AppendQuery(fmter, b) + + return b, nil +} + +func (m *migrator) dropColumn(fmter schema.Formatter, b []byte, drop *migrate.DropColumn) (_ []byte, err error) { + b = append(b, "ALTER TABLE "...) + fqn := drop.FQN() + b, _ = fqn.AppendQuery(fmter, b) + + b = append(b, " DROP COLUMN "...) + b = fmter.AppendName(b, drop.Column) + return b, nil } @@ -120,14 +146,11 @@ func (m *migrator) renameConstraint(fmter schema.Formatter, b []byte, rename *mi } b = append(b, " RENAME CONSTRAINT "...) - if b, err = bun.Ident(rename.OldName).AppendQuery(fmter, b); err != nil { - return b, err - } + b = fmter.AppendIdent(b, rename.OldName) b = append(b, " TO "...) - if b, err = bun.Ident(rename.NewName).AppendQuery(fmter, b); err != nil { - return b, err - } + b = fmter.AppendIdent(b, rename.NewName) + return b, nil } @@ -139,9 +162,8 @@ func (m *migrator) dropContraint(fmter schema.Formatter, b []byte, drop *migrate } b = append(b, " DROP CONSTRAINT "...) - if b, err = bun.Ident(drop.ConstraintName).AppendQuery(fmter, b); err != nil { - return b, err - } + b = fmter.AppendIdent(b, drop.ConstraintName) + return b, nil } @@ -153,9 +175,7 @@ func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate. } b = append(b, " ADD CONSTRAINT "...) - if b, err = bun.Ident(add.ConstraintName).AppendQuery(fmter, b); err != nil { - return b, err - } + b = fmter.AppendIdent(b, add.ConstraintName) b = append(b, " FOREIGN KEY ("...) if b, err = add.FK.From.Column.Safe().AppendQuery(fmter, b); err != nil { @@ -192,7 +212,7 @@ func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *mi b = append(b, ","...) } b = append(b, " ALTER COLUMN "...) - b, _ = bun.Ident(colDef.Column).AppendQuery(fmter, b) + b = fmter.AppendIdent(b, colDef.Column) i++ } diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go index 346be401e..bd758d1f9 100644 --- a/internal/dbtest/inspect_test.go +++ b/internal/dbtest/inspect_test.go @@ -3,6 +3,7 @@ package dbtest_test import ( "context" "fmt" + "strings" "testing" "time" @@ -306,14 +307,17 @@ func cmpTables(tb testing.TB, d sqlschema.InspectorDialect, want, got []sqlschem } func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, want, got map[string]sqlschema.Column) { + tb.Helper() var errs []string + + var missing []string for colName, wantCol := range want { errorf := func(format string, args ...interface{}) { errs = append(errs, fmt.Sprintf("[%s.%s] "+format, append([]interface{}{tableName, colName}, args...)...)) } gotCol, ok := got[colName] if !ok { - errorf("column is missing") + missing = append(missing, colName) continue } @@ -338,6 +342,21 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w } } + if len(missing) > 0 { + errs = append(errs, fmt.Sprintf("%q has missing columns: %q", tableName, strings.Join(missing, "\", \""))) + } + + var extra []string + for colName := range got { + if _, ok := want[colName]; !ok { + extra = append(extra, colName) + } + } + + if len(extra) > 0 { + errs = append(errs, fmt.Sprintf("%q has extra columns: %q", tableName, strings.Join(extra, "\", \""))) + } + for _, errMsg := range errs { tb.Error(errMsg) } diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 553114e93..bb424b9f7 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -208,6 +208,7 @@ func TestAutoMigrator_Run(t *testing.T) { {testRenameColumnRenamesFK}, {testChangeColumnType_AutoCast}, {testIdentity}, + {testAddDropColumn}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -708,6 +709,51 @@ func testIdentity(t *testing.T, db *bun.DB) { cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) } +func testAddDropColumn(t *testing.T, db *bun.DB) { + type TableBefore struct { + bun.BaseModel `bun:"table:table"` + DoNotTouch string `bun:"do_not_touch"` + DropMe string `bun:"dropme"` + } + + type TableAfter struct { + bun.BaseModel `bun:"table:table"` + DoNotTouch string `bun:"do_not_touch"` + AddMe bool `bun:"addme"` + } + + wantTables := []sqlschema.Table{ + { + Schema: db.Dialect().DefaultSchema(), + Name: "table", + Columns: map[string]sqlschema.Column{ + "do_not_touch": { + SQLType: sqltype.VarChar, + IsNullable: true, + }, + "addme": { + SQLType: sqltype.Boolean, + IsNullable: true, + }, + }, + }, + } + + ctx := context.Background() + inspect := inspectDbOrSkip(t, db) + mustResetModel(t, ctx, db, (*TableBefore)(nil)) + m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil))) + + // Act + err := m.Run(ctx) + require.NoError(t, err) + + // Assert + state := inspect(ctx) + cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) +} + + // // TODO: rewrite these tests into AutoMigrator tests, Diff should be moved to migrate/internal package // func TestDiff(t *testing.T) { // type Journal struct { diff --git a/migrate/diff.go b/migrate/diff.go index bb13723b4..1afa4dd96 100644 --- a/migrate/diff.go +++ b/migrate/diff.go @@ -340,8 +340,12 @@ func (d *detector) makeTargetColDef(current, target sqlschema.Column) sqlschema. // detechColumnChanges finds renamed columns and, if checkType == true, columns with changed type. func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkType bool) { +ChangedRenamed: for tName, tCol := range target.Columns { - // This column exists in the database, so it hasn't been renamed. + + // This column exists in the database, so it hasn't been renamed, dropped, or added. + // Still, we should not delete(columns, thisColumn), because later we will need to + // check that we do not try to rename a column to an already a name that already exists. if cCol, ok := current.Columns[tName]; ok { if checkType && !d.equalColumns(cCol, tCol) { d.changes.Add(&ChangeColumnType{ @@ -351,13 +355,15 @@ func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkTyp From: cCol, To: d.makeTargetColDef(cCol, tCol), }) - // TODO: Can I delete (current.Column, tName) then? Because if it's type has changed, it will never match in the line 343. } continue } + // Column tName does not exist in the database -- it's been either renamed or added. + // Find renamed columns first. for cName, cCol := range current.Columns { - if _, keep := target.Columns[cName]; keep || !d.equalColumns(tCol, cCol) { + // Cannot rename if a column with this name already exists or the types differ. + if _, exists := current.Columns[tName]; exists || !d.equalColumns(tCol, cCol) { continue } d.changes.Add(&RenameColumn{ @@ -368,7 +374,27 @@ func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkTyp }) delete(current.Columns, cName) // no need to check this column again d.refMap.UpdateC(sqlschema.C(target.Schema, target.Name, cName), tName) - break + + continue ChangedRenamed + } + + d.changes.Add(&AddColumn{ + Schema: target.Schema, + Table: target.Name, + Column: tName, + ColDef: tCol, + }) + } + + // Drop columns which do not exist in the target schema and were not renamed. + for cName, cCol := range current.Columns { + if _, keep := target.Columns[cName]; !keep { + d.changes.Add(&DropColumn{ + Schema: target.Schema, + Table: target.Name, + Column: cName, + ColDef: cCol, + }) } } } diff --git a/migrate/operations.go b/migrate/operations.go index 950dd5e2e..016910edc 100644 --- a/migrate/operations.go +++ b/migrate/operations.go @@ -120,6 +120,84 @@ func (op *RenameColumn) DependsOn(another Operation) bool { return ok && rt.Schema == op.Schema && rt.NewName == op.Table } +type AddColumn struct { + Schema string + Table string + Column string + ColDef sqlschema.Column +} + +var _ Operation = (*AddColumn)(nil) +var _ sqlschema.Operation = (*AddColumn)(nil) + +func (op *AddColumn) FQN() schema.FQN { + return schema.FQN{ + Schema: op.Schema, + Table: op.Table, + } +} + +func (op *AddColumn) GetReverse() Operation { + return &DropColumn{ + Schema: op.Schema, + Table: op.Table, + Column: op.Column, + } +} + +type DropColumn struct { + Schema string + Table string + Column string + ColDef sqlschema.Column +} + +var _ Operation = (*DropColumn)(nil) +var _ sqlschema.Operation = (*DropColumn)(nil) + +func (op *DropColumn) FQN() schema.FQN { + return schema.FQN{ + Schema: op.Schema, + Table: op.Table, + } +} + +func (op *DropColumn) GetReverse() Operation { + return &AddColumn{ + Schema: op.Schema, + Table: op.Table, + Column: op.Column, + ColDef: op.ColDef, + } +} + +func (op *DropColumn) DependsOn(another Operation) bool { + // TODO: refactor + if dc, ok := another.(*DropConstraint); ok { + var fCol bool + fCols := dc.FK.From.Column.Split() + for _, c := range fCols { + if c == op.Column { + fCol = true + break + } + } + + var tCol bool + tCols := dc.FK.To.Column.Split() + for _, c := range tCols { + if c == op.Column { + tCol = true + break + } + } + + return (dc.FK.From.Schema == op.Schema && dc.FK.From.Table == op.Table && fCol) || + (dc.FK.To.Schema == op.Schema && dc.FK.To.Table == op.Table && tCol) + } + return false +} + // RenameConstraint. type RenameConstraint struct { FK sqlschema.FK @@ -168,6 +246,7 @@ func (op *AddForeignKey) FQN() schema.FQN { func (op *AddForeignKey) DependsOn(another Operation) bool { switch another := another.(type) { case *RenameTable: + // TODO: provide some sort of "DependsOn" method for FK return another.Schema == op.FK.From.Schema && another.NewName == op.FK.From.Table case *CreateTable: return (another.Schema == op.FK.To.Schema && another.Name == op.FK.To.Table) || // either it's the referencing one