Skip to content

Commit

Permalink
Add predicate to sum() builtin (#592)
Browse files Browse the repository at this point in the history
* Add predicate to sum() builtin

* go mod tidy
  • Loading branch information
antonmedv authored Apr 13, 2024
1 parent d66ffcd commit 38f9496
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 65 deletions.
12 changes: 5 additions & 7 deletions builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ var Builtins = []*Function{
Predicate: true,
Types: types(new(func([]any, func(any) bool) int)),
},
{
Name: "sum",
Predicate: true,
Types: types(new(func([]any, func(any) bool) int)),
},
{
Name: "groupBy",
Predicate: true,
Expand Down Expand Up @@ -387,13 +392,6 @@ var Builtins = []*Function{
return validateAggregateFunc("min", args)
},
},
{
Name: "sum",
Func: sum,
Validate: func(args []reflect.Type) (reflect.Type, error) {
return validateAggregateFunc("sum", args)
},
},
{
Name: "mean",
Func: func(args ...any) (any, error) {
Expand Down
4 changes: 0 additions & 4 deletions builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ func TestBuiltin(t *testing.T) {
{`sum([.5, 1.5, 2.5])`, 4.5},
{`sum([])`, 0},
{`sum([1, 2, 3.0, 4])`, 10.0},
{`sum(10, [1, 2, 3], 1..9)`, 61},
{`sum(-10, [1, 2, 3, 4])`, 0},
{`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1},
{`mean(1..9)`, 5.0},
{`mean([.5, 1.5, 2.5])`, 1.5},
{`mean([])`, 0.0},
Expand Down Expand Up @@ -219,7 +216,6 @@ func TestBuiltin_errors(t *testing.T) {
{`min([1, "2"])`, `invalid argument for min (type string)`},
{`median(1..9, "t")`, "invalid argument for median (type string)"},
{`mean("s", 1..9)`, "invalid argument for mean (type string)"},
{`sum("s", "h")`, "invalid argument for sum (type string)"},
{`duration("error")`, `invalid duration`},
{`date("error")`, `invalid date`},
{`get()`, `invalid number of arguments (expected 2, got 0)`},
Expand Down
39 changes: 0 additions & 39 deletions builtin/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,45 +258,6 @@ func String(arg any) any {
return fmt.Sprintf("%v", arg)
}

func sum(args ...any) (any, error) {
var total int
var fTotal float64

for _, arg := range args {
rv := reflect.ValueOf(deref.Deref(arg))

switch rv.Kind() {
case reflect.Array, reflect.Slice:
size := rv.Len()
for i := 0; i < size; i++ {
elemSum, err := sum(rv.Index(i).Interface())
if err != nil {
return nil, err
}
switch elemSum := elemSum.(type) {
case int:
total += elemSum
case float64:
fTotal += elemSum
}
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
total += int(rv.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
total += int(rv.Uint())
case reflect.Float32, reflect.Float64:
fTotal += rv.Float()
default:
return nil, fmt.Errorf("invalid argument for sum (type %T)", arg)
}
}

if fTotal != 0.0 {
return fTotal + float64(total), nil
}
return total, nil
}

func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
var val any
for _, arg := range args {
Expand Down
23 changes: 23 additions & 0 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,29 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")

case "sum":
collection, _ := v.visit(node.Arguments[0])
if !isArray(collection) && !isAny(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}

if len(node.Arguments) == 2 {
v.begin(collection)
closure, _ := v.visit(node.Arguments[1])
v.end()

if isFunc(closure) &&
closure.NumOut() == 1 &&
closure.NumIn() == 1 && isAny(closure.In(0)) {
return closure.Out(0), info{}
}
} else {
if isAny(collection) {
return anyType, info{}
}
return collection.Elem(), info{}
}

case "find", "findLast":
collection, _ := v.visit(node.Arguments[0])
if !isArray(collection) && !isAny(collection) {
Expand Down
19 changes: 19 additions & 0 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
c.emit(OpEnd)
return

case "sum":
c.compile(node.Arguments[0])
c.emit(OpBegin)
c.emit(OpInt, 0)
c.emit(OpSetAcc)
c.emitLoop(func() {
if len(node.Arguments) == 2 {
c.compile(node.Arguments[1])
} else {
c.emit(OpPointer)
}
c.emit(OpGetAcc)
c.emit(OpAdd)
c.emit(OpSetAcc)
})
c.emit(OpGetAcc)
c.emit(OpEnd)
return

case "find":
c.compile(node.Arguments[0])
c.emit(OpBegin)
Expand Down
1 change: 1 addition & 0 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var predicates = map[string]struct {
"filter": {[]arg{expr, closure}},
"map": {[]arg{expr, closure}},
"count": {[]arg{expr, closure}},
"sum": {[]arg{expr, closure | optional}},
"find": {[]arg{expr, closure}},
"findIndex": {[]arg{expr, closure}},
"findLast": {[]arg{expr, closure}},
Expand Down
1 change: 0 additions & 1 deletion test/fuzz/fuzz_corpus.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10455,7 +10455,6 @@ max(f64, i64)
max(false ? 1 : 0.5)
max(false ? 1 : nil)
max(false ? add : ok)
max(false ? half : list)
max(false ? i : nil)
max(false ? i32 : score)
max(false ? true : 1)
Expand Down
14 changes: 0 additions & 14 deletions testdata/examples.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7419,12 +7419,6 @@ get(ok ? score : foo, String?.foo())
get(ok ? score : i64, foo)
get(reduce(list, array), i32)
get(sort(array), i32)
get(sum(array), Qux)
get(sum(array), String)
get(sum(array), f32)
get(sum(array), f64 == list)
get(sum(array), greet)
get(sum(array), i)
get(take(list, i), i64)
get(true ? "bar" : ok, score(i))
get(true ? "foo" : half, list)
Expand Down Expand Up @@ -7460,7 +7454,6 @@ greet != nil ? list : false
greet != score
greet != score != false
greet != score or ok
greet != sum(array)
greet == add
greet == add ? i : list
greet == add or ok
Expand Down Expand Up @@ -12200,7 +12193,6 @@ last(ok ? ok : 0.5)
last(reduce(array, list))
last(reduce(list, array))
last(sort(array))
last(sum(array))
last(true ? "bar" : half)
last(true ? add : list)
last(true ? foo : 1)
Expand Down Expand Up @@ -14818,7 +14810,6 @@ ok != nil ? nil : array
ok != not ok
ok != ok
ok != ok ? false : "bar"
ok != sum(array)
ok && !false
ok && !ok
ok && "foo" matches "bar"
Expand Down Expand Up @@ -16970,7 +16961,6 @@ string(groupBy(list, i))
string(half != nil)
string(half != score)
string(half == nil)
string(half == sum(array))
string(half(0.5))
string(half(1))
string(half(f64))
Expand Down Expand Up @@ -17297,18 +17287,14 @@ sum([0.5])
sum([f32])
sum(array)
sum(array) != f32
sum(array) != half
sum(array) != ok
sum(array) % i
sum(array) % i64
sum(array) - f32
sum(array) / -f64
sum(array) < i
sum(array) == div
sum(array) == i64 - i
sum(array) ^ f64
sum(array) not in array
sum(array) not in list
sum(filter(array, ok))
sum(groupBy(array, i32).String)
sum(groupBy(list, #)?.greet)
Expand Down

0 comments on commit 38f9496

Please sign in to comment.