Skip to content

Commit

Permalink
Merge pull request #307 from upper/issue-297
Browse files Browse the repository at this point in the history
Honor omitempty on InsertInto(). Closes #297
  • Loading branch information
José Carlos authored Dec 13, 2016
2 parents 8232c84 + f611a71 commit 51cf352
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 45 deletions.
8 changes: 3 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@ DB_HOST ?= 127.0.0.1
export DB_HOST

test:
go test -v -benchtime=500ms -bench=. ./lib/... & \
go test -v -benchtime=500ms -bench=. ./internal/... & \
wait && \
go test -v -benchtime=500ms -bench=. ./lib/... && \
go test -v -benchtime=500ms -bench=. ./internal/... && \
for ADAPTER in postgresql mysql sqlite ql mongo; do \
$(MAKE) -C $$ADAPTER test & \
$(MAKE) -C $$ADAPTER test; \
done && \
wait && \
go test -v
6 changes: 3 additions & 3 deletions lib/sqlbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql
for i := 0; i < l; i++ {
switch v := columns[i].(type) {
case *selector:
expanded, rawArgs := expandPlaceholders(v.statement().Compile(v.stringer.t), v.Arguments())
expanded, rawArgs := Preprocess(v.statement().Compile(v.stringer.t), v.Arguments())
f[i] = exql.RawValue(expanded)
args = append(args, rawArgs...)
case db.Function:
Expand All @@ -336,11 +336,11 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql
} else {
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
}
expanded, fnArgs := expandPlaceholders(fnName, fnArgs)
expanded, fnArgs := Preprocess(fnName, fnArgs)
f[i] = exql.RawValue(expanded)
args = append(args, fnArgs...)
case db.RawValue:
expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments())
expanded, rawArgs := Preprocess(v.Raw(), v.Arguments())
f[i] = exql.RawValue(expanded)
args = append(args, rawArgs...)
case exql.Fragment:
Expand Down
83 changes: 83 additions & 0 deletions lib/sqlbuilder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,89 @@ func TestInsert(t *testing.T) {
)
}

{
type artistStruct struct {
ID int `db:"id,omitempty"`
Name string `db:"name,omitempty"`
}

assert.Equal(
`INSERT INTO "artist" ("name") VALUES ($1)`,
b.InsertInto("artist").
Values(artistStruct{Name: "Chavela Vargas"}).
String(),
)

assert.Equal(
`INSERT INTO "artist" ("id") VALUES ($1)`,
b.InsertInto("artist").
Values(artistStruct{ID: 1}).
String(),
)
}

{
type artistStruct struct {
ID int `db:"id,omitempty"`
Name string `db:"name,omitempty"`
}

{
q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"})

assert.Equal(
`INSERT INTO "artist" ("name") VALUES ($1)`,
q.String(),
)
assert.Equal(
[]interface{}{"Chavela Vargas"},
q.Arguments(),
)
}

{
q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"}).Values(artistStruct{Name: "Alondra de la Parra"})

assert.Equal(
`INSERT INTO "artist" ("name") VALUES ($1), ($2)`,
q.String(),
)
assert.Equal(
[]interface{}{"Chavela Vargas", "Alondra de la Parra"},
q.Arguments(),
)
}

{
q := b.InsertInto("artist").Values(artistStruct{ID: 1})

assert.Equal(
`INSERT INTO "artist" ("id") VALUES ($1)`,
q.String(),
)

assert.Equal(
[]interface{}{1},
q.Arguments(),
)
}

{
q := b.InsertInto("artist").Values(artistStruct{ID: 1}).Values(artistStruct{ID: 2})

assert.Equal(
`INSERT INTO "artist" ("id") VALUES ($1), ($2)`,
q.String(),
)

assert.Equal(
[]interface{}{1, 2},
q.Arguments(),
)
}

}

