Skip to content

Commit

Permalink
feat: support UNIQUE constraints
Browse files Browse the repository at this point in the history
Database inspector no longer mixes UNIQUE constraints with
other constraint types, such as PK and FKs. While database
drivers may require some FK and PK fields to be unique,
these constraints are in practice distinct UNIQUE constraints
and should not be synonymous to PK/FK constraints.

Changes in UNIQUE constraints can be detected in any tables
that haven't been dropped.
  • Loading branch information
bevzzz committed Oct 27, 2024
1 parent fed6012 commit 3c4d5d2
Show file tree
Hide file tree
Showing 8 changed files with 369 additions and 115 deletions.
33 changes: 28 additions & 5 deletions dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ func (m *migrator) Apply(ctx context.Context, changes ...interface{}) error {
b, err = m.addColumn(fmter, b, change)
case *migrate.DropColumn:
b, err = m.dropColumn(fmter, b, change)
case *migrate.DropConstraint:
b, err = m.dropContraint(fmter, b, change)
case *migrate.AddForeignKey:
b, err = m.addForeignKey(fmter, b, change)
case *migrate.AddUniqueConstraint:
b, err = m.addUnique(fmter, b, change)
case *migrate.DropUniqueConstraint:
b, err = m.dropConstraint(fmter, b, change.FQN, change.Unique.Name)
case *migrate.DropConstraint:
b, err = m.dropConstraint(fmter, b, change.FQN(), change.ConstraintName)
case *migrate.RenameConstraint:
b, err = m.renameConstraint(fmter, b, change)
case *migrate.ChangeColumnType:
Expand Down Expand Up @@ -147,15 +151,34 @@ func (m *migrator) renameConstraint(fmter schema.Formatter, b []byte, rename *mi
return b, nil
}

func (m *migrator) dropContraint(fmter schema.Formatter, b []byte, drop *migrate.DropConstraint) (_ []byte, err error) {
func (m *migrator) addUnique(fmter schema.Formatter, b []byte, change *migrate.AddUniqueConstraint) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
if b, err = change.FQN.AppendQuery(fmter, b); err != nil {
return b, err
}

b = append(b, " ADD CONSTRAINT "...)
if change.Unique.Name != "" {
b = fmter.AppendName(b, change.Unique.Name)
} else {
// Default naming scheme for unique constraints in Postgres is <table>_<column>_key
b = fmter.AppendName(b, fmt.Sprintf("%s_%s_key", change.FQN.Table, change.Unique.Columns))
}
b = append(b, " UNIQUE ("...)
b, _ = change.Unique.Columns.Safe().AppendQuery(fmter, b)
b = append(b, ")"...)

return b, nil
}

func (m *migrator) dropConstraint(fmter schema.Formatter, b []byte, fqn schema.FQN, name string) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
fqn := drop.FQN()
if b, err = fqn.AppendQuery(fmter, b); err != nil {
return b, err
}

b = append(b, " DROP CONSTRAINT "...)
b = fmter.AppendName(b, drop.ConstraintName)
b = fmter.AppendName(b, name)

return b, nil
}
Expand Down
102 changes: 23 additions & 79 deletions dialect/pgdialect/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) {
if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil {
return state, err
}

colDefs := make(map[string]sqlschema.Column)
uniqueGroups := make(map[string][]string)

for _, c := range columns {
def := c.Default
if c.IsSerial || c.IsIdentity {
Expand All @@ -66,12 +69,25 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) {
IsAutoIncrement: c.IsSerial,
IsIdentity: c.IsIdentity,
}

for _, group := range c.UniqueGroups {
uniqueGroups[group] = append(uniqueGroups[group], c.Name)
}
}

var unique []sqlschema.Unique
for name, columns := range uniqueGroups {
unique = append(unique, sqlschema.Unique{
Name: name,
Columns: sqlschema.NewComposite(columns...),
})
}

state.Tables = append(state.Tables, sqlschema.Table{
Schema: table.Schema,
Name: table.Name,
Columns: colDefs,
Schema: table.Schema,
Name: table.Name,
Columns: colDefs,
UniqueContraints: unique,
})
}

