diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go new file mode 100644 index 000000000..f8341fbf7 --- /dev/null +++ b/dialect/pgdialect/alter_table.go @@ -0,0 +1,26 @@ +package pgdialect + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uptrace/bun/migrate/sqlschema" +) + +func (d *Dialect) Migrator(sqldb *sql.DB) sqlschema.Migrator { + return &Migrator{sqldb: sqldb} +} + +type Migrator struct { + sqldb *sql.DB +} + +func (m *Migrator) RenameTable(ctx context.Context, oldName, newName string) error { + query := fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldName, newName) + _, err := m.sqldb.ExecContext(ctx, query) + if err != nil { + return err + } + return nil +} diff --git a/dialect/pgdialect/dialect.go b/dialect/pgdialect/dialect.go index 766aa1be4..022c74bd0 100644 --- a/dialect/pgdialect/dialect.go +++ b/dialect/pgdialect/dialect.go @@ -10,6 +10,7 @@ import ( "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/migrate/sqlschema" "github.com/uptrace/bun/schema" "github.com/uptrace/bun/schema/inspector" ) @@ -32,6 +33,7 @@ type Dialect struct { var _ schema.Dialect = (*Dialect)(nil) var _ inspector.Dialect = (*Dialect)(nil) +var _ sqlschema.MigratorDialect = (*Dialect)(nil) func New() *Dialect { d := new(Dialect) diff --git a/dialect/pgdialect/inspector.go b/dialect/pgdialect/inspector.go index 418140855..5aa2995f3 100644 --- a/dialect/pgdialect/inspector.go +++ b/dialect/pgdialect/inspector.go @@ -10,30 +10,38 @@ import ( "github.com/uptrace/bun/schema" ) -func (d *Dialect) Inspector(db *bun.DB) schema.Inspector { - return newDatabaseInspector(db) +func (d *Dialect) Inspector(db *bun.DB, excludeTables ...string) schema.Inspector { + return newInspector(db, excludeTables...) } -type DatabaseInspector struct { - db *bun.DB +type Inspector struct { + db *bun.DB + excludeTables []string } -var _ schema.Inspector = (*DatabaseInspector)(nil) +var _ schema.Inspector = (*Inspector)(nil) -func newDatabaseInspector(db *bun.DB) *DatabaseInspector { - return &DatabaseInspector{db: db} +func newInspector(db *bun.DB, excludeTables ...string) *Inspector { + return &Inspector{db: db, excludeTables: excludeTables} } -func (di *DatabaseInspector) Inspect(ctx context.Context) (schema.State, error) { +func (in *Inspector) Inspect(ctx context.Context) (schema.State, error) { var state schema.State + + exclude := in.excludeTables + if len(exclude) == 0 { + // Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice. + exclude = []string{""} + } + var tables []*InformationSchemaTable - if err := di.db.NewRaw(sqlInspectTables).Scan(ctx, &tables); err != nil { + if err := in.db.NewRaw(sqlInspectTables, bun.In(exclude)).Scan(ctx, &tables); err != nil { return state, err } for _, table := range tables { var columns []*InformationSchemaColumn - if err := di.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil { + if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil { return state, err } colDefs := make(map[string]schema.ColumnDef) @@ -105,6 +113,7 @@ FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema <> 'information_schema' AND table_schema NOT LIKE 'pg_%' + AND table_name NOT IN (?) ` // sqlInspectColumnsQuery retrieves column definitions for the specified table. diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index e93ce5843..c3ad08565 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -225,11 +225,12 @@ func testEachDB(t *testing.T, f func(t *testing.T, dbName string, db *bun.DB)) { } // testEachDialect allows testing dialect-specific functionality that does not require database interactions. -func testEachDialect(t *testing.T, f func(t *testing.T, dialectName string, dialect func() schema.Dialect)) { +func testEachDialect(t *testing.T, f func(t *testing.T, dialectName string, dialect schema.Dialect)) { for _, newDialect := range allDialects { - name := newDialect().Name().String() + d := newDialect() + name := d.Name().String() t.Run(name, func(t *testing.T) { - f(t, name, newDialect) + f(t, name, d) }) } } diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go index 9b092ef4d..bd17a8a32 100644 --- a/internal/dbtest/inspect_test.go +++ b/internal/dbtest/inspect_test.go @@ -20,7 +20,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) { Title string `bun:",notnull,unique:title_author"` Locale string `bun:",type:varchar(5),default:'en-GB'"` Pages int8 `bun:"page_count,notnull,default:1"` - Count int32 `bun:"book_count,autoincrement"` + Count int32 `bun:"book_count,autoincrement"` } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -34,7 +34,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) { } ctx := context.Background() - createTableOrSkip(t, ctx, db, (*Book)(nil)) + mustResetModel(t, ctx, db, (*Book)(nil)) dbInspector := dialect.Inspector(db) want := schema.State{ @@ -105,7 +105,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) { func getDatabaseInspectorOrSkip(tb testing.TB, db *bun.DB) schema.Inspector { dialect := db.Dialect() if id, ok := dialect.(inspector.Dialect); ok { - return id.Inspector(db) + return id.Inspector(db, migrationsTable, migrationLocksTable) } tb.Skipf("%q dialect does not implement inspector.Dialect", dialect.Name()) return nil diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index bab42e9b3..c9b5197cf 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -160,7 +160,8 @@ func testMigrateUpError(t *testing.T, db *bun.DB) { require.Equal(t, []string{"down2", "down1"}, history) } -func TestAutoMigrator_Migrate(t *testing.T) { +func TestAutoMigrator_Run(t *testing.T) { + tests := []struct { fn func(t *testing.T, db *bun.DB) }{ @@ -190,14 +191,17 @@ func testRenameTable(t *testing.T, db *bun.DB) { // Arrange ctx := context.Background() di := getDatabaseInspectorOrSkip(t, db) - createTableOrSkip(t, ctx, db, (*initial)(nil)) + mustResetModel(t, ctx, db, (*initial)(nil)) + mustDropTableOnCleanup(t, ctx, db, (*changed)(nil)) - m, err := migrate.NewAutoMigrator(db) + m, err := migrate.NewAutoMigrator(db, + migrate.WithTableNameAuto(migrationsTable), + migrate.WithLocksTableNameAuto(migrationLocksTable), + migrate.WithModel((*changed)(nil))) require.NoError(t, err) - m.SetModels((*changed)(nil)) // Act - err = m.Migrate(ctx) + err = m.Run(ctx) require.NoError(t, err) // Assert @@ -209,27 +213,13 @@ func testRenameTable(t *testing.T, db *bun.DB) { require.Equal(t, "changed", tables[0].Name) } -func createTableOrSkip(tb testing.TB, ctx context.Context, db *bun.DB, model interface{}) { - tb.Helper() - if _, err := db.NewCreateTable().IfNotExists().Model(model).Exec(ctx); err != nil { - tb.Skip("setup failed:", err) - } - tb.Cleanup(func() { - if _, err := db.NewDropTable().IfExists().Model(model).Exec(ctx); err != nil { - tb.Log("cleanup:", err) - } - }) -} - func TestDetector_Diff(t *testing.T) { tests := []struct { - name string - states func(testing.TB, context.Context, func() schema.Dialect) (stateDb schema.State, stateModel schema.State) + states func(testing.TB, context.Context, schema.Dialect) (stateDb schema.State, stateModel schema.State) operations []migrate.Operation }{ { - name: "find a renamed table", - states: renamedTableStates, + states: testDetectRenamedTable, operations: []migrate.Operation{ &migrate.RenameTable{ From: "books", @@ -239,13 +229,9 @@ func TestDetector_Diff(t *testing.T) { }, } - testEachDialect(t, func(t *testing.T, dialectName string, dialect func() schema.Dialect) { - if dialectName != "pg" { - t.Skip() - } - + testEachDialect(t, func(t *testing.T, dialectName string, dialect schema.Dialect) { for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(funcName(tt.states), func(t *testing.T) { ctx := context.Background() var d migrate.Detector stateDb, stateModel := tt.states(t, ctx, dialect) @@ -258,7 +244,7 @@ func TestDetector_Diff(t *testing.T) { }) } -func renamedTableStates(tb testing.TB, ctx context.Context, dialect func() schema.Dialect) (s1, s2 schema.State) { +func testDetectRenamedTable(tb testing.TB, ctx context.Context, dialect schema.Dialect) (s1, s2 schema.State) { type Book struct { bun.BaseModel @@ -279,17 +265,20 @@ func renamedTableStates(tb testing.TB, ctx context.Context, dialect func() schem Title string `bun:"title,notnull"` Pages int `bun:"page_count,notnull,default:0"` } - return getState(tb, ctx, dialect(), + return getState(tb, ctx, dialect, (*Author)(nil), (*Book)(nil), - ), getState(tb, ctx, dialect(), + ), getState(tb, ctx, dialect, (*Author)(nil), (*BookRenamed)(nil), ) } func getState(tb testing.TB, ctx context.Context, dialect schema.Dialect, models ...interface{}) schema.State { - inspector := schema.NewInspector(dialect, models...) + tables := schema.NewTables(dialect) + tables.Register(models...) + + inspector := schema.NewInspector(tables) state, err := inspector.Inspect(ctx) if err != nil { tb.Skip("get state: %w", err) diff --git a/migrate/auto.go b/migrate/auto.go index 8453e069d..fac7dcf97 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -3,62 +3,167 @@ package migrate import ( "context" "fmt" + "strings" "github.com/uptrace/bun" + "github.com/uptrace/bun/migrate/sqlschema" "github.com/uptrace/bun/schema" "github.com/uptrace/bun/schema/inspector" ) +type AutoMigratorOption func(m *AutoMigrator) + +// WithModel adds a bun.Model to the scope of migrations. +func WithModel(models ...interface{}) AutoMigratorOption { + return func(m *AutoMigrator) { + m.includeModels = append(m.includeModels, models...) + } +} + +// WithExcludeTable tells the AutoMigrator to ignore a table in the database. +func WithExcludeTable(tables ...string) AutoMigratorOption { + return func(m *AutoMigrator) { + m.excludeTables = append(m.excludeTables, tables...) + } +} + +// WithTableNameAuto overrides default migrations table name. +func WithTableNameAuto(table string) AutoMigratorOption { + return func(m *AutoMigrator) { + m.table = table + m.migratorOpts = append(m.migratorOpts, WithTableName(table)) + } +} + +// WithLocksTableNameAuto overrides default migration locks table name. +func WithLocksTableNameAuto(table string) AutoMigratorOption { + return func(m *AutoMigrator) { + m.locksTable = table + m.migratorOpts = append(m.migratorOpts, WithLocksTableName(table)) + } +} + +// WithMarkAppliedOnSuccessAuto sets the migrator to only mark migrations as applied/unapplied +// when their up/down is successful. +func WithMarkAppliedOnSuccessAuto(enabled bool) AutoMigratorOption { + return func(m *AutoMigrator) { + m.migratorOpts = append(m.migratorOpts, WithMarkAppliedOnSuccess(enabled)) + } +} + type AutoMigrator struct { db *bun.DB - // models limit the set of tables considered for the migration. - models []interface{} - // dbInspector creates the current state for the target database. dbInspector schema.Inspector // modelInspector creates the desired state based on the model definitions. modelInspector schema.Inspector + + // schemaMigrator executes ALTER TABLE queries. + schemaMigrator sqlschema.Migrator + + table string + locksTable string + + // includeModels define the migration scope. + includeModels []interface{} + + // excludeTables are excluded from database inspection. + excludeTables []string + + // migratorOpts are passed to Migrator constructor. + migratorOpts []MigratorOption } -func NewAutoMigrator(db *bun.DB) (*AutoMigrator, error) { +func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, error) { + am := &AutoMigrator{ + db: db, + table: defaultTable, + locksTable: defaultLocksTable, + } + + for _, opt := range opts { + opt(am) + } + am.excludeTables = append(am.excludeTables, am.table, am.locksTable) + + schemaMigrator, err := sqlschema.NewMigrator(db) + if err != nil { + return nil, err + } + am.schemaMigrator = schemaMigrator + + tables := schema.NewTables(db.Dialect()) + tables.Register(am.includeModels...) + am.modelInspector = schema.NewInspector(tables) + dialect := db.Dialect() - withInspector, ok := dialect.(inspector.Dialect) + inspectorDialect, ok := dialect.(inspector.Dialect) if !ok { return nil, fmt.Errorf("%q dialect does not implement inspector.Dialect", dialect.Name()) } + am.dbInspector = inspectorDialect.Inspector(db, am.excludeTables...) - return &AutoMigrator{ - db: db, - dbInspector: withInspector.Inspector(db), - }, nil -} - -func (am *AutoMigrator) SetModels(models ...interface{}) { - am.models = models + return am, nil } func (am *AutoMigrator) diff(ctx context.Context) (Changeset, error) { + var detector Detector var changes Changeset var err error - // TODO: do on "SetModels" - am.modelInspector = schema.NewInspector(am.db.Dialect(), am.models...) - - _, err = am.dbInspector.Inspect(ctx) + got, err := am.dbInspector.Inspect(ctx) if err != nil { return changes, err } - _, err = am.modelInspector.Inspect(ctx) + want, err := am.modelInspector.Inspect(ctx) if err != nil { return changes, err } - return changes, nil + return detector.Diff(got, want), nil +} + +// Migrate writes required changes to a new migration file and runs the migration. +// This will create and entry in the migrations table, making it possible to revert +// the changes with Migrator.Rollback(). +func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) error { + changeset, err := am.diff(ctx) + if err != nil { + return fmt.Errorf("auto migrate: %w", err) + } + + migrations := NewMigrations() + name, _ := genMigrationName("auto") + migrations.Add(Migration{ + Name: name, + Up: changeset.Up(am.schemaMigrator), + Down: changeset.Down(am.schemaMigrator), + Comment: "Changes detected by bun.migrate.AutoMigrator", + }) + + migrator := NewMigrator(am.db, migrations, am.migratorOpts...) + if err := migrator.Init(ctx); err != nil { + return fmt.Errorf("auto migrate: %w", err) + } + + if _, err := migrator.Migrate(ctx, opts...); err != nil { + return fmt.Errorf("auto migrate: %w", err) + } + return nil } -func (am *AutoMigrator) Migrate(ctx context.Context) error { +// Run runs required migrations in-place and without creating a database entry. +func (am *AutoMigrator) Run(ctx context.Context) error { + changeset, err := am.diff(ctx) + if err != nil { + return fmt.Errorf("run auto migrate: %w", err) + } + up := changeset.Up(am.schemaMigrator) + if err := up(ctx, am.db); err != nil { + return fmt.Errorf("run auto migrate: %w", err) + } return nil } @@ -68,7 +173,9 @@ func (am *AutoMigrator) Migrate(ctx context.Context) error { // Apart from storing the function to execute the change, // it knows how to *write* the corresponding code, and what the reverse operation is. type Operation interface { - Func() MigrationFunc + Func(sqlschema.Migrator) MigrationFunc + // GetReverse returns an operation that can revert the current one. + GetReverse() Operation } type RenameTable struct { @@ -76,10 +183,16 @@ type RenameTable struct { To string } -func (rt *RenameTable) Func() MigrationFunc { +func (rt *RenameTable) Func(m sqlschema.Migrator) MigrationFunc { return func(ctx context.Context, db *bun.DB) error { - db.Dialect() - return nil + return m.RenameTable(ctx, rt.From, rt.To) + } +} + +func (rt *RenameTable) GetReverse() Operation { + return &RenameTable{ + From: rt.To, + To: rt.From, } } @@ -88,6 +201,8 @@ type Changeset struct { operations []Operation } +var _ Operation = (*Changeset)(nil) + func (c Changeset) Operations() []Operation { return c.operations } @@ -96,6 +211,36 @@ func (c *Changeset) Add(op Operation) { c.operations = append(c.operations, op) } +func (c *Changeset) Func(m sqlschema.Migrator) MigrationFunc { + return func(ctx context.Context, db *bun.DB) error { + for _, op := range c.operations { + fn := op.Func(m) + if err := fn(ctx, db); err != nil { + return err + } + } + return nil + } +} + +func (c *Changeset) GetReverse() Operation { + var reverse Changeset + for _, op := range c.operations { + reverse.Add(op.GetReverse()) + } + return &reverse +} + +// Up is syntactic sugar. +func (c *Changeset) Up(m sqlschema.Migrator) MigrationFunc { + return c.Func(m) +} + +// Down is syntactic sugar. +func (c *Changeset) Down(m sqlschema.Migrator) MigrationFunc { + return c.GetReverse().Func(m) +} + type Detector struct{} func (d *Detector) Diff(got, want schema.State) Changeset { @@ -177,6 +322,17 @@ func (set tableSet) clone() tableSet { return res } +func (set tableSet) String() string { + var s strings.Builder + for k := range set.underlying { + if s.Len() > 0 { + s.WriteString(", ") + } + s.WriteString(k) + } + return s.String() +} + // signature is a set of column definitions, which allows "relation/name-agnostic" comparison between them; // meaning that two columns are considered equal if their types are the same. type signature struct { diff --git a/migrate/migrator.go b/migrate/migrator.go index e6d70e39f..b14ad64ca 100644 --- a/migrate/migrator.go +++ b/migrate/migrator.go @@ -12,14 +12,21 @@ import ( "github.com/uptrace/bun" ) +const ( + defaultTable = "bun_migrations" + defaultLocksTable = "bun_migration_locks" +) + type MigratorOption func(m *Migrator) +// WithTableName overrides default migrations table name. func WithTableName(table string) MigratorOption { return func(m *Migrator) { m.table = table } } +// WithLocksTableName overrides default migration locks table name. func WithLocksTableName(table string) MigratorOption { return func(m *Migrator) { m.locksTable = table @@ -27,7 +34,7 @@ func WithLocksTableName(table string) MigratorOption { } // WithMarkAppliedOnSuccess sets the migrator to only mark migrations as applied/unapplied -// when their up/down is successful +// when their up/down is successful. func WithMarkAppliedOnSuccess(enabled bool) MigratorOption { return func(m *Migrator) { m.markAppliedOnSuccess = enabled @@ -52,8 +59,8 @@ func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Mi ms: migrations.ms, - table: "bun_migrations", - locksTable: "bun_migration_locks", + table: defaultTable, + locksTable: defaultLocksTable, } for _, opt := range opts { opt(m) @@ -246,7 +253,7 @@ func (m *Migrator) CreateGoMigration( opt(cfg) } - name, err := m.genMigrationName(name) + name, err := genMigrationName(name) if err != nil { return nil, err } @@ -329,7 +336,7 @@ func (m *Migrator) createSQL(ctx context.Context, fname string, transactional bo var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`) -func (m *Migrator) genMigrationName(name string) (string, error) { +func genMigrationName(name string) (string, error) { const timeFormat = "20060102150405" if name == "" { diff --git a/migrate/sqlschema/migrate.go b/migrate/sqlschema/migrate.go new file mode 100644 index 000000000..0b51e2b02 --- /dev/null +++ b/migrate/sqlschema/migrate.go @@ -0,0 +1,33 @@ +package sqlschema + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uptrace/bun" + "github.com/uptrace/bun/schema" +) + +type MigratorDialect interface { + schema.Dialect + Migrator(*sql.DB) Migrator +} + +type Migrator interface { + RenameTable(ctx context.Context, oldName, newName string) error +} + +type migrator struct { + Migrator +} + +func NewMigrator(db *bun.DB) (Migrator, error) { + md, ok := db.Dialect().(MigratorDialect) + if !ok { + return nil, fmt.Errorf("%q dialect does not implement sqlschema.Migrator", db.Dialect().Name()) + } + return &migrator{ + Migrator: md.Migrator(db.DB), + }, nil +} diff --git a/schema/inspector.go b/schema/inspector.go index 464cfa81f..e767742a8 100644 --- a/schema/inspector.go +++ b/schema/inspector.go @@ -14,8 +14,8 @@ type State struct { } type TableDef struct { - Schema string - Name string + Schema string + Name string Columns map[string]ColumnDef } @@ -29,48 +29,39 @@ type ColumnDef struct { } type SchemaInspector struct { - dialect Dialect + tables *Tables } var _ Inspector = (*SchemaInspector)(nil) -func NewInspector(dialect Dialect, models ...interface{}) *SchemaInspector { - dialect.Tables().Register(models...) +func NewInspector(tables *Tables) *SchemaInspector { return &SchemaInspector{ - dialect: dialect, + tables: tables, } } +// Inspect creates the current project state from the passed bun.Models. +// Do not recycle SchemaInspector for different sets of models, as older models will not be de-registerred before the next run. func (si *SchemaInspector) Inspect(ctx context.Context) (State, error) { var state State - for _, t := range si.dialect.Tables().All() { + for _, t := range si.tables.All() { columns := make(map[string]ColumnDef) for _, f := range t.Fields { columns[f.Name] = ColumnDef{ - SQLType: f.CreateTableSQLType, - DefaultValue: f.SQLDefault, - IsPK: f.IsPK, - IsNullable: !f.NotNull, + SQLType: strings.ToLower(f.CreateTableSQLType), + DefaultValue: f.SQLDefault, + IsPK: f.IsPK, + IsNullable: !f.NotNull, IsAutoIncrement: f.AutoIncrement, - IsIdentity: f.Identity, + IsIdentity: f.Identity, } } - schema, table := splitTableNameTag(si.dialect, t.Name) state.Tables = append(state.Tables, TableDef{ - Schema: schema, - Name: table, + Schema: t.Schema, + Name: t.Name, Columns: columns, }) } return state, nil } - -// splitTableNameTag -func splitTableNameTag(d Dialect, nameTag string) (string, string) { - schema, table := d.DefaultSchema(), nameTag - if schemaTable := strings.Split(nameTag, "."); len(schemaTable) == 2 { - schema, table = schemaTable[0], schemaTable[1] - } - return schema, table -} \ No newline at end of file diff --git a/schema/inspector/dialect.go b/schema/inspector/dialect.go index 701300da9..beb23eea5 100644 --- a/schema/inspector/dialect.go +++ b/schema/inspector/dialect.go @@ -7,5 +7,5 @@ import ( type Dialect interface { schema.Dialect - Inspector(db *bun.DB) schema.Inspector + Inspector(db *bun.DB, excludeTables ...string) schema.Inspector } diff --git a/schema/table.go b/schema/table.go index c8e71e38f..f770a2f76 100644 --- a/schema/table.go +++ b/schema/table.go @@ -45,6 +45,7 @@ type Table struct { TypeName string ModelName string + Schema string Name string SQLName Safe SQLNameForSelects Safe @@ -371,10 +372,18 @@ func (t *Table) processBaseModelField(f reflect.StructField) { } if tag.Name != "" { + schema, _ := t.schemaFromTagName(tag.Name) + t.Schema = schema + + // Eventually, we should only assign the "table" portion as the table name, + // which will also require a change in how the table name is appended to queries. + // Until that is done, set table name to tag.Name. t.setName(tag.Name) } if s, ok := tag.Option("table"); ok { + schema, _ := t.schemaFromTagName(tag.Name) + t.Schema = schema t.setName(s) } @@ -388,6 +397,17 @@ func (t *Table) processBaseModelField(f reflect.StructField) { } } +// schemaFromTagName splits the bun.BaseModel tag name into schema and table name +// in case it is specified in the "schema"."table" format. +// Assume default schema if one isn't explicitly specified. +func (t *Table) schemaFromTagName(name string) (string, string) { + schema, table := t.dialect.DefaultSchema(), name + if schemaTable := strings.Split(name, "."); len(schemaTable) == 2 { + schema, table = schemaTable[0], schemaTable[1] + } + return schema, table +} + // nolint func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field { sqlName := internal.Underscore(sf.Name)