From 38f949658f8e27c1e9367487b4769b51128c36e9 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Sat, 13 Apr 2024 15:37:02 +0200 Subject: [PATCH] Add predicate to sum() builtin (#592) * Add predicate to sum() builtin * go mod tidy --- builtin/builtin.go | 12 +++++------- builtin/builtin_test.go | 4 ---- builtin/lib.go | 39 --------------------------------------- checker/checker.go | 23 +++++++++++++++++++++++ compiler/compiler.go | 19 +++++++++++++++++++ parser/parser.go | 1 + test/fuzz/fuzz_corpus.txt | 1 - testdata/examples.txt | 14 -------------- 8 files changed, 48 insertions(+), 65 deletions(-) diff --git a/builtin/builtin.go b/builtin/builtin.go index 7bf377df..efc01fc2 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -83,6 +83,11 @@ var Builtins = []*Function{ Predicate: true, Types: types(new(func([]any, func(any) bool) int)), }, + { + Name: "sum", + Predicate: true, + Types: types(new(func([]any, func(any) bool) int)), + }, { Name: "groupBy", Predicate: true, @@ -387,13 +392,6 @@ var Builtins = []*Function{ return validateAggregateFunc("min", args) }, }, - { - Name: "sum", - Func: sum, - Validate: func(args []reflect.Type) (reflect.Type, error) { - return validateAggregateFunc("sum", args) - }, - }, { Name: "mean", Func: func(args ...any) (any, error) { diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index 7f5045f4..307d4a86 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -90,9 +90,6 @@ func TestBuiltin(t *testing.T) { {`sum([.5, 1.5, 2.5])`, 4.5}, {`sum([])`, 0}, {`sum([1, 2, 3.0, 4])`, 10.0}, - {`sum(10, [1, 2, 3], 1..9)`, 61}, - {`sum(-10, [1, 2, 3, 4])`, 0}, - {`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1}, {`mean(1..9)`, 5.0}, {`mean([.5, 1.5, 2.5])`, 1.5}, {`mean([])`, 0.0}, @@ -219,7 +216,6 @@ func TestBuiltin_errors(t *testing.T) { {`min([1, "2"])`, `invalid argument for min (type string)`}, {`median(1..9, "t")`, "invalid argument for median (type string)"}, {`mean("s", 1..9)`, "invalid argument for mean (type string)"}, - {`sum("s", "h")`, "invalid argument for sum (type string)"}, {`duration("error")`, `invalid duration`}, {`date("error")`, `invalid date`}, {`get()`, `invalid number of arguments (expected 2, got 0)`}, diff --git a/builtin/lib.go b/builtin/lib.go index e3a6c0ae..e3cd61b9 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -258,45 +258,6 @@ func String(arg any) any { return fmt.Sprintf("%v", arg) } -func sum(args ...any) (any, error) { - var total int - var fTotal float64 - - for _, arg := range args { - rv := reflect.ValueOf(deref.Deref(arg)) - - switch rv.Kind() { - case reflect.Array, reflect.Slice: - size := rv.Len() - for i := 0; i < size; i++ { - elemSum, err := sum(rv.Index(i).Interface()) - if err != nil { - return nil, err - } - switch elemSum := elemSum.(type) { - case int: - total += elemSum - case float64: - fTotal += elemSum - } - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - total += int(rv.Int()) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - total += int(rv.Uint()) - case reflect.Float32, reflect.Float64: - fTotal += rv.Float() - default: - return nil, fmt.Errorf("invalid argument for sum (type %T)", arg) - } - } - - if fTotal != 0.0 { - return fTotal + float64(total), nil - } - return total, nil -} - func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { var val any for _, arg := range args { diff --git a/checker/checker.go b/checker/checker.go index b46178d4..a6daa27b 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -668,6 +668,29 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { } return v.error(node.Arguments[1], "predicate should has one input and one output param") + case "sum": + collection, _ := v.visit(node.Arguments[0]) + if !isArray(collection) && !isAny(collection) { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + } + + if len(node.Arguments) == 2 { + v.begin(collection) + closure, _ := v.visit(node.Arguments[1]) + v.end() + + if isFunc(closure) && + closure.NumOut() == 1 && + closure.NumIn() == 1 && isAny(closure.In(0)) { + return closure.Out(0), info{} + } + } else { + if isAny(collection) { + return anyType, info{} + } + return collection.Elem(), info{} + } + case "find", "findLast": collection, _ := v.visit(node.Arguments[0]) if !isArray(collection) && !isAny(collection) { diff --git a/compiler/compiler.go b/compiler/compiler.go index a38d977d..07bc58b8 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -809,6 +809,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { c.emit(OpEnd) return + case "sum": + c.compile(node.Arguments[0]) + c.emit(OpBegin) + c.emit(OpInt, 0) + c.emit(OpSetAcc) + c.emitLoop(func() { + if len(node.Arguments) == 2 { + c.compile(node.Arguments[1]) + } else { + c.emit(OpPointer) + } + c.emit(OpGetAcc) + c.emit(OpAdd) + c.emit(OpSetAcc) + }) + c.emit(OpGetAcc) + c.emit(OpEnd) + return + case "find": c.compile(node.Arguments[0]) c.emit(OpBegin) diff --git a/parser/parser.go b/parser/parser.go index 9cb79cbb..641369b1 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -34,6 +34,7 @@ var predicates = map[string]struct { "filter": {[]arg{expr, closure}}, "map": {[]arg{expr, closure}}, "count": {[]arg{expr, closure}}, + "sum": {[]arg{expr, closure | optional}}, "find": {[]arg{expr, closure}}, "findIndex": {[]arg{expr, closure}}, "findLast": {[]arg{expr, closure}}, diff --git a/test/fuzz/fuzz_corpus.txt b/test/fuzz/fuzz_corpus.txt index 2bb8ab7f..16654863 100644 --- a/test/fuzz/fuzz_corpus.txt +++ b/test/fuzz/fuzz_corpus.txt @@ -10455,7 +10455,6 @@ max(f64, i64) max(false ? 1 : 0.5) max(false ? 1 : nil) max(false ? add : ok) -max(false ? half : list) max(false ? i : nil) max(false ? i32 : score) max(false ? true : 1) diff --git a/testdata/examples.txt b/testdata/examples.txt index 712aa91c..5b9d2cdd 100644 --- a/testdata/examples.txt +++ b/testdata/examples.txt @@ -7419,12 +7419,6 @@ get(ok ? score : foo, String?.foo()) get(ok ? score : i64, foo) get(reduce(list, array), i32) get(sort(array), i32) -get(sum(array), Qux) -get(sum(array), String) -get(sum(array), f32) -get(sum(array), f64 == list) -get(sum(array), greet) -get(sum(array), i) get(take(list, i), i64) get(true ? "bar" : ok, score(i)) get(true ? "foo" : half, list) @@ -7460,7 +7454,6 @@ greet != nil ? list : false greet != score greet != score != false greet != score or ok -greet != sum(array) greet == add greet == add ? i : list greet == add or ok @@ -12200,7 +12193,6 @@ last(ok ? ok : 0.5) last(reduce(array, list)) last(reduce(list, array)) last(sort(array)) -last(sum(array)) last(true ? "bar" : half) last(true ? add : list) last(true ? foo : 1) @@ -14818,7 +14810,6 @@ ok != nil ? nil : array ok != not ok ok != ok ok != ok ? false : "bar" -ok != sum(array) ok && !false ok && !ok ok && "foo" matches "bar" @@ -16970,7 +16961,6 @@ string(groupBy(list, i)) string(half != nil) string(half != score) string(half == nil) -string(half == sum(array)) string(half(0.5)) string(half(1)) string(half(f64)) @@ -17297,18 +17287,14 @@ sum([0.5]) sum([f32]) sum(array) sum(array) != f32 -sum(array) != half -sum(array) != ok sum(array) % i sum(array) % i64 sum(array) - f32 sum(array) / -f64 sum(array) < i -sum(array) == div sum(array) == i64 - i sum(array) ^ f64 sum(array) not in array -sum(array) not in list sum(filter(array, ok)) sum(groupBy(array, i32).String) sum(groupBy(list, #)?.greet)