Skip to content

Commit

Permalink
added PreFragment and PostFragment support
Browse files Browse the repository at this point in the history
  • Loading branch information
ganigeorgiev committed Dec 10, 2024
1 parent 40a0b32 commit f18776a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
26 changes: 26 additions & 0 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type SelectQuery struct {
ctx context.Context
buildHook BuildHookFunc

preFragment string
postFragment string
selects []string
distinct bool
selectOption string
Expand Down Expand Up @@ -88,6 +90,18 @@ func (q *SelectQuery) WithContext(ctx context.Context) *SelectQuery {
return q
}

// PreFragment sets SQL fragment that should be prepended before the select query (e.g. WITH clause).
func (s *SelectQuery) PreFragment(fragment string) *SelectQuery {
s.preFragment = fragment
return s
}

// PostFragment sets SQL fragment that should be appended at the end of the select query.
func (s *SelectQuery) PostFragment(fragment string) *SelectQuery {
s.postFragment = fragment
return s
}

// Select specifies the columns to be selected.
// Column names will be automatically quoted.
func (s *SelectQuery) Select(cols ...string) *SelectQuery {
Expand Down Expand Up @@ -265,13 +279,15 @@ func (s *SelectQuery) Build() *Query {
qb := s.builder.QueryBuilder()

clauses := []string{
s.preFragment,
qb.BuildSelect(s.selects, s.distinct, s.selectOption),
qb.BuildFrom(s.from),
qb.BuildJoin(s.join, params),
qb.BuildWhere(s.where, params),
qb.BuildGroupBy(s.groupBy),
qb.BuildHaving(s.having, params),
}

sql := ""
for _, clause := range clauses {
if clause != "" {
Expand All @@ -282,7 +298,13 @@ func (s *SelectQuery) Build() *Query {
}
}
}

sql = qb.BuildOrderByAndLimit(sql, s.orderBy, s.limit, s.offset)

if s.postFragment != "" {
sql += " " + s.postFragment
}

if union := qb.BuildUnion(s.union, params); union != "" {
sql = fmt.Sprintf("(%v) %v", sql, union)
}
Expand Down Expand Up @@ -377,6 +399,8 @@ func (s *SelectQuery) Column(a interface{}) error {

// QueryInfo represents a debug/info struct with exported SelectQuery fields.
type QueryInfo struct {
PreFragment string
PostFragment string
Builder Builder
Selects []string
Distinct bool
Expand All @@ -400,6 +424,8 @@ type QueryInfo struct {
func (s *SelectQuery) Info() *QueryInfo {
return &QueryInfo{
Builder: s.builder,
PreFragment: s.preFragment,
PostFragment: s.postFragment,
Selects: s.selects,
Distinct: s.distinct,
SelectOption: s.selectOption,
Expand Down
12 changes: 7 additions & 5 deletions select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ func TestSelectQuery(t *testing.T) {

// a full select query
q = db.Select("id", "name").
PreFragment("pre").
PostFragment("post").
AndSelect("age").
Distinct(true).
SelectOption("CALC").
Expand All @@ -46,18 +48,18 @@ func TestSelectQuery(t *testing.T) {
AndBind(Params{"age": 30}).
Build()

expected = "SELECT DISTINCT CALC `id`, `name`, `age` FROM `users` INNER JOIN `profile` ON user.id=profile.id LEFT JOIN `team` RIGHT JOIN `dept` WHERE ((age>30) AND (status=1)) OR (type=2) GROUP BY `id`, `age` HAVING ((id>10) AND (id<20)) OR (type=3) ORDER BY `age` DESC, `type`, `id` LIMIT 10 OFFSET 20"
expected = "pre SELECT DISTINCT CALC `id`, `name`, `age` FROM `users` INNER JOIN `profile` ON user.id=profile.id LEFT JOIN `team` RIGHT JOIN `dept` WHERE ((age>30) AND (status=1)) OR (type=2) GROUP BY `id`, `age` HAVING ((id>10) AND (id<20)) OR (type=3) ORDER BY `age` DESC, `type`, `id` LIMIT 10 OFFSET 20 post"
assert.Equal(t, q.SQL(), expected, "t3")
assert.Equal(t, len(q.Params()), 2, "t4")

q3 := db.Select().AndBind(Params{"id": 1}).Build()
assert.Equal(t, len(q3.Params()), 1)

// union
q1 := db.Select().From("users").Build()
q2 := db.Select().From("posts").Build()
q = db.Select().From("profiles").Union(q1).UnionAll(q2).Build()
expected = "(SELECT * FROM `profiles`) UNION (SELECT * FROM `users`) UNION ALL (SELECT * FROM `posts`)"
q1 := db.Select().From("users").PreFragment("pre_q1").Build()
q2 := db.Select().From("posts").PostFragment("post_q2").Build()
q = db.Select().From("profiles").Union(q1).UnionAll(q2).PreFragment("pre").PostFragment("post").Build()
expected = "(pre SELECT * FROM `profiles` post) UNION (pre_q1 SELECT * FROM `users`) UNION ALL (SELECT * FROM `posts` post_q2)"
assert.Equal(t, q.SQL(), expected, "t5")
}

Expand Down

0 comments on commit f18776a

Please sign in to comment.