Skip to content

Commit

Permalink
fix: make allowzero work with auto-detected primary keys
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Sep 18, 2021
1 parent 73ad6f5 commit 82ca87c
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 24 deletions.
32 changes: 18 additions & 14 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -853,27 +853,31 @@ func testScanBytes(t *testing.T, db *bun.DB) {

func testPointers(t *testing.T, db *bun.DB) {
type Model struct {
ID *int64 `bun:",allowzero,default:0"`
ID *int64 `bun:",default:0"`
Str *string
}

ctx := context.Background()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)
for _, id := range []int64{-1, 0, 1} {
err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

id := int64(1)
str := "hello"
models := []Model{
{},
{ID: &id, Str: &str},
}
_, err = db.NewInsert().Model(&models).Exec(ctx)
require.NoError(t, err)
var model Model
if id >= 0 {
str := "hello"
model.ID = &id
model.Str = &str

var models2 []Model
err = db.NewSelect().Model(&models2).Order("id ASC").Scan(ctx)
require.NoError(t, err)
}

_, err = db.NewInsert().Model(&model).Exec(ctx)
require.NoError(t, err)

var model2 Model
err = db.NewSelect().Model(&model2).Order("id ASC").Scan(ctx)
require.NoError(t, err)
}
}

func testExists(t *testing.T, db *bun.DB) {
Expand Down
6 changes: 6 additions & 0 deletions internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,12 @@ func TestQuery(t *testing.T) {
}
return db.NewInsert().Model(new(Model))
},
func(db *bun.DB) schema.QueryAppender {
type Model struct {
ID int `bun:",allowzero"`
}
return db.NewInsert().Model(new(Model))
},
}

timeRE := regexp.MustCompile(`'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d+(\+\d{2}:\d{2})?'`)
Expand Down
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-85
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO `models` (`id`) VALUES (0)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-85
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO `models` (`id`) VALUES (0)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-85
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("id") VALUES (0)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-85
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("id") VALUES (0)
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-85
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO "models" ("id") VALUES (0)
4 changes: 3 additions & 1 deletion schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
func (f *Field) markAsPK() {
f.IsPK = true
f.NotNull = true
f.NullZero = true
if !f.Tag.HasOption("allowzero") {
f.NullZero = true
}
}

func indexEqual(ind1, ind2 []int) bool {
Expand Down
18 changes: 9 additions & 9 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,17 @@ func (t *Table) initFields() {
t.FieldMap = make(map[string]*Field, t.Type.NumField())
t.addFields(t.Type, nil)

if len(t.PKs) > 0 {
return
}
for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
if field, ok := t.FieldMap[name]; ok {
field.markAsPK()
t.PKs = []*Field{field}
t.DataFields = removeField(t.DataFields, field)
break
if len(t.PKs) == 0 {
for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
if field, ok := t.FieldMap[name]; ok {
field.markAsPK()
t.PKs = []*Field{field}
t.DataFields = removeField(t.DataFields, field)
break
}
}
}

if len(t.PKs) == 1 {
pk := t.PKs[0]
if pk.SQLDefault != "" {
Expand Down

0 comments on commit 82ca87c

Please sign in to comment.