{
intRef := func(i int) *int {
if i == 0 {
Expand Down
13 changes: 4 additions & 9 deletions lib/sqlbuilder/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ func Preprocess(in string, args []interface{}) (string, []interface{}) {
return expandQuery(in, args, preprocessFn)
}

func expandPlaceholders(in string, args []interface{}) (string, []interface{}) {
// TODO: Remove after immutable query builder
return in, args
}

// ToWhereWithArguments converts the given parameters into a exql.Where
// value.
func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) {
Expand All @@ -93,7 +88,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.
if len(t) > 0 {
if s, ok := t[0].(string); ok {
if strings.ContainsAny(s, "?") || len(t) == 1 {
s, args = expandPlaceholders(s, t[1:])
s, args = Preprocess(s, t[1:])
where.Conditions = []exql.Fragment{exql.RawValue(s)}
} else {
var val interface{}
Expand Down Expand Up @@ -122,7 +117,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.
}
return
case db.RawValue:
r, v := expandPlaceholders(t.Raw(), t.Arguments())
r, v := Preprocess(t.Raw(), t.Arguments())
where.Conditions = []exql.Fragment{exql.RawValue(r)}
args = append(args, v...)
return
Expand Down Expand Up @@ -294,11 +289,11 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal
// A function with one or more arguments.
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
}
expanded, fnArgs := expandPlaceholders(fnName, fnArgs)
expanded, fnArgs := Preprocess(fnName, fnArgs)
columnValue.Value = exql.RawValue(expanded)
args = append(args, fnArgs...)
case db.RawValue:
expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments())
expanded, rawArgs := Preprocess(value.Raw(), value.Arguments())
columnValue.Value = exql.RawValue(expanded)
args = append(args, rawArgs...)
default:
Expand Down
104 changes: 78 additions & 26 deletions lib/sqlbuilder/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ package sqlbuilder

import (
"database/sql"
"sync"

"upper.io/db.v2/internal/sqladapter/exql"
)

type inserter struct {
*stringer
builder *sqlBuilder
table string
values []*exql.Values
builder *sqlBuilder
table string

enqueuedValues [][]interface{}
mu sync.Mutex

returning []exql.Fragment
columns []exql.Fragment
arguments []interface{}
Expand All @@ -28,6 +32,7 @@ func (qi *inserter) Batch(n int) *BatchInserter {
}

func (qi *inserter) Arguments() []interface{} {
_ = qi.statement()
return qi.arguments
}

Expand Down Expand Up @@ -69,34 +74,77 @@ func (qi *inserter) Columns(columns ...string) Inserter {
}

func (qi *inserter) Values(values ...interface{}) Inserter {
if len(values) == 1 {
ff, vv, err := Map(values[0], &MapOptions{IncludeZeroed: true, IncludeNil: true})
if err == nil {
columns, vals, arguments, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv)

qi.arguments = append(qi.arguments, arguments...)
qi.values = append(qi.values, vals)
if len(qi.columns) == 0 {
for _, c := range columns.Columns {
qi.columns = append(qi.columns, c)
qi.mu.Lock()
defer qi.mu.Unlock()

if qi.enqueuedValues == nil {
qi.enqueuedValues = [][]interface{}{}
}
qi.enqueuedValues = append(qi.enqueuedValues, values)
return qi
}

func (qi *inserter) processValues() (values []*exql.Values, arguments []interface{}) {
// TODO: simplify with immutable queries
var insertNils bool

for _, enqueuedValue := range qi.enqueuedValues {
if len(enqueuedValue) == 1 {
ff, vv, err := Map(enqueuedValue[0], nil)
if err == nil {
columns, vals, args, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv)

values, arguments = append(values, vals), append(arguments, args...)

if len(qi.columns) == 0 {
for _, c := range columns.Columns {
qi.columns = append(qi.columns, c)
}
} else {
if len(qi.columns) != len(columns.Columns) {
insertNils = true
break
}
}
continue
}
return qi
}
}

if len(qi.columns) == 0 || len(values) == len(qi.columns) {
qi.arguments = append(qi.arguments, values...)
if len(qi.columns) == 0 || len(enqueuedValue) == len(qi.columns) {
arguments = append(arguments, enqueuedValue...)

l := len(values)
placeholders := make([]exql.Fragment, l)
for i := 0; i < l; i++ {
placeholders[i] = exql.RawValue(`?`)
l := len(enqueuedValue)
placeholders := make([]exql.Fragment, l)
for i := 0; i < l; i++ {
placeholders[i] = exql.RawValue(`?`)
}
values = append(values, exql.NewValueGroup(placeholders...))
}
qi.values = append(qi.values, exql.NewValueGroup(placeholders...))
}

return qi
if insertNils {
values, arguments = values[0:0], arguments[0:0]

for _, enqueuedValue := range qi.enqueuedValues {
if len(enqueuedValue) == 1 {
ff, vv, err := Map(enqueuedValue[0], &MapOptions{IncludeZeroed: true, IncludeNil: true})
if err == nil {
columns, vals, args, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv)
values, arguments = append(values, vals), append(arguments, args...)

if len(qi.columns) != len(columns.Columns) {
qi.columns = qi.columns[0:0]
for _, c := range columns.Columns {
qi.columns = append(qi.columns, c)
}
}
}
continue
}
}
}

return
}

func (qi *inserter) statement() *exql.Statement {
Expand All @@ -105,14 +153,18 @@ func (qi *inserter) statement() *exql.Statement {
Table: exql.TableWithName(qi.table),
}

if len(qi.values) > 0 {
stmt.Values = exql.JoinValueGroups(qi.values...)
}
values, arguments := qi.processValues()

qi.arguments = arguments

if len(qi.columns) > 0 {
stmt.Columns = exql.JoinColumns(qi.columns...)
}

if len(values) > 0 {
stmt.Values = exql.JoinValueGroups(values...)
}

if len(qi.returning) > 0 {
stmt.Returning = exql.ReturningColumns(qi.returning...)
}
Expand Down
4 changes: 2 additions & 2 deletions lib/sqlbuilder/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector {

switch value := columns[i].(type) {
case db.RawValue:
col, args := expandPlaceholders(value.Raw(), value.Arguments())
col, args := Preprocess(value.Raw(), value.Arguments())
sort = &exql.SortColumn{
Column: exql.RawValue(col),
}
Expand All @@ -170,7 +170,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector {
} else {
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
}
expanded, fnArgs := expandPlaceholders(fnName, fnArgs)
expanded, fnArgs := Preprocess(fnName, fnArgs)
sort = &exql.SortColumn{
Column: exql.RawValue(expanded),
}
Expand Down

0 comments on commit 51cf352

Please sign in to comment.