Skip to content

Commit

Permalink
Add mem grow to builtin (#505)
Browse files Browse the repository at this point in the history
* Add mem grow to builtin

* Add OpValidateArgs opcode
  • Loading branch information
antonmedv authored Dec 30, 2023
1 parent 7b890a1 commit dd925fd
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 24 deletions.
24 changes: 14 additions & 10 deletions builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ import (
)

type Function struct {
Name string
Func func(args ...any) (any, error)
Fast func(arg any) any
Types []reflect.Type
Validate func(args []reflect.Type) (reflect.Type, error)
Predicate bool
Name string
Func func(args ...any) (any, error)
Fast func(arg any) any
ValidateArgs func(args ...any) (any, error)
Types []reflect.Type
Validate func(args []reflect.Type) (reflect.Type, error)
Predicate bool
}

var (
Expand Down Expand Up @@ -325,12 +326,15 @@ var Builtins = []*Function{
},
{
Name: "repeat",
Func: func(args ...any) (any, error) {
ValidateArgs: func(args ...any) (any, error) {
n := runtime.ToInt(args[1])
if n > 1e6 {
panic("memory budget exceeded")
if n < 0 {
panic(fmt.Errorf("invalid argument for repeat (expected positive integer, got %d)", n))
}
return strings.Repeat(args[0].(string), n), nil
return uint(n), nil
},
Func: func(args ...any) (any, error) {
return strings.Repeat(args[0].(string), runtime.ToInt(args[1])), nil
},
Types: types(strings.Repeat),
},
Expand Down
23 changes: 20 additions & 3 deletions builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,26 @@ func TestBuiltin_memory_limits(t *testing.T) {

for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
_, err := expr.Eval(test.input, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "memory budget exceeded")
timeout := make(chan bool, 1)
go func() {
time.Sleep(time.Second)
timeout <- true
}()

done := make(chan bool, 1)
go func() {
_, err := expr.Eval(test.input, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "memory budget exceeded")
done <- true
}()

select {
case <-done:
// Success.
case <-timeout:
t.Fatal("timeout")
}
})
}
}
Expand Down
26 changes: 16 additions & 10 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,31 +147,31 @@ func (c *compiler) addVariable(name string) int {
func (c *compiler) emitFunction(fn *builtin.Function, argsLen int) {
switch argsLen {
case 0:
c.emit(OpCall0, c.addFunction(fn))
c.emit(OpCall0, c.addFunction(fn.Name, fn.Func))
case 1:
c.emit(OpCall1, c.addFunction(fn))
c.emit(OpCall1, c.addFunction(fn.Name, fn.Func))
case 2:
c.emit(OpCall2, c.addFunction(fn))
c.emit(OpCall2, c.addFunction(fn.Name, fn.Func))
case 3:
c.emit(OpCall3, c.addFunction(fn))
c.emit(OpCall3, c.addFunction(fn.Name, fn.Func))
default:
c.emit(OpLoadFunc, c.addFunction(fn))
c.emit(OpLoadFunc, c.addFunction(fn.Name, fn.Func))
c.emit(OpCallN, argsLen)
}
}

// addFunction adds builtin.Function.Func to the program.functions and returns its index.
func (c *compiler) addFunction(fn *builtin.Function) int {
func (c *compiler) addFunction(name string, fn Function) int {
if fn == nil {
panic("function is nil")
}
if p, ok := c.functionsIndex[fn.Name]; ok {
if p, ok := c.functionsIndex[name]; ok {
return p
}
p := len(c.functions)
c.functions = append(c.functions, fn.Func)
c.functionsIndex[fn.Name] = p
c.debugInfo[fmt.Sprintf("func_%d", p)] = fn.Name
c.functions = append(c.functions, fn)
c.functionsIndex[name] = p
c.debugInfo[fmt.Sprintf("func_%d", p)] = name
return p
}

Expand Down Expand Up @@ -904,6 +904,12 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
for _, arg := range node.Arguments {
c.compile(arg)
}

if f.ValidateArgs != nil {
c.emit(OpLoadFunc, c.addFunction("$_validate_args_"+f.Name, f.ValidateArgs))
c.emit(OpValidateArgs, len(node.Arguments))
}

if f.Fast != nil {
c.emit(OpCallBuiltin1, id)
} else if f.Func != nil {
Expand Down
1 change: 1 addition & 0 deletions vm/opcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ const (
OpCallFast
OpCallTyped
OpCallBuiltin1
OpValidateArgs
OpArray
OpMap
OpLen
Expand Down
5 changes: 4 additions & 1 deletion vm/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (program *Program) DisassembleWriter(w io.Writer) {
constant("OpLoadMethod")

case OpLoadFunc:
argument("OpLoadFunc")
argumentWithInfo("OpLoadFunc", "func")

case OpLoadEnv:
code("OpLoadEnv")
Expand Down Expand Up @@ -278,6 +278,9 @@ func (program *Program) DisassembleWriter(w io.Writer) {
case OpCallBuiltin1:
builtinArg("OpCallBuiltin1")

case OpValidateArgs:
argument("OpValidateArgs")

case OpArray:
code("OpArray")

Expand Down
8 changes: 8 additions & 0 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,14 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
case OpCallBuiltin1:
vm.push(builtin.Builtins[arg].Fast(vm.pop()))

case OpValidateArgs:
fn := vm.pop().(Function)
mem, err := fn(vm.stack[len(vm.stack)-arg:]...)
if err != nil {
panic(err)
}
vm.memGrow(mem.(uint))

case OpArray:
size := vm.pop().(int)
vm.memGrow(uint(size))
Expand Down

0 comments on commit dd925fd

Please sign in to comment.