diff --git a/dialect/pgdialect/inspector.go b/dialect/pgdialect/inspector.go index db457de21..95d2581b2 100644 --- a/dialect/pgdialect/inspector.go +++ b/dialect/pgdialect/inspector.go @@ -39,6 +39,12 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) { return state, err } + var fks []*ForeignKey + if err := in.db.NewRaw(sqlInspectForeignKeys, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil { + return state, err + } + state.FKs = make(map[sqlschema.FK]string, len(fks)) + for _, table := range tables { var columns []*InformationSchemaColumn if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil { @@ -72,12 +78,17 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) { Columns: colDefs, }) } + + for _, fk := range fks { + state.FKs[sqlschema.FK{ + From: sqlschema.C(fk.SourceSchema, fk.SourceTable, fk.SourceColumns...), + To: sqlschema.C(fk.TargetSchema, fk.TargetTable, fk.TargetColumns...), + }] = fk.ConstraintName + } return state, nil } type InformationSchemaTable struct { - bun.BaseModel - Schema string `bun:"table_schema,pk"` Name string `bun:"table_name,pk"` @@ -85,8 +96,6 @@ type InformationSchemaTable struct { } type InformationSchemaColumn struct { - bun.BaseModel - Schema string `bun:"table_schema"` Table string `bun:"table_name"` Name string `bun:"column_name"` @@ -104,17 +113,29 @@ type InformationSchemaColumn struct { UniqueGroup []string `bun:"unique_group,array"` } +type ForeignKey struct { + ConstraintName string `bun:"constraint_name"` + SourceSchema string `bun:"schema_name"` + SourceTable string `bun:"table_name"` + SourceColumns []string `bun:"columns,array"` + TargetSchema string `bun:"target_schema"` + TargetTable string `bun:"target_table"` + TargetColumns []string `bun:"target_columns,array"` +} + const ( // sqlInspectTables retrieves all user-defined tables across all schemas. // It excludes relations from Postgres's reserved "pg_" schemas and views from the "information_schema". + // Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results. sqlInspectTables = ` -SELECT table_schema, table_name +SELECT "table_schema", "table_name" FROM information_schema.tables WHERE table_type = 'BASE TABLE' - AND table_schema <> 'information_schema' - AND table_schema NOT LIKE 'pg_%' - AND table_name NOT IN (?) - ` + AND "table_schema" <> 'information_schema' + AND "table_schema" NOT LIKE 'pg_%' + AND "table_name" NOT IN (?) +ORDER BY "table_schema", "table_name" +` // sqlInspectColumnsQuery retrieves column definitions for the specified table. // Unlike sqlInspectTables and sqlInspectSchema, it should be passed to bun.NewRaw @@ -180,10 +201,13 @@ FROM ( ) att USING ("table_schema", "table_name", "column_name") ) "c" WHERE "table_schema" = ? AND "table_name" = ? - ` +ORDER BY "table_schema", "table_name", "column_name" +` // sqlInspectSchema retrieves column type definitions for all user-defined tables. // Other relations, such as views and indices, as well as Posgres's internal relations are excluded. + // + // TODO: implement scanning ORM relations for RawQuery too, so that one could scan this query directly to InformationSchemaTable. sqlInspectSchema = ` SELECT "t"."table_schema", @@ -247,5 +271,45 @@ FROM information_schema.tables "t" WHERE table_type = 'BASE TABLE' AND table_schema <> 'information_schema' AND table_schema NOT LIKE 'pg_%' - ` +ORDER BY table_schema, table_name +` + + // sqlInspectForeignKeys get FK definitions for user-defined tables. + // Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results. + sqlInspectForeignKeys = ` +WITH + "schemas" AS ( + SELECT oid, nspname + FROM pg_namespace + ), + "tables" AS ( + SELECT oid, relnamespace, relname, relkind + FROM pg_class + ), + "columns" AS ( + SELECT attrelid, attname, attnum + FROM pg_attribute + WHERE attisdropped = false + ) +SELECT DISTINCT + co.conname AS "constraint_name", + ss.nspname AS schema_name, + s.relname AS "table_name", + ARRAY_AGG(sc.attname) AS "columns", + ts.nspname AS target_schema, + "t".relname AS target_table, + ARRAY_AGG(tc.attname) AS target_columns +FROM pg_constraint co + LEFT JOIN "tables" s ON s.oid = co.conrelid + LEFT JOIN "schemas" ss ON ss.oid = s.relnamespace + LEFT JOIN "columns" sc ON sc.attrelid = s.oid AND sc.attnum = ANY(co.conkey) + LEFT JOIN "tables" t ON t.oid = co.confrelid + LEFT JOIN "schemas" ts ON ts.oid = "t".relnamespace + LEFT JOIN "columns" tc ON tc.attrelid = "t".oid AND tc.attnum = ANY(co.confkey) +WHERE co.contype = 'f' + AND co.conrelid IN (SELECT oid FROM pg_class WHERE relkind = 'r') + AND ARRAY_POSITION(co.conkey, sc.attnum) = ARRAY_POSITION(co.confkey, tc.attnum) + AND s.relname NOT IN (?) AND "t".relname NOT IN (?) +GROUP BY "constraint_name", "schema_name", "table_name", target_schema, target_table +` ) diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go index 53c8905cc..42e200e2c 100644 --- a/internal/dbtest/inspect_test.go +++ b/internal/dbtest/inspect_test.go @@ -9,89 +9,265 @@ import ( "github.com/uptrace/bun/migrate/sqlschema" ) -func TestDatabaseInspector_Inspect(t *testing.T) { +type Article struct { + ISBN int `bun:",pk,identity"` + Editor string `bun:",notnull,unique:title_author,default:'john doe'"` + Title string `bun:",notnull,unique:title_author"` + Locale string `bun:",type:varchar(5),default:'en-GB'"` + Pages int8 `bun:"page_count,notnull,default:1"` + Count int32 `bun:"book_count,autoincrement"` + PublisherID string `bun:"publisher_id,notnull"` + AuthorID int `bun:"author_id,notnull"` - type Book struct { - bun.BaseModel `bun:"table:books"` + // Publisher that published this article. + Publisher *Publisher `bun:"rel:belongs-to,join:publisher_id=publisher_id"` - ISBN int `bun:",pk,identity"` - Author string `bun:",notnull,unique:title_author,default:'john doe'"` - Title string `bun:",notnull,unique:title_author"` - Locale string `bun:",type:varchar(5),default:'en-GB'"` - Pages int8 `bun:"page_count,notnull,default:1"` - Count int32 `bun:"book_count,autoincrement"` - } + // Author wrote this article. + Author *Journalist `bun:"rel:belongs-to,join:author_id=author_id"` +} + +type Office struct { + bun.BaseModel `bun:"table:admin.offices"` + Name string `bun:"office_name,pk"` + TennantID string `bun:"publisher_id"` + TennantName string `bun:"publisher_name"` + + Tennant *Publisher `bun:"rel:has-one,join:publisher_id=publisher_id,join:publisher_name=publisher_name"` +} + +type Publisher struct { + ID string `bun:"publisher_id,pk,default:gen_random_uuid(),unique:office_fk"` + Name string `bun:"publisher_name,unique,notnull,unique:office_fk"` + + // Writers write articles for this publisher. + Writers []Journalist `bun:"m2m:publisher_to_journalists,join:Publisher=Author"` +} + +// PublisherToJournalist describes which journalist work with each publisher. +// One publisher can also work with many journalists. It's an N:N (or m2m) relation. +type PublisherToJournalist struct { + bun.BaseModel `bun:"table:publisher_to_journalists"` + PublisherID string `bun:"publisher_id,pk"` + AuthorID int `bun:"author_id,pk"` + + Publisher *Publisher `bun:"rel:belongs-to,join:publisher_id=publisher_id"` + Author *Journalist `bun:"rel:belongs-to,join:author_id=author_id"` +} + +type Journalist struct { + bun.BaseModel `bun:"table:authors"` + ID int `bun:"author_id,pk,identity"` + FirstName string `bun:",notnull"` + LastName string + + // Articles that this journalist has written. + Articles []*Article `bun:"rel:has-many,join:author_id=author_id"` +} + +func TestDatabaseInspector_Inspect(t *testing.T) { testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { + db.RegisterModel((*PublisherToJournalist)(nil)) + dbInspector, err := sqlschema.NewInspector(db) if err != nil { t.Skip(err) } ctx := context.Background() - mustResetModel(t, ctx, db, (*Book)(nil)) - - want := sqlschema.State{ - Tables: []sqlschema.Table{ - { - Schema: "public", - Name: "books", - Columns: map[string]sqlschema.Column{ - "isbn": { - SQLType: "bigint", - IsPK: true, - IsNullable: false, - IsAutoIncrement: false, - IsIdentity: true, - DefaultValue: "", - }, - "author": { - SQLType: "varchar", - IsPK: false, - IsNullable: false, - IsAutoIncrement: false, - IsIdentity: false, - DefaultValue: "john doe", - }, - "title": { - SQLType: "varchar", - IsPK: false, - IsNullable: false, - IsAutoIncrement: false, - IsIdentity: false, - DefaultValue: "", - }, - "locale": { - SQLType: "varchar(5)", - IsPK: false, - IsNullable: true, - IsAutoIncrement: false, - IsIdentity: false, - DefaultValue: "en-GB", - }, - "page_count": { - SQLType: "smallint", - IsPK: false, - IsNullable: false, - IsAutoIncrement: false, - IsIdentity: false, - DefaultValue: "1", - }, - "book_count": { - SQLType: "integer", - IsPK: false, - IsNullable: false, - IsAutoIncrement: true, - IsIdentity: false, - DefaultValue: "", - }, + mustCreateSchema(t, ctx, db, "admin") + mustCreateTableWithFKs(t, ctx, db, + // Order of creation matters: + (*Journalist)(nil), // does not reference other tables + (*Publisher)(nil), // does not reference other tables + (*Office)(nil), // references Publisher + (*PublisherToJournalist)(nil), // references Journalist and Publisher + (*Article)(nil), // references Journalist and Publisher + ) + defaultSchema := db.Dialect().DefaultSchema() + + // Tables come sorted alphabetically by schema and table. + wantTables := []sqlschema.Table{ + { + Schema: "admin", + Name: "offices", + Columns: map[string]sqlschema.Column{ + "office_name": { + SQLType: "varchar", + IsPK: true, + }, + "publisher_id": { + SQLType: "varchar", + IsNullable: true, + }, + "publisher_name": { + SQLType: "varchar", + IsNullable: true, + }, + }, + }, + { + Schema: defaultSchema, + Name: "articles", + Columns: map[string]sqlschema.Column{ + "isbn": { + SQLType: "bigint", + IsPK: true, + IsNullable: false, + IsAutoIncrement: false, + IsIdentity: true, + DefaultValue: "", + }, + "editor": { + SQLType: "varchar", + IsPK: false, + IsNullable: false, + IsAutoIncrement: false, + IsIdentity: false, + DefaultValue: "john doe", + }, + "title": { + SQLType: "varchar", + IsPK: false, + IsNullable: false, + IsAutoIncrement: false, + IsIdentity: false, + DefaultValue: "", + }, + "locale": { + SQLType: "varchar(5)", + IsPK: false, + IsNullable: true, + IsAutoIncrement: false, + IsIdentity: false, + DefaultValue: "en-GB", + }, + "page_count": { + SQLType: "smallint", + IsPK: false, + IsNullable: false, + IsAutoIncrement: false, + IsIdentity: false, + DefaultValue: "1", + }, + "book_count": { + SQLType: "integer", + IsPK: false, + IsNullable: false, + IsAutoIncrement: true, + IsIdentity: false, + DefaultValue: "", + }, + "publisher_id": { + SQLType: "varchar", + }, + "author_id": { + SQLType: "bigint", }, }, }, + { + Schema: defaultSchema, + Name: "authors", + Columns: map[string]sqlschema.Column{ + "author_id": { + SQLType: "bigint", + IsPK: true, + IsIdentity: true, + }, + "first_name": { + SQLType: "varchar", + }, + "last_name": { + SQLType: "varchar", + IsNullable: true, + }, + }, + }, + { + Schema: defaultSchema, + Name: "publisher_to_journalists", + Columns: map[string]sqlschema.Column{ + "publisher_id": { + SQLType: "varchar", + IsPK: true, + }, + "author_id": { + SQLType: "bigint", + IsPK: true, + }, + }, + }, + { + Schema: defaultSchema, + Name: "publishers", + Columns: map[string]sqlschema.Column{ + "publisher_id": { + SQLType: "varchar", + IsPK: true, + DefaultValue: "gen_random_uuid()", + }, + "publisher_name": { + SQLType: "varchar", + }, + }, + }, + } + + wantFKs := []sqlschema.FK{ + { // + From: sqlschema.C(defaultSchema, "articles", "publisher_id"), + To: sqlschema.C(defaultSchema, "publishers", "publisher_id"), + }, + { + From: sqlschema.C(defaultSchema, "articles", "author_id"), + To: sqlschema.C(defaultSchema, "authors", "author_id"), + }, + { // + From: sqlschema.C("admin", "offices", "publisher_name", "publisher_id"), + To: sqlschema.C(defaultSchema, "publishers", "publisher_name", "publisher_id"), + }, + { // + From: sqlschema.C(defaultSchema, "publisher_to_journalists", "publisher_id"), + To: sqlschema.C(defaultSchema, "publishers", "publisher_id"), + }, + { // + From: sqlschema.C(defaultSchema, "publisher_to_journalists", "author_id"), + To: sqlschema.C(defaultSchema, "authors", "author_id"), + }, } got, err := dbInspector.Inspect(ctx) require.NoError(t, err) - require.Equal(t, want, got) + + // State.FKs store their database names, which differ from dialect to dialect. + // Because of that we compare FKs and Tables separately. + require.Equal(t, wantTables, got.Tables) + + var fks []sqlschema.FK + for fk := range got.FKs { + fks = append(fks, fk) + } + require.ElementsMatch(t, wantFKs, fks) + }) +} + +func mustCreateTableWithFKs(tb testing.TB, ctx context.Context, db *bun.DB, models ...interface{}) { + tb.Helper() + for _, model := range models { + create := db.NewCreateTable().Model(model).WithForeignKeys() + _, err := create.Exec(ctx) + require.NoError(tb, err, "must create table %q:", create.GetTableName()) + mustDropTableOnCleanup(tb, ctx, db, model) + } +} + +func mustCreateSchema(tb testing.TB, ctx context.Context, db *bun.DB, schema string) { + tb.Helper() + _, err := db.NewRaw("CREATE SCHEMA IF NOT EXISTS ?", bun.Ident(schema)).Exec(ctx) + require.NoError(tb, err, "create schema %q:", schema) + + tb.Cleanup(func() { + db.NewRaw("DROP SCHEMA IF EXISTS ?", bun.Ident(schema)).Exec(ctx) }) } diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 91bb59265..d1da1ea34 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -260,22 +260,41 @@ func testCreateDropTable(t *testing.T, db *bun.DB) { require.Equal(t, "createme", tables[0].Name) } -type Journal struct { - ISBN string `bun:"isbn,pk"` - Title string `bun:"title,notnull"` - Pages int `bun:"page_count,notnull,default:0"` -} +func TestDetector_Diff(t *testing.T) { + type Journal struct { + ISBN string `bun:"isbn,pk"` + Title string `bun:"title,notnull"` + Pages int `bun:"page_count,notnull,default:0"` + } -type Reader struct { - Username string `bun:",pk,default:gen_random_uuid()"` -} + type Reader struct { + Username string `bun:",pk,default:gen_random_uuid()"` + } -type ExternalUsers struct { - bun.BaseModel `bun:"external.users"` - Name string `bun:",pk"` -} + type ExternalUsers struct { + bun.BaseModel `bun:"external.users"` + Name string `bun:",pk"` + } + + // ------------------------------------------------------------------------ + type ThingNoOwner struct { + bun.BaseModel `bun:"things"` + ID int64 `bun:"thing_id,pk"` + OwnerID int64 `bun:",notnull"` + } + + type Owner struct { + ID int64 `bun:",pk"` + } + + type Thing struct { + bun.BaseModel `bun:"things"` + ID int64 `bun:"thing_id,pk"` + OwnerID int64 `bun:",notnull"` + + Owner *Owner `bun:"rel:belongs-to,join:owner_id=id"` + } -func TestDetector_Diff(t *testing.T) { testEachDialect(t, func(t *testing.T, dialectName string, dialect schema.Dialect) { for _, tt := range []struct { name string @@ -283,7 +302,7 @@ func TestDetector_Diff(t *testing.T) { want []migrate.Operation }{ { - name: "1 table renamed, 1 added, 2 dropped", + name: "1 table renamed, 1 created, 2 dropped", states: func(tb testing.TB, ctx context.Context, d schema.Dialect) (stateDb sqlschema.State, stateModel sqlschema.State) { // Database state ------------- type Subscription struct { @@ -361,13 +380,111 @@ func TestDetector_Diff(t *testing.T) { }, }, }, + { + name: "detect new FKs on existing columns", + states: func(t testing.TB, ctx context.Context, d schema.Dialect) (stateDb sqlschema.State, stateModel sqlschema.State) { + // database state + type LonelyUser struct { + bun.BaseModel `bun:"table:users"` + Username string `bun:",pk"` + DreamPetKind string `bun:"pet_kind,notnull"` + DreamPetName string `bun:"pet_name,notnull"` + ImaginaryFriend string `bun:"friend"` + } + + type Pet struct { + Nickname string `bun:",pk"` + Kind string `bun:",pk"` + } + + // model state + type HappyUser struct { + bun.BaseModel `bun:"table:users"` + Username string `bun:",pk"` + PetKind string `bun:"pet_kind,notnull"` + PetName string `bun:"pet_name,notnull"` + Friend string `bun:"friend"` + + Pet *Pet `bun:"rel:has-one,join:pet_kind=kind,join:pet_name=nickname"` + BestFriend *HappyUser `bun:"rel:has-one,join:friend=username"` + } + + return getState(t, ctx, d, + (*LonelyUser)(nil), + (*Pet)(nil), + ), getState(t, ctx, d, + (*HappyUser)(nil), + (*Pet)(nil), + ) + }, + want: []migrate.Operation{ + &migrate.AddForeignKey{ + SourceTable: "users", + SourceColumns: []string{"pet_kind", "pet_name"}, + TargetTable: "pets", + TargetColums: []string{"kind", "nickname"}, + }, + &migrate.AddForeignKey{ + SourceTable: "users", + SourceColumns: []string{"friend"}, + TargetTable: "users", + TargetColums: []string{"username"}, + }, + }, + }, + { + name: "create FKs for new tables", // TODO: update test case to detect an added column too + states: func(t testing.TB, ctx context.Context, d schema.Dialect) (stateDb sqlschema.State, stateModel sqlschema.State) { + return getState(t, ctx, d, + (*ThingNoOwner)(nil), + ), getState(t, ctx, d, + (*Owner)(nil), + (*Thing)(nil), + ) + }, + want: []migrate.Operation{ + &migrate.CreateTable{ + Model: &Owner{}, + }, + &migrate.AddForeignKey{ + SourceTable: "things", + SourceColumns: []string{"owner_id"}, + TargetTable: "owners", + TargetColums: []string{"id"}, + }, + }, + }, + { + name: "drop FKs for dropped tables", // TODO: update test case to detect dropped columns too + states: func(t testing.TB, ctx context.Context, d schema.Dialect) (sqlschema.State, sqlschema.State) { + stateDb := getState(t, ctx, d, (*Owner)(nil), (*Thing)(nil)) + stateModel := getState(t, ctx, d, (*ThingNoOwner)(nil)) + + // Normally a database state will have the names of the constraints filled in, but we need to mimic that for the test. + stateDb.FKs[sqlschema.FK{ + From: sqlschema.C(d.DefaultSchema(), "things", "owner_id"), + To: sqlschema.C(d.DefaultSchema(), "owners", "id"), + }] = "test_fkey" + return stateDb, stateModel + }, + want: []migrate.Operation{ + &migrate.DropTable{ + Schema: dialect.DefaultSchema(), + Name: "owners", + }, + &migrate.DropForeignKey{ + Schema: dialect.DefaultSchema(), + Table: "things", + ConstraintName: "test_fkey", + }, + }, + }, } { - t.Run(funcName(tt.states), func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - var d migrate.Detector stateDb, stateModel := tt.states(t, ctx, dialect) - got := d.Diff(stateDb, stateModel).Operations() + got := migrate.Diff(stateDb, stateModel).Operations() checkEqualChangeset(t, got, tt.want) }) } diff --git a/internal/dbtest/sqlschema_test.go b/internal/dbtest/sqlschema_test.go new file mode 100644 index 000000000..29f709e14 --- /dev/null +++ b/internal/dbtest/sqlschema_test.go @@ -0,0 +1,222 @@ +package dbtest_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/uptrace/bun/migrate/sqlschema" +) + +func TestRefMap_Update(t *testing.T) { + for _, tt := range []struct { + name string + fks []sqlschema.FK + update func(rm sqlschema.RefMap) int + wantUpdated int + wantFKs []sqlschema.FK + }{ + { + name: "update table reference in all FKs that reference its columns", + fks: []sqlschema.FK{ + { + From: sqlschema.C("x", "y", "z"), + To: sqlschema.C("a", "b", "c"), + }, + { + From: sqlschema.C("m", "n", "o"), + To: sqlschema.C("a", "b", "d"), + }, + }, + update: func(rm sqlschema.RefMap) int { + return rm.UpdateT(sqlschema.T("a", "b"), sqlschema.T("a", "new_b")) + }, + wantUpdated: 2, + wantFKs: []sqlschema.FK{ // checking 1 of the 2 updated ones should be enough + { + From: sqlschema.C("x", "y", "z"), + To: sqlschema.C("a", "new_b", "c"), + }, + }, + }, + { + name: "update table reference in FK which points to the same table", + fks: []sqlschema.FK{ + { + From: sqlschema.C("a", "b", "child"), + To: sqlschema.C("a", "b", "parent"), + }, + }, + update: func(rm sqlschema.RefMap) int { + return rm.UpdateT(sqlschema.T("a", "b"), sqlschema.T("a", "new_b")) + }, + wantUpdated: 1, + wantFKs: []sqlschema.FK{ + { + From: sqlschema.C("a", "new_b", "child"), + To: sqlschema.C("a", "new_b", "parent"), + }, + }, + }, + { + name: "update column reference in all FKs which depend on it", + fks: []sqlschema.FK{ + { + From: sqlschema.C("x", "y", "z"), + To: sqlschema.C("a", "b", "c"), + }, + { + From: sqlschema.C("a", "b", "c"), + To: sqlschema.C("m", "n", "o"), + }, + }, + update: func(rm sqlschema.RefMap) int { + return rm.UpdateC(sqlschema.C("a", "b", "c"), "c_new") + }, + wantUpdated: 2, + wantFKs: []sqlschema.FK{ + { + From: sqlschema.C("x", "y", "z"), + To: sqlschema.C("a", "b", "c_new"), + }, + }, + }, + { + name: "foreign keys defined on multiple columns", + fks: []sqlschema.FK{ + { + From: sqlschema.C("a", "b", "c1", "c2"), + To: sqlschema.C("q", "r", "s1", "s2"), + }, + { + From: sqlschema.C("m", "n", "o", "p"), + To: sqlschema.C("a", "b", "c2"), + }, + }, + update: func(rm sqlschema.RefMap) int { + return rm.UpdateC(sqlschema.C("a", "b", "c2"), "x2") + }, + wantUpdated: 2, + wantFKs: []sqlschema.FK{ + { + From: sqlschema.C("a", "b", "c1", "x2"), + To: sqlschema.C("q", "r", "s1", "s2"), + }, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + rm := sqlschema.NewRefMap(tt.fks...) + + n := tt.update(rm) + + require.Equal(t, tt.wantUpdated, n) + require.Equal(t, tt.wantUpdated, len(rm.Updated())) + checkHasFK(t, rm, tt.wantFKs...) + }) + } +} + +func checkHasFK(tb testing.TB, rm sqlschema.RefMap, fks ...sqlschema.FK) { +outer: + for _, want := range fks { + for _, gotptr := range rm { + if got := *gotptr; got == want { + continue outer + } + } + tb.Fatalf("did not find FK%+v", want) + } +} + +func TestRefMap_Delete(t *testing.T) { + for _, tt := range []struct { + name string + fks []sqlschema.FK + del func(rm sqlschema.RefMap) int + wantDeleted []sqlschema.FK + }{ + { + name: "delete FKs that depend on the table", + fks: []sqlschema.FK{ + { + From: sqlschema.C("a", "b", "c"), + To: sqlschema.C("x", "y", "z"), + }, + { + From: sqlschema.C("m", "n", "o"), + To: sqlschema.C("a", "b", "d"), + }, + { + From: sqlschema.C("q", "r", "s"), + To: sqlschema.C("w", "w", "w"), + }, + }, + del: func(rm sqlschema.RefMap) int { + return rm.DeleteT(sqlschema.T("a", "b")) + }, + wantDeleted: []sqlschema.FK{ + { + From: sqlschema.C("a", "b", "c"), + To: sqlschema.C("x", "y", "z"), + }, + { + From: sqlschema.C("m", "n", "o"), + To: sqlschema.C("a", "b", "d"), + }, + }, + }, + { + name: "delete FKs that depend on the column", + fks: []sqlschema.FK{ + { + From: sqlschema.C("a", "b", "c"), + To: sqlschema.C("x", "y", "z"), + }, + { + From: sqlschema.C("q", "r", "s"), + To: sqlschema.C("w", "w", "w"), + }, + }, + del: func(rm sqlschema.RefMap) int { + return rm.DeleteC(sqlschema.C("a", "b", "c")) + }, + wantDeleted: []sqlschema.FK{ + { + From: sqlschema.C("a", "b", "c"), + To: sqlschema.C("x", "y", "z"), + }, + }, + }, + { + name: "foreign keys defined on multiple columns", + fks: []sqlschema.FK{ + { + From: sqlschema.C("a", "b", "c1", "c2"), + To: sqlschema.C("q", "r", "s1", "s2"), + }, + { + From: sqlschema.C("m", "n", "o", "p"), + To: sqlschema.C("a", "b", "c2"), + }, + }, + del: func(rm sqlschema.RefMap) int { + return rm.DeleteC(sqlschema.C("a", "b", "c1")) + }, + wantDeleted: []sqlschema.FK{ + { + From: sqlschema.C("a", "b", "c1", "c2"), + To: sqlschema.C("q", "r", "s1", "s2"), + }, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + rm := sqlschema.NewRefMap(tt.fks...) + + n := tt.del(rm) + + require.Equal(t, len(tt.wantDeleted), n) + require.ElementsMatch(t, rm.Deleted(), tt.wantDeleted) + }) + } +} diff --git a/migrate/auto.go b/migrate/auto.go index c58e243d9..677ce6787 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -107,7 +107,6 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err } func (am *AutoMigrator) diff(ctx context.Context) (Changeset, error) { - var detector Detector var changes Changeset var err error @@ -120,7 +119,7 @@ func (am *AutoMigrator) diff(ctx context.Context) (Changeset, error) { if err != nil { return changes, err } - return detector.Diff(got, want), nil + return Diff(got, want), nil } // Migrate writes required changes to a new migration file and runs the migration. @@ -167,12 +166,23 @@ func (am *AutoMigrator) Run(ctx context.Context) error { // INTERNAL ------------------------------------------------------------------- -type Detector struct{} +func Diff(got, want sqlschema.State) Changeset { + detector := newDetector() + return detector.DetectChanges(got, want) +} -func (d *Detector) Diff(got, want sqlschema.State) Changeset { - var changes Changeset +type detector struct { + changes Changeset +} + +func newDetector() *detector { + return &detector{} +} - oldModels := newTableSet(got.Tables...) +func (d *detector) DetectChanges(got, want sqlschema.State) Changeset { + + // TableSets for discovering CREATE/RENAME/DROP TABLE + oldModels := newTableSet(got.Tables...) // newModels := newTableSet(want.Tables...) addedModels := newModels.Sub(oldModels) @@ -182,7 +192,7 @@ AddedLoop: removedModels := oldModels.Sub(newModels) for _, removed := range removedModels.Values() { if d.canRename(added, removed) { - changes.Add(&RenameTable{ + d.changes.Add(&RenameTable{ Schema: removed.Schema, From: removed.Name, To: added.Name, @@ -196,7 +206,7 @@ AddedLoop: } } // If a new table did not appear because of the rename operation, then it must've been created. - changes.Add(&CreateTable{ + d.changes.Add(&CreateTable{ Schema: added.Schema, Name: added.Name, Model: added.Model, @@ -205,17 +215,39 @@ AddedLoop: // Tables that aren't present anymore and weren't renamed were deleted. for _, t := range oldModels.Sub(newModels).Values() { - changes.Add(&DropTable{ + d.changes.Add(&DropTable{ Schema: t.Schema, Name: t.Name, }) } - return changes + // Compare FKs + for fk /*, fkName */ := range want.FKs { + if _, ok := got.FKs[fk]; !ok { + d.changes.Add(&AddForeignKey{ + SourceTable: fk.From.Table, + SourceColumns: fk.From.Column.Split(), + TargetTable: fk.To.Table, + TargetColums: fk.To.Column.Split(), + }) + } + } + + for fk, fkName := range got.FKs { + if _, ok := want.FKs[fk]; !ok { + d.changes.Add(&DropForeignKey{ + Schema: fk.From.Schema, + Table: fk.From.Table, + ConstraintName: fkName, + }) + } + } + + return d.changes } // canRename checks if t1 can be renamed to t2. -func (d Detector) canRename(t1, t2 sqlschema.Table) bool { +func (d detector) canRename(t1, t2 sqlschema.Table) bool { return t1.Schema == t2.Schema && sqlschema.EqualSignatures(t1, t2) } @@ -231,6 +263,9 @@ func (c Changeset) String() string { for _, op := range c.operations { ops = append(ops, op.String()) } + if len(ops) == 0 { + return "" + } return strings.Join(ops, "\n") } @@ -384,6 +419,50 @@ func trimSchema(name string) string { return name } +type AddForeignKey struct { + SourceTable string + SourceColumns []string + TargetTable string + TargetColums []string +} + +var _ Operation = (*AddForeignKey)(nil) + +func (op AddForeignKey) String() string { + return fmt.Sprintf("AddForeignKey %s(%s) references %s(%s)", + op.SourceTable, strings.Join(op.SourceColumns, ","), + op.TargetTable, strings.Join(op.TargetColums, ","), + ) +} + +func (op *AddForeignKey) Func(m sqlschema.Migrator) MigrationFunc { + return nil +} + +func (op *AddForeignKey) GetReverse() Operation { + return nil +} + +type DropForeignKey struct { + Schema string + Table string + ConstraintName string +} + +var _ Operation = (*DropForeignKey)(nil) + +func (op *DropForeignKey) String() string { + return fmt.Sprintf("DropFK %q on table %q.%q", op.ConstraintName, op.Schema, op.Table) +} + +func (op *DropForeignKey) Func(m sqlschema.Migrator) MigrationFunc { + return nil +} + +func (op *DropForeignKey) GetReverse() Operation { + return nil +} + // sqlschema utils ------------------------------------------------------------ // tableSet stores unique table definitions. diff --git a/migrate/sqlschema/inspector.go b/migrate/sqlschema/inspector.go index 2f44f93c5..2060fef0c 100644 --- a/migrate/sqlschema/inspector.go +++ b/migrate/sqlschema/inspector.go @@ -47,7 +47,9 @@ func NewSchemaInspector(tables *schema.Tables) *SchemaInspector { } func (si *SchemaInspector) Inspect(ctx context.Context) (State, error) { - var state State + state := State{ + FKs: make(map[FK]string), + } for _, t := range si.tables.All() { columns := make(map[string]Column) for _, f := range t.Fields { @@ -67,6 +69,22 @@ func (si *SchemaInspector) Inspect(ctx context.Context) (State, error) { Model: t.ZeroIface, Columns: columns, }) + + for _, rel := range t.Relations { + var fromCols, toCols []string + for _, f := range rel.BaseFields { + fromCols = append(fromCols, f.Name) + } + for _, f := range rel.JoinFields { + toCols = append(toCols, f.Name) + } + + target := rel.JoinTable + state.FKs[FK{ + From: C(t.Schema, t.Name, fromCols...), + To: C(target.Schema, target.Name, toCols...), + }] = "" + } } return state, nil } diff --git a/migrate/sqlschema/state.go b/migrate/sqlschema/state.go index 8f7e96b0d..c57ea36d0 100644 --- a/migrate/sqlschema/state.go +++ b/migrate/sqlschema/state.go @@ -1,7 +1,10 @@ package sqlschema +import "strings" + type State struct { Tables []Table + FKs map[FK]string } type Table struct { @@ -61,3 +64,217 @@ func (s *signature) Equals(other signature) bool { } return true } + +// tFQN is a fully-qualified table name. +type tFQN struct { + Schema string + Table string +} + +func T(schema, table string) tFQN { return tFQN{Schema: schema, Table: table} } + +// cFQN is a fully-qualified column name. +type cFQN struct { + tFQN + Column composite +} + +func C(schema, table string, columns ...string) cFQN { + return cFQN{tFQN: T(schema, table), Column: newComposite(columns...)} +} + +// composite is a hashable representation of []string used to define FKs that depend on multiple columns. +// Although having duplicated column references in a FK is illegal, composite neither validate nor enforce this constraint on the caller. +type composite string + +// newComposite creates a composite column from a slice of column names. +func newComposite(columns ...string) composite { + return composite(strings.Join(columns, ",")) +} + +// Split returns a slice of column names that make up the composite. +func (c composite) Split() []string { + return strings.Split(string(c), ",") +} + +// Contains checks that a composite column contains every part of another composite. +func (c composite) Contains(other composite) bool { + var count int + checkColumns := other.Split() + wantCount := len(checkColumns) + + for _, check := range checkColumns { + for _, column := range c.Split() { + if check == column { + count++ + } + if count == wantCount { + return true + } + } + } + return count == wantCount +} + +// Replace renames a column if it is part of the composite. +// If a composite consists of multiple columns, only one column will be renamed. +func (c composite) Replace(oldColumn, newColumn string) composite { + columns := c.Split() + for i, column := range columns { + if column == oldColumn { + columns[i] = newColumn + return newComposite(columns...) + } + } + return c +} + +// T returns the FQN of the column's parent table. +func (c cFQN) T() tFQN { + return tFQN{Schema: c.Schema, Table: c.Table} +} + +// FK defines a foreign key constraint. +// FK depends on a column/table if their FQN is included in its definition. +// +// Example: +// +// FK{ +// From: C{"A", "B", "C"}, +// To: C{"X", "Y", "Z"}, +// } +// - depends on C{"A", "B", "C"} +// - depends on C{"X", "Y", "Z"} +// - depends on T{"A", "B"} and T{"X", "Y"} +// +// FIXME: current design does not allow for one column referencing multiple columns. Or does it? Think again. +// Consider: +// +// CONSTRAINT fk_customers FOREIGN KEY (customer_id) REFERENCES customers(id) +// CONSTRAINT fk_orders FOREIGN KEY (customer_id) REFERENCES orders(customer_id) +type FK struct { + From cFQN // From is the referencing column. + To cFQN // To is the referenced column. +} + +// DependsT checks if either part of the FK's definition mentions T +// and returns the columns that belong to T. Notice that *C allows modifying the column's FQN. +func (fk *FK) DependsT(t tFQN) (ok bool, cols []*cFQN) { + if c := &fk.From; c.T() == t { + ok = true + cols = append(cols, c) + } + if c := &fk.To; c.T() == t { + ok = true + cols = append(cols, c) + } + if !ok { + return false, nil + } + return +} + +// DependsC checks if the FK definition mentions C and returns a modifiable FQN of the matching column. +func (fk *FK) DependsC(c cFQN) (bool, *cFQN) { + switch { + case fk.From.Column.Contains(c.Column): + return true, &fk.From + case fk.To.Column.Contains(c.Column): + return true, &fk.To + } + return false, nil +} + +// RefMap helps detecting modified FK relations. +// It starts with an initial state and provides methods to update and delete +// foreign key relations based on the column or table they depend on. +type RefMap map[FK]*FK + +// deleted is a special value that RefMap uses to denote a deleted FK constraint. +var deleted FK + +// NewRefMap records the FK's initial state to a RefMap. +func NewRefMap(fks ...FK) RefMap { + ref := make(RefMap) + for _, fk := range fks { + copyfk := fk + ref[fk] = ©fk + } + return ref +} + +// UpdateT updates the table FQN in all FKs that depend on it, e.g. if a table is renamed or moved to a different schema. +// Returns the number of updated entries. +func (r RefMap) UpdateT(oldT, newT tFQN) (n int) { + for _, fk := range r { + ok, cols := fk.DependsT(oldT) + if !ok { + continue + } + for _, c := range cols { + c.Schema = newT.Schema + c.Table = newT.Table + } + n++ + } + return +} + +// UpdateC updates the column FQN in all FKs that depend on it, e.g. if a column is renamed, +// and so, only the column-name part of the FQN can be updated. Returns the number of updated entries. +func (r RefMap) UpdateC(oldC cFQN, newColumn string) (n int) { + for _, fk := range r { + if ok, col := fk.DependsC(oldC); ok { + oldColumns := oldC.Column.Split() + // UpdateC can only update 1 column at a time. + col.Column = col.Column.Replace(oldColumns[0], newColumn) + n++ + } + } + return +} + +// DeleteT marks all FKs that depend on the table as deleted. +// Returns the number of deleted entries. +func (r RefMap) DeleteT(t tFQN) (n int) { + for old, fk := range r { + if ok, _ := fk.DependsT(t); ok { + r[old] = &deleted + n++ + } + } + return +} + +// DeleteC marks all FKs that depend on the column as deleted. +// Returns the number of deleted entries. +func (r RefMap) DeleteC(c cFQN) (n int) { + for old, fk := range r { + if ok, _ := fk.DependsC(c); ok { + r[old] = &deleted + n++ + } + } + return +} + +// Updated returns FKs that were updated, both their old and new defitions. +func (r RefMap) Updated() map[FK]FK { + fks := make(map[FK]FK) + for old, fk := range r { + if old != *fk { + fks[old] = *fk + } + } + return fks +} + +// Deleted gets all FKs that were marked as deleted. +func (r RefMap) Deleted() (fks []FK) { + for old, fk := range r { + if fk == &deleted { + fks = append(fks, old) + } + } + return +}