Skip to content

Commit

Permalink
feat: [#280] Add change table feature (#671)
Browse files Browse the repository at this point in the history
* feat: [#280] Add change table feature

* fix test

* remove fmt
  • Loading branch information
hwbrzzl authored Oct 5, 2024
1 parent fe51fc9 commit 90b2733
Show file tree
Hide file tree
Showing 16 changed files with 501 additions and 20 deletions.
2 changes: 2 additions & 0 deletions contracts/database/migration/blueprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ type Blueprint interface {
DropIfExists()
// GetAddedColumns Get the added columns.
GetAddedColumns() []ColumnDefinition
// GetChangedColumns Get the changed columns.
GetChangedColumns() []ColumnDefinition
// GetTableName Get the table name with prefix.
GetTableName() string
// HasCommand Determine if the blueprint has a specific command.
Expand Down
4 changes: 4 additions & 0 deletions contracts/database/migration/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package migration
type ColumnDefinition interface {
// AutoIncrement set the column as auto increment
AutoIncrement() ColumnDefinition
// Change the column
Change()
// GetAutoIncrement returns the autoIncrement value
GetAutoIncrement() bool
// GetChange returns the change value
Expand All @@ -17,6 +19,8 @@ type ColumnDefinition interface {
GetNullable() bool
// GetType returns the type value
GetType() string
// Nullable allow NULL values to be inserted into the column
Nullable() ColumnDefinition
// Unsigned set the column as unsigned
Unsigned() ColumnDefinition
}
4 changes: 4 additions & 0 deletions contracts/database/migration/grammar.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
)

type Grammar interface {
// CompileAdd Compile an add column command.
CompileAdd(blueprint Blueprint) string
// CompileChange Compile a change column command into a series of SQL statements.
CompileChange(blueprint Blueprint) string
// CompileCreate Compile a create table command.
CompileCreate(blueprint Blueprint, query orm.Query) string
// CompileDropIfExists Compile a drop table (if exists) command.
Expand Down
2 changes: 1 addition & 1 deletion contracts/database/migration/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type Schema interface {
// Sql Execute a sql directly.
Sql(sql string)
// Table Modify a table on the schema.
//Table(table string, callback func(table Blueprint))
Table(table string, callback func(table Blueprint)) error
}

type Migration interface {
Expand Down
8 changes: 4 additions & 4 deletions database/migration/blueprint.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package migration

import (
"fmt"

"github.com/goravel/framework/contracts/database/migration"
ormcontract "github.com/goravel/framework/contracts/database/orm"
"github.com/goravel/framework/support/convert"
Expand Down Expand Up @@ -48,8 +46,6 @@ func (r *Blueprint) BigInteger(column string) migration.ColumnDefinition {

func (r *Blueprint) Build(query ormcontract.Query, grammar migration.Grammar) error {
for _, sql := range r.ToSql(query, grammar) {
// TODO remove
fmt.Println("sql:", sql)
if _, err := query.Exec(sql); err != nil {
return err
}
Expand Down Expand Up @@ -152,6 +148,10 @@ func (r *Blueprint) ToSql(query ormcontract.Query, grammar migration.Grammar) []
var statements []string
for _, command := range r.commands {
switch command.Name {
case commandAdd:
statements = append(statements, grammar.CompileAdd(r))
case commandChange:
statements = append(statements, grammar.CompileChange(r))
case commandCreate:
statements = append(statements, grammar.CompileCreate(r, query))
case commandDropIfExists:
Expand Down
10 changes: 10 additions & 0 deletions database/migration/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ func (r *ColumnDefinition) AutoIncrement() migration.ColumnDefinition {
return r
}

func (r *ColumnDefinition) Change() {
r.change = convert.Pointer(true)
}

func (r *ColumnDefinition) GetAutoIncrement() (autoIncrement bool) {
if r.autoIncrement != nil {
return *r.autoIncrement
Expand Down Expand Up @@ -75,6 +79,12 @@ func (r *ColumnDefinition) GetType() (ttype string) {
return
}

func (r *ColumnDefinition) Nullable() migration.ColumnDefinition {
r.nullable = convert.Pointer(true)

return r
}

func (r *ColumnDefinition) Unsigned() migration.ColumnDefinition {
r.unsigned = convert.Pointer(true)

Expand Down
25 changes: 25 additions & 0 deletions database/migration/grammars/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,31 @@ func NewPostgres() *Postgres {
return postgres
}

func (r *Postgres) CompileAdd(blueprint migration.Blueprint) string {
return fmt.Sprintf("alter table %s %s", blueprint.GetTableName(), strings.Join(prefixArray("add column", getColumns(r, blueprint)), ","))
}

func (r *Postgres) CompileChange(blueprint migration.Blueprint) string {
var columns []string
for _, column := range blueprint.GetChangedColumns() {
var changes []string

for _, modifier := range r.modifiers {
if change := modifier(blueprint, column); change != "" {
changes = append(changes, change)
}
}

columns = append(columns, strings.Join(prefixArray("alter column "+column.GetName(), changes), ", "))
}

if len(columns) == 0 {
return ""
}

return fmt.Sprintf("alter table %s %s", blueprint.GetTableName(), strings.Join(columns, ", "))
}

func (r *Postgres) CompileCreate(blueprint migration.Blueprint, query orm.Query) string {
return fmt.Sprintf("create table %s (%s)", blueprint.GetTableName(), strings.Join(getColumns(r, blueprint), ","))
}
Expand Down
113 changes: 98 additions & 15 deletions database/migration/grammars/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,122 @@ func TestPostgresSuite(t *testing.T) {
}

func (s *PostgresSuite) SetupTest() {
postgres := &Postgres{
attributeCommands: []string{"comment"},
serials: []string{"bigInteger"},
}
postgres.modifiers = []func(contractsmigration.Blueprint, contractsmigration.ColumnDefinition) string{
postgres.ModifyDefault,
s.grammar = NewPostgres()
}

func (s *PostgresSuite) TestCompileChange() {
var (
mockBlueprint *mocksmigration.Blueprint
mockColumn1 *mocksmigration.ColumnDefinition
mockColumn2 *mocksmigration.ColumnDefinition
)

tests := []struct {
name string
setup func()
expectSql string
}{
{
name: "no changes",
setup: func() {
mockBlueprint.EXPECT().GetChangedColumns().Return([]contractsmigration.ColumnDefinition{}).Once()
},
},
{
name: "single change",
setup: func() {
mockColumn1.EXPECT().GetAutoIncrement().Return(false).Once()
mockColumn1.EXPECT().GetDefault().Return("goravel").Twice()
mockColumn1.EXPECT().GetName().Return("name").Once()
mockColumn1.EXPECT().GetChange().Return(true).Times(3)
mockColumn1.EXPECT().GetNullable().Return(true).Once()
mockBlueprint.EXPECT().GetTableName().Return("users").Once()
mockBlueprint.EXPECT().GetChangedColumns().Return([]contractsmigration.ColumnDefinition{mockColumn1}).Once()
},
expectSql: "alter table users alter column name set default 'goravel', alter column name drop not null",
},
{
name: "multiple changes",
setup: func() {
mockColumn1.EXPECT().GetAutoIncrement().Return(false).Once()
mockColumn1.EXPECT().GetDefault().Return("goravel").Twice()
mockColumn1.EXPECT().GetName().Return("name").Once()
mockColumn1.EXPECT().GetChange().Return(true).Times(3)
mockColumn1.EXPECT().GetNullable().Return(true).Once()
mockColumn2.EXPECT().GetAutoIncrement().Return(false).Once()
mockColumn2.EXPECT().GetDefault().Return(1).Twice()
mockColumn2.EXPECT().GetName().Return("age").Once()
mockColumn2.EXPECT().GetChange().Return(true).Times(3)
mockColumn2.EXPECT().GetNullable().Return(false).Once()
mockBlueprint.EXPECT().GetTableName().Return("users").Once()
mockBlueprint.EXPECT().GetChangedColumns().Return([]contractsmigration.ColumnDefinition{mockColumn1, mockColumn2}).Once()
},
expectSql: "alter table users alter column name set default 'goravel', alter column name drop not null, alter column age set default '1', alter column age set not null",
},
}

s.grammar = postgres
for _, test := range tests {
s.Run(test.name, func() {
mockBlueprint = mocksmigration.NewBlueprint(s.T())
mockColumn1 = mocksmigration.NewColumnDefinition(s.T())
mockColumn2 = mocksmigration.NewColumnDefinition(s.T())

test.setup()

sql := s.grammar.CompileChange(mockBlueprint)

s.Equal(test.expectSql, sql)
})
}
}

func (s *PostgresSuite) TestCompileCreate() {
mockColumn1 := mocksmigration.NewColumnDefinition(s.T())
mockColumn2 := mocksmigration.NewColumnDefinition(s.T())
mockBlueprint := mocksmigration.NewBlueprint(s.T())

// postgres.go::CompileCreate
mockBlueprint.EXPECT().GetTableName().Return("users").Once()
// utils.go::getColumns
mockBlueprint.EXPECT().GetAddedColumns().Return([]contractsmigration.ColumnDefinition{
mockColumn1, mockColumn2,
}).Once()
// utils.go::getColumns
mockColumn1.EXPECT().GetName().Return("id").Once()
// utils.go::getType
mockColumn1.EXPECT().GetType().Return("integer").Once()
// postgres.go::TypeInteger
mockColumn1.EXPECT().GetAutoIncrement().Return(true).Once()
// postgres.go::ModifyDefault
mockColumn1.EXPECT().GetChange().Return(false).Once()
mockColumn1.EXPECT().GetDefault().Return(nil).Once()
// postgres.go::ModifyIncrement
mockColumn1.EXPECT().GetChange().Return(false).Once()
mockBlueprint.EXPECT().HasCommand("primary").Return(false).Once()
mockColumn1.EXPECT().GetType().Return("integer").Once()
mockColumn1.EXPECT().GetAutoIncrement().Return(true).Once()
// postgres.go::ModifyNullable
mockColumn1.EXPECT().GetChange().Return(false).Once()
mockColumn1.EXPECT().GetNullable().Return(false).Once()

mockColumn2 := mocksmigration.NewColumnDefinition(s.T())
// utils.go::getColumns
mockColumn2.EXPECT().GetName().Return("name").Once()
// utils.go::getType
mockColumn2.EXPECT().GetType().Return("string").Once()
// postgres.go::TypeString
mockColumn2.EXPECT().GetLength().Return(100).Once()
// postgres.go::ModifyDefault
mockColumn2.EXPECT().GetChange().Return(false).Once()
mockColumn2.EXPECT().GetDefault().Return(nil).Once()
// postgres.go::ModifyIncrement
mockColumn2.EXPECT().GetChange().Return(false).Once()
mockBlueprint.EXPECT().HasCommand("primary").Return(false).Once()
mockColumn2.EXPECT().GetType().Return("string").Once()
// postgres.go::ModifyNullable
mockColumn2.EXPECT().GetChange().Return(false).Once()
mockColumn2.EXPECT().GetNullable().Return(true).Once()

mockBlueprint := mocksmigration.NewBlueprint(s.T())
mockBlueprint.EXPECT().GetTableName().Return("users").Once()
mockBlueprint.EXPECT().GetAddedColumns().Return([]contractsmigration.ColumnDefinition{
mockColumn1, mockColumn2,
}).Once()

s.Equal("create table users (id serial,name varchar(100))",
s.Equal("create table users (id serial primary key not null,name varchar(100) null)",
s.grammar.CompileCreate(mockBlueprint, nil))
}

Expand Down
8 changes: 8 additions & 0 deletions database/migration/grammars/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,11 @@ func getType(grammar migration.Grammar, column migration.ColumnDefinition) strin

return ""
}

func prefixArray(prefix string, values []string) []string {
for i, value := range values {
values[i] = prefix + " " + value
}

return values
}
5 changes: 5 additions & 0 deletions database/migration/grammars/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ func TestGetType(t *testing.T) {

assert.Empty(t, getType(mockGrammar1, mockColumn1))
}

func TestPrefixArray(t *testing.T) {
values := []string{"a", "b", "c"}
assert.Equal(t, []string{"prefix a", "prefix b", "prefix c"}, prefixArray("prefix", values))
}
8 changes: 8 additions & 0 deletions database/migration/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ func (r *Schema) Sql(sql string) {
_, _ = r.orm.Connection(r.connection).Query().Exec(sql)
}

func (r *Schema) Table(table string, callback func(table migration.Blueprint)) error {
blueprint := r.createBlueprint(table)
callback(blueprint)

// TODO catch error and rollback
return r.build(blueprint)
}

func (r *Schema) build(blueprint migration.Blueprint) error {
return blueprint.Build(r.orm.Connection(r.connection).Query(), r.grammar)
}
Expand Down
67 changes: 67 additions & 0 deletions database/migration/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,73 @@ func (s *SchemaSuite) TestDropIfExists() {
}
}

func (s *SchemaSuite) TestTable() {
for driver, testQuery := range s.driverToTestQuery {
s.Run(driver.String(), func() {
schema, mockOrm := initSchema(s.T(), testQuery)

mockOrm.EXPECT().Connection(schema.connection).Return(mockOrm).Once()
mockOrm.EXPECT().Query().Return(testQuery.Query()).Times(3)

err := schema.Create("changes", func(table migration.Blueprint) {
table.String("name")
})
s.NoError(err)
s.True(schema.HasTable("changes"))

tables, err := schema.GetTables()
s.NoError(err)
s.Greater(len(tables), 0)

// Open this after implementing other methods
//s.Require().True(schema.HasColumn("changes", "name"))
//columns, err := schema.GetColumns("changes")
//s.Require().Nil(err)
//for _, column := range columns {
// if column.Name == "name" {
// s.False(column.AutoIncrement)
// s.Empty(column.Collation)
// s.Empty(column.Comment)
// s.Empty(column.Default)
// s.False(column.Nullable)
// s.Equal("character varying(255)", column.Type)
// s.Equal("varchar", column.TypeName)
// }
//}
//
//err = schema.Table("changes", func(table migration.Blueprint) {
// table.Integer("age")
// table.String("name").Comment("This is a name column").Default("goravel").Change()
//})
//s.Nil(err)
//s.True(schema.HasTable("changes"))
//s.Require().True(schema.HasColumns("changes", []string{"name", "age"}))
//columns, err = schema.GetColumns("changes")
//s.Require().Nil(err)
//for _, column := range columns {
// if column.Name == "name" {
// s.False(column.AutoIncrement)
// s.Empty(column.Collation)
// s.Equal("This is a name column", column.Comment)
// s.Equal("'goravel'::character varying", column.Default)
// s.False(column.Nullable)
// s.Equal("character varying(255)", column.Type)
// s.Equal("varchar", column.TypeName)
// }
// if column.Name == "age" {
// s.False(column.AutoIncrement)
// s.Empty(column.Collation)
// s.Empty(column.Comment)
// s.Empty(column.Default)
// s.False(column.Nullable)
// s.Equal("integer", column.Type)
// s.Equal("int4", column.TypeName)
// }
//}
})
}
}

func initSchema(t *testing.T, testQuery *gorm.TestQuery) (*Schema, *mocksorm.Orm) {
mockOrm := mocksorm.NewOrm(t)
schema := NewSchema(testQuery.MockConfig(), testQuery.Docker().Driver().String(), nil, mockOrm)
Expand Down
Loading

0 comments on commit 90b2733

Please sign in to comment.