diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go index 237e0bb78..d88974bb0 100644 --- a/dialect/pgdialect/alter_table.go +++ b/dialect/pgdialect/alter_table.go @@ -185,32 +185,31 @@ func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *mi return b, err } + // alterColumn never re-assigns err, so there is no need to check for err != nil after calling it var i int appendAlterColumn := func() { if i > 0 { - b = append(b, ", "...) + b = append(b, ","...) } b = append(b, " ALTER COLUMN "...) - b, err = bun.Ident(colDef.Column).AppendQuery(fmter, b) + b, _ = bun.Ident(colDef.Column).AppendQuery(fmter, b) i++ } got, want := colDef.From, colDef.To if want.SQLType != got.SQLType { - if appendAlterColumn(); err != nil { - return b, err - } + appendAlterColumn() b = append(b, " SET DATA TYPE "...) if b, err = want.AppendQuery(fmter, b); err != nil { return b, err } } + // Column must be declared NOT NULL before identity can be added. + // Although PG can resolve the order of operations itself, we make this explicit in the query. if want.IsNullable != got.IsNullable { - if appendAlterColumn(); err != nil { - return b, err - } + appendAlterColumn() if !want.IsNullable { b = append(b, " SET NOT NULL"...) } else { @@ -218,10 +217,18 @@ func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *mi } } - if want.DefaultValue != got.DefaultValue { - if appendAlterColumn(); err != nil { - return b, err + if want.IsIdentity != got.IsIdentity { + appendAlterColumn() + if !want.IsIdentity { + b = append(b, " DROP IDENTITY"...) + } else { + b = append(b, " ADD"...) + b = appendGeneratedAsIdentity(b) } + } + + if want.DefaultValue != got.DefaultValue { + appendAlterColumn() if want.DefaultValue == "" { b = append(b, " DROP DEFAULT"...) } else { diff --git a/dialect/pgdialect/dialect.go b/dialect/pgdialect/dialect.go index 73355b6c0..e95b581fd 100644 --- a/dialect/pgdialect/dialect.go +++ b/dialect/pgdialect/dialect.go @@ -123,5 +123,10 @@ func (d *Dialect) AppendUint64(b []byte, n uint64) []byte { } func (d *Dialect) AppendSequence(b []byte, _ *schema.Table, _ *schema.Field) []byte { + return appendGeneratedAsIdentity(b) +} + +// appendGeneratedAsIdentity appends GENERATED BY DEFAULT AS IDENTITY to the column definition. +func appendGeneratedAsIdentity(b []byte) []byte { return append(b, " GENERATED BY DEFAULT AS IDENTITY"...) } diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 5b82609eb..553114e93 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/migrate" "github.com/uptrace/bun/migrate/sqlschema" ) @@ -206,6 +207,7 @@ func TestAutoMigrator_Run(t *testing.T) { {testForceRenameFK}, {testRenameColumnRenamesFK}, {testChangeColumnType_AutoCast}, + {testIdentity}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -660,7 +662,50 @@ func testChangeColumnType_AutoCast(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) - // require.Equal(t, wantTables, state.Tables +} + +func testIdentity(t *testing.T, db *bun.DB) { + type TableBefore struct { + bun.BaseModel `bun:"table:table"` + A int64 `bun:",notnull,identity"` + B int64 + } + + type TableAfter struct { + bun.BaseModel `bun:"table:table"` + A int64 `bun:",notnull"` + B int64 `bun:",notnull,identity"` + } + + wantTables := []sqlschema.Table{ + { + Schema: db.Dialect().DefaultSchema(), + Name: "table", + Columns: map[string]sqlschema.Column{ + "a": { + SQLType: sqltype.BigInt, + IsIdentity: false, // <- drop IDENTITY + }, + "b": { + SQLType: sqltype.BigInt, + IsIdentity: true, // <- add IDENTITY + }, + }, + }, + } + + ctx := context.Background() + inspect := inspectDbOrSkip(t, db) + mustResetModel(t, ctx, db, (*TableBefore)(nil)) + m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil))) + + // Act + err := m.Run(ctx) + require.NoError(t, err) + + // Assert + state := inspect(ctx) + cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) } // // TODO: rewrite these tests into AutoMigrator tests, Diff should be moved to migrate/internal package diff --git a/migrate/diff.go b/migrate/diff.go index 75036a797..bb13723b4 100644 --- a/migrate/diff.go +++ b/migrate/diff.go @@ -356,9 +356,8 @@ func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkTyp continue } - // Find the column with the same type and the for cName, cCol := range current.Columns { - if !d.equalColumns(tCol, cCol) { + if _, keep := target.Columns[cName]; keep || !d.equalColumns(tCol, cCol) { continue } d.changes.Add(&RenameColumn{ diff --git a/migrate/operations.go b/migrate/operations.go index 3d37c8c76..950dd5e2e 100644 --- a/migrate/operations.go +++ b/migrate/operations.go @@ -206,6 +206,7 @@ func (op *DropConstraint) GetReverse() Operation { } } +// Change column type. type ChangeColumnType struct { Schema string Table string