diff --git a/builtin/builtin.go b/builtin/builtin.go index 2417446d..cd943bd0 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -13,12 +13,13 @@ import ( ) type Function struct { - Name string - Func func(args ...any) (any, error) - Fast func(arg any) any - Types []reflect.Type - Validate func(args []reflect.Type) (reflect.Type, error) - Predicate bool + Name string + Func func(args ...any) (any, error) + Fast func(arg any) any + ValidateArgs func(args ...any) (any, error) + Types []reflect.Type + Validate func(args []reflect.Type) (reflect.Type, error) + Predicate bool } var ( @@ -325,12 +326,15 @@ var Builtins = []*Function{ }, { Name: "repeat", - Func: func(args ...any) (any, error) { + ValidateArgs: func(args ...any) (any, error) { n := runtime.ToInt(args[1]) - if n > 1e6 { - panic("memory budget exceeded") + if n < 0 { + panic(fmt.Errorf("invalid argument for repeat (expected positive integer, got %d)", n)) } - return strings.Repeat(args[0].(string), n), nil + return uint(n), nil + }, + Func: func(args ...any) (any, error) { + return strings.Repeat(args[0].(string), runtime.ToInt(args[1])), nil }, Types: types(strings.Repeat), }, diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index a861d0fa..23273a24 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -260,9 +260,26 @@ func TestBuiltin_memory_limits(t *testing.T) { for _, test := range tests { t.Run(test.input, func(t *testing.T) { - _, err := expr.Eval(test.input, nil) - assert.Error(t, err) - assert.Contains(t, err.Error(), "memory budget exceeded") + timeout := make(chan bool, 1) + go func() { + time.Sleep(time.Second) + timeout <- true + }() + + done := make(chan bool, 1) + go func() { + _, err := expr.Eval(test.input, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "memory budget exceeded") + done <- true + }() + + select { + case <-done: + // Success. + case <-timeout: + t.Fatal("timeout") + } }) } } diff --git a/compiler/compiler.go b/compiler/compiler.go index f4bb3bd5..1f1c00c9 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -147,31 +147,31 @@ func (c *compiler) addVariable(name string) int { func (c *compiler) emitFunction(fn *builtin.Function, argsLen int) { switch argsLen { case 0: - c.emit(OpCall0, c.addFunction(fn)) + c.emit(OpCall0, c.addFunction(fn.Name, fn.Func)) case 1: - c.emit(OpCall1, c.addFunction(fn)) + c.emit(OpCall1, c.addFunction(fn.Name, fn.Func)) case 2: - c.emit(OpCall2, c.addFunction(fn)) + c.emit(OpCall2, c.addFunction(fn.Name, fn.Func)) case 3: - c.emit(OpCall3, c.addFunction(fn)) + c.emit(OpCall3, c.addFunction(fn.Name, fn.Func)) default: - c.emit(OpLoadFunc, c.addFunction(fn)) + c.emit(OpLoadFunc, c.addFunction(fn.Name, fn.Func)) c.emit(OpCallN, argsLen) } } // addFunction adds builtin.Function.Func to the program.functions and returns its index. -func (c *compiler) addFunction(fn *builtin.Function) int { +func (c *compiler) addFunction(name string, fn Function) int { if fn == nil { panic("function is nil") } - if p, ok := c.functionsIndex[fn.Name]; ok { + if p, ok := c.functionsIndex[name]; ok { return p } p := len(c.functions) - c.functions = append(c.functions, fn.Func) - c.functionsIndex[fn.Name] = p - c.debugInfo[fmt.Sprintf("func_%d", p)] = fn.Name + c.functions = append(c.functions, fn) + c.functionsIndex[name] = p + c.debugInfo[fmt.Sprintf("func_%d", p)] = name return p } @@ -904,6 +904,12 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { for _, arg := range node.Arguments { c.compile(arg) } + + if f.ValidateArgs != nil { + c.emit(OpLoadFunc, c.addFunction("$_validate_args_"+f.Name, f.ValidateArgs)) + c.emit(OpValidateArgs, len(node.Arguments)) + } + if f.Fast != nil { c.emit(OpCallBuiltin1, id) } else if f.Func != nil { diff --git a/vm/opcodes.go b/vm/opcodes.go index 1106cd3f..57b84050 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -60,6 +60,7 @@ const ( OpCallFast OpCallTyped OpCallBuiltin1 + OpValidateArgs OpArray OpMap OpLen diff --git a/vm/program.go b/vm/program.go index 27b8f609..085b18a7 100644 --- a/vm/program.go +++ b/vm/program.go @@ -137,7 +137,7 @@ func (program *Program) DisassembleWriter(w io.Writer) { constant("OpLoadMethod") case OpLoadFunc: - argument("OpLoadFunc") + argumentWithInfo("OpLoadFunc", "func") case OpLoadEnv: code("OpLoadEnv") @@ -278,6 +278,9 @@ func (program *Program) DisassembleWriter(w io.Writer) { case OpCallBuiltin1: builtinArg("OpCallBuiltin1") + case OpValidateArgs: + argument("OpValidateArgs") + case OpArray: code("OpArray") diff --git a/vm/vm.go b/vm/vm.go index 20e1594f..fa2c7589 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -397,6 +397,14 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { case OpCallBuiltin1: vm.push(builtin.Builtins[arg].Fast(vm.pop())) + case OpValidateArgs: + fn := vm.pop().(Function) + mem, err := fn(vm.stack[len(vm.stack)-arg:]...) + if err != nil { + panic(err) + } + vm.memGrow(mem.(uint)) + case OpArray: size := vm.pop().(int) vm.memGrow(uint(size))