From 30818499eacddd3b1a3e749091ba6a1468125641 Mon Sep 17 00:00:00 2001 From: tim Date: Thu, 16 Jun 2022 23:11:20 +0500 Subject: [PATCH] fix: merge apply --- example/rel-join-condition/main.go | 19 ++++++++++++++----- query_select.go | 21 +++++++++++++++++---- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/example/rel-join-condition/main.go b/example/rel-join-condition/main.go index 021de1297..a132cf46b 100644 --- a/example/rel-join-condition/main.go +++ b/example/rel-join-condition/main.go @@ -21,7 +21,7 @@ type Profile struct { type User struct { ID int64 `bun:",pk,autoincrement"` Name string - Profiles []*Profile `bun:"rel:has-many,join:id=user_id,join_on:active IS TRUE,join_on:lang='ru'"` + Profiles []*Profile `bun:"rel:has-many,join:id=user_id,join_on: active IS TRUE"` } func main() { @@ -45,15 +45,23 @@ func main() { if err := db.NewSelect(). Model(user). Column("user.*"). - Relation("Profiles"). + Relation("Profiles", func(q *bun.SelectQuery) *bun.SelectQuery { + return q.Where("lang = 'ru'") + }). OrderExpr("user.id ASC"). Limit(1). Scan(ctx); err != nil { panic(err) } - fmt.Println(user.ID, user.Name, user.Profiles[0]) - // Output: 1 user &{2 ru true 1} + fmt.Printf("user.ID: %d, user.Name: %q\n", user.ID, user.Name) + fmt.Printf("user.Profiles: ") + for _, p := range user.Profiles { + fmt.Printf("%v, ", p) + } + fmt.Println() + // Output: user.ID: 1, user.Name: "user 1" + // user.Profiles: &{2 ru true 1}, } func createSchema(ctx context.Context, db *bun.DB) error { @@ -78,7 +86,8 @@ func createSchema(ctx context.Context, db *bun.DB) error { profiles := []*Profile{ {ID: 1, Lang: "en", Active: true, UserID: 1}, {ID: 2, Lang: "ru", Active: true, UserID: 1}, - {ID: 3, Lang: "md", Active: false, UserID: 1}, + {ID: 3, Lang: "ru", Active: false, UserID: 1}, + {ID: 4, Lang: "md", Active: false, UserID: 1}, } if _, err := db.NewInsert().Model(&profiles).Exec(ctx); err != nil { return err diff --git a/query_select.go b/query_select.go index a239a4fef..32f624c73 100644 --- a/query_select.go +++ b/query_select.go @@ -305,19 +305,32 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ return q } + var apply1, apply2 func(*SelectQuery) *SelectQuery + if len(join.Relation.Condition) > 0 { - apl := func(q *SelectQuery) *SelectQuery { + apply1 = func(q *SelectQuery) *SelectQuery { + for _, opt := range join.Relation.Condition { q.addWhere(schema.SafeQueryWithSep(opt, nil, " AND ")) } + return q } - - join.apply = apl } if len(apply) == 1 { - join.apply = apply[0] + apply2 = apply[0] + } + + join.apply = func(q *SelectQuery) *SelectQuery { + if apply1 != nil { + q = apply1(q) + } + if apply2 != nil { + q = apply2(q) + } + + return q } return q