Skip to content

Commit

Permalink
Merge pull request #317 from upper/issue-316
Browse files Browse the repository at this point in the history
Issue 316
  • Loading branch information
José Carlos authored Jan 5, 2017
2 parents 51cf352 + 2c780be commit 9b757d6
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 26 deletions.
9 changes: 6 additions & 3 deletions internal/sqladapter/exql/column_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ func (c *ColumnValue) Compile(layout *Template) (compiled string) {
}

data := columnValueT{
c.Column.Compile(layout),
c.Operator,
c.Value.Compile(layout),
Column: c.Column.Compile(layout),
Operator: c.Operator,
}

if c.Value != nil {
data.Value = c.Value.Compile(layout)
}

compiled = mustParse(layout.ColumnValue, data)
Expand Down
4 changes: 4 additions & 0 deletions lib/sqlbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error)
return nil, nil, ErrExpectingPointerToEitherMapOrStruct
}

if len(fv.fields) == 0 {
return nil, nil, errors.New("No values mapped.")
}

sort.Sort(&fv)

return fv.fields, fv.values, nil
Expand Down
44 changes: 44 additions & 0 deletions lib/sqlbuilder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,50 @@ func TestUpdate(t *testing.T) {
"id = id + ?", 10,
).Where("id > ?", 0).String(),
)

{
q := b.Update("posts").Set("column = ?", "foo")

assert.Equal(
`UPDATE "posts" SET "column" = $1`,
q.String(),
)

assert.Equal(
[]interface{}{"foo"},
q.Arguments(),
)
}

{
q := b.Update("posts").Set(db.Raw("column = ?", "foo"))

assert.Equal(
`UPDATE "posts" SET column = $1`,
q.String(),
)

assert.Equal(
[]interface{}{"foo"},
q.Arguments(),
)
}

{
q := b.Update("posts").Set(
db.Cond{"tags": db.Raw("array_remove(tags, ?)", "foo")},
).Where(db.Raw("hub_id = ? AND ? = ANY(tags) AND ? = ANY(tags)", 1, "bar", "baz"))

assert.Equal(
`UPDATE "posts" SET "tags" = array_remove(tags, $1) WHERE (hub_id = $2 AND $3 = ANY(tags) AND $4 = ANY(tags))`,
q.String(),
)

assert.Equal(
[]interface{}{"foo", 1, "bar", "baz"},
q.Arguments(),
)
}
}

func TestDelete(t *testing.T) {
Expand Down
20 changes: 18 additions & 2 deletions lib/sqlbuilder/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.
func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) {
switch t := in.(type) {
case db.RawValue:
return exql.RawValue(t.String()), nil
return exql.RawValue(t.String()), t.Arguments()
case db.Function:
fnName := t.Name()
fnArgs := []interface{}{}
Expand Down Expand Up @@ -230,7 +230,14 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal
case []interface{}:
l := len(t)
for i := 0; i < l; i++ {
column := t[i].(string)
column, ok := t[i].(string)

if !ok {
p, q := tu.ToColumnValues(t[i])
cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...)
args = append(args, q...)
continue
}

if !strings.ContainsAny(column, "=") {
column = fmt.Sprintf("%s = ?", column)
Expand Down Expand Up @@ -337,6 +344,15 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal

cv.ColumnValues = append(cv.ColumnValues, &columnValue)

return cv, args
case db.RawValue:
columnValue := exql.ColumnValue{}
p, q := Preprocess(t.Raw(), t.Arguments())

columnValue.Column = exql.RawValue(p)
args = append(args, q...)

cv.ColumnValues = append(cv.ColumnValues, &columnValue)
return cv, args
case db.Constraints:
for _, c := range t.Constraints() {
Expand Down
46 changes: 25 additions & 21 deletions lib/sqlbuilder/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,38 @@ type updater struct {
mu sync.Mutex
}

func (qu *updater) Set(terms ...interface{}) Updater {
if len(terms) == 1 {
ff, vv, _ := Map(terms[0], nil)
func (qu *updater) Set(columns ...interface{}) Updater {

cvs := make([]exql.Fragment, 0, len(ff))
args := make([]interface{}, 0, len(vv))
if len(columns) == 1 {
ff, vv, err := Map(columns[0], nil)
if err == nil {

for i := range ff {
cv := &exql.ColumnValue{
Column: exql.ColumnWithName(ff[i]),
Operator: qu.builder.t.AssignmentOperator,
}
cvs := make([]exql.Fragment, 0, len(ff))
args := make([]interface{}, 0, len(vv))

var localArgs []interface{}
cv.Value, localArgs = qu.builder.t.PlaceholderValue(vv[i])
for i := range ff {
cv := &exql.ColumnValue{
Column: exql.ColumnWithName(ff[i]),
Operator: qu.builder.t.AssignmentOperator,
}

args = append(args, localArgs...)
cvs = append(cvs, cv)
}
var localArgs []interface{}
cv.Value, localArgs = qu.builder.t.PlaceholderValue(vv[i])

args = append(args, localArgs...)
cvs = append(cvs, cv)
}

qu.columnValues.Insert(cvs...)
qu.columnValuesArgs = append(qu.columnValuesArgs, args...)
} else if len(terms) > 1 {
cv, arguments := qu.builder.t.ToColumnValues(terms)
qu.columnValues.Insert(cv.ColumnValues...)
qu.columnValuesArgs = append(qu.columnValuesArgs, arguments...)
qu.columnValues.Insert(cvs...)
qu.columnValuesArgs = append(qu.columnValuesArgs, args...)
return qu
}
}

cv, arguments := qu.builder.t.ToColumnValues(columns)
qu.columnValues.Insert(cv.ColumnValues...)
qu.columnValuesArgs = append(qu.columnValuesArgs, arguments...)

return qu
}

Expand Down

0 comments on commit 9b757d6

Please sign in to comment.