Expand Down Expand Up @@ -106,8 +122,7 @@ type InformationSchemaColumn struct {
IndentityType string `bun:"identity_type"`
IsSerial bool `bun:"is_serial"`
IsNullable bool `bun:"is_nullable"`
IsUnique bool `bun:"is_unique"`
UniqueGroup []string `bun:"unique_group,array"`
UniqueGroups []string `bun:"unique_groups,array"`
}

type ForeignKey struct {
Expand Down Expand Up @@ -156,8 +171,7 @@ SELECT
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial,
COALESCE("c".identity_type, '') AS identity_type,
"c".is_nullable = 'YES' AS is_nullable,
'u' = ANY("c".constraint_type) AS is_unique,
"c"."constraint_name" AS unique_group
"c"."unique_groups" AS unique_groups
FROM (
SELECT
"table_schema",
Expand All @@ -170,7 +184,7 @@ FROM (
"c".is_nullable,
att.array_dims,
att.identity_type,
att."constraint_name",
att."unique_groups",
att."constraint_type"
FROM information_schema.columns "c"
LEFT JOIN (
Expand All @@ -180,7 +194,7 @@ FROM (
"c".attname AS "column_name",
"c".attndims AS array_dims,
"c".attidentity AS identity_type,
ARRAY_AGG(con.conname) AS "constraint_name",
ARRAY_AGG(con.conname) FILTER (WHERE con.contype = 'u') AS "unique_groups",
ARRAY_AGG(con.contype) AS "constraint_type"
FROM (
SELECT
Expand All @@ -200,76 +214,6 @@ FROM (
) "c"
WHERE "table_schema" = ? AND "table_name" = ?
ORDER BY "table_schema", "table_name", "column_name"
`

// sqlInspectSchema retrieves column type definitions for all user-defined tables.
// Other relations, such as views and indices, as well as Posgres's internal relations are excluded.
//
// TODO: implement scanning ORM relations for RawQuery too, so that one could scan this query directly to InformationSchemaTable.
sqlInspectSchema = `
SELECT
"t"."table_schema",
"t".table_name,
"c".column_name,
"c".data_type,
"c".character_maximum_length::integer AS varchar_len,
"c".data_type = 'ARRAY' AS is_array,
COALESCE("c".array_dims, 0) AS array_dims,
CASE
WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$')
ELSE "c".column_default
END AS "default",
"c".constraint_type = 'p' AS is_pk,
"c".is_identity = 'YES' AS is_identity,
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "t".table_name, "c".column_name) AS is_serial,
COALESCE("c".identity_type, '') AS identity_type,
"c".is_nullable = 'YES' AS is_nullable,
"c".constraint_type = 'u' AS is_unique,
"c"."constraint_name" AS unique_group
FROM information_schema.tables "t"
LEFT JOIN (
SELECT
"table_schema",
"table_name",
"column_name",
"c".data_type,
"c".character_maximum_length,
"c".column_default,
"c".is_identity,
"c".is_nullable,
att.array_dims,
att.identity_type,
att."constraint_name",
att."constraint_type"
FROM information_schema.columns "c"
LEFT JOIN (
SELECT
s.nspname AS table_schema,
"t".relname AS "table_name",
"c".attname AS "column_name",
"c".attndims AS array_dims,
"c".attidentity AS identity_type,
con.conname AS "constraint_name",
con.contype AS "constraint_type"
FROM (
SELECT
conname,
contype,
connamespace,
conrelid,
conrelid AS attrelid,
UNNEST(conkey) AS attnum
FROM pg_constraint
) con
LEFT JOIN pg_attribute "c" USING (attrelid, attnum)
LEFT JOIN pg_namespace s ON s.oid = con.connamespace
LEFT JOIN pg_class "t" ON "t".oid = con.conrelid
) att USING (table_schema, "table_name", "column_name")
) "c" USING (table_schema, "table_name")
WHERE table_type = 'BASE TABLE'
AND table_schema <> 'information_schema'
AND table_schema NOT LIKE 'pg_%'
ORDER BY table_schema, table_name
`

// sqlInspectForeignKeys get FK definitions for user-defined tables.
Expand Down
67 changes: 61 additions & 6 deletions internal/dbtest/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type Office struct {

type Publisher struct {
ID string `bun:"publisher_id,pk,default:gen_random_uuid(),unique:office_fk"`
Name string `bun:"publisher_name,unique,notnull,unique:office_fk"`
Name string `bun:"publisher_name,notnull,unique:office_fk"`
CreatedAt time.Time `bun:"created_at,default:current_timestamp"`

// Writers write articles for this publisher.
Expand All @@ -63,8 +63,9 @@ type PublisherToJournalist struct {
type Journalist struct {
bun.BaseModel `bun:"table:authors"`
ID int `bun:"author_id,pk,identity"`
FirstName string `bun:",notnull"`
LastName string
FirstName string `bun:"first_name,notnull,unique:full_name"`
LastName string `bun:"last_name,notnull,unique:full_name"`
Email string `bun:"email,notnull,unique"`

// Articles that this journalist has written.
Articles []*Article `bun:"rel:has-many,join:author_id=author_id"`
Expand Down Expand Up @@ -171,6 +172,9 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
SQLType: "bigint",
},
},
UniqueContraints: []sqlschema.Unique{
{Columns: sqlschema.NewComposite("editor", "title")},
},
},
{
Schema: defaultSchema,
Expand All @@ -185,10 +189,16 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
SQLType: sqltype.VarChar,
},
"last_name": {
SQLType: sqltype.VarChar,
IsNullable: true,
SQLType: sqltype.VarChar,
},
"email": {
SQLType: sqltype.VarChar,
},
},
UniqueContraints: []sqlschema.Unique{
{Columns: sqlschema.NewComposite("first_name", "last_name")},
{Columns: sqlschema.NewComposite("email")},
},
},
{
Schema: defaultSchema,
Expand Down Expand Up @@ -222,6 +232,9 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
IsNullable: true,
},
},
UniqueContraints: []sqlschema.Unique{
{Columns: sqlschema.NewComposite("publisher_id", "publisher_name")},
},
},
}

Expand Down Expand Up @@ -268,7 +281,7 @@ func mustCreateTableWithFKs(tb testing.TB, ctx context.Context, db *bun.DB, mode
for _, model := range models {
create := db.NewCreateTable().Model(model).WithForeignKeys()
_, err := create.Exec(ctx)
require.NoError(tb, err, "must create table %q:", create.GetTableName())
require.NoError(tb, err, "arrange: must create table %q:", create.GetTableName())
mustDropTableOnCleanup(tb, ctx, db, model)
}
}
Expand Down Expand Up @@ -303,9 +316,11 @@ func cmpTables(tb testing.TB, d sqlschema.InspectorDialect, want, got []sqlschem
}

cmpColumns(tb, d, wt.Name, wt.Columns, gt.Columns)
cmpConstraints(tb, wt, gt)
}
}

// cmpColumns compares that column definitions on the tables are
func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, want, got map[string]sqlschema.Column) {
tb.Helper()
var errs []string
Expand Down Expand Up @@ -362,6 +377,20 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w
}
}

// cmpConstraints compares constraints defined on the table with the expected ones.
func cmpConstraints(tb testing.TB, want, got sqlschema.Table) {
tb.Helper()

// Only keep columns included in each unique constraint for comparison.
stripNames := func(uniques []sqlschema.Unique) (res []string) {
for _, u := range uniques {
res = append(res, u.Columns.String())
}
return
}
require.ElementsMatch(tb, stripNames(want.UniqueContraints), stripNames(got.UniqueContraints), "table %q does not have expected unique constraints (listA=want, listB=got)", want.Name)
}

func tableNames(tables []sqlschema.Table) (names []string) {
for i := range tables {
names = append(names, tables[i].Name)
Expand Down Expand Up @@ -441,5 +470,31 @@ func TestSchemaInspector_Inspect(t *testing.T) {
require.Len(t, got.Tables, 1)
cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, got.Tables[0].Columns)
})

t.Run("inspect unique constraints", func(t *testing.T) {
type Model struct {
ID string `bun:",unique"`
FirstName string `bun:"first_name,unique:full_name"`
LastName string `bun:"last_name,unique:full_name"`
}

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewSchemaInspector(tables)

want := sqlschema.Table{
Name: "models",
UniqueContraints: []sqlschema.Unique{
{Columns: sqlschema.NewComposite("id")},
{Name: "full_name", Columns: sqlschema.NewComposite("first_name", "last_name")},
},
}

got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

require.Len(t, got.Tables, 1)
cmpConstraints(t, want, got.Tables[0])
})
})
}
Loading

0 comments on commit 3c4d5d2

Please sign in to comment.