Skip to content

Commit

Permalink
Add OpCallSafe
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv committed Feb 12, 2024
1 parent e53cefe commit 84ac0b8
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 32 deletions.
13 changes: 7 additions & 6 deletions builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,16 @@ var Builtins = []*Function{
},
{
Name: "repeat",
ValidateArgs: func(args ...any) (any, error) {
Safe: func(args ...any) (any, uint, error) {
s := args[0].(string)
n := runtime.ToInt(args[1])
if n < 0 {
panic(fmt.Errorf("invalid argument for repeat (expected positive integer, got %d)", n))
return nil, 0, fmt.Errorf("invalid argument for repeat (expected positive integer, got %d)", n)
}
return uint(n), nil
},
Func: func(args ...any) (any, error) {
return strings.Repeat(args[0].(string), runtime.ToInt(args[1])), nil
if n > 1e6 {
return nil, 0, fmt.Errorf("memory budget exceeded")
}
return strings.Repeat(s, n), uint(len(s) * n), nil
},
Types: types(strings.Repeat),
},
Expand Down
14 changes: 7 additions & 7 deletions builtin/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import (
)

type Function struct {
Name string
Func func(args ...any) (any, error)
Fast func(arg any) any
ValidateArgs func(args ...any) (any, error)
Types []reflect.Type
Validate func(args []reflect.Type) (reflect.Type, error)
Predicate bool
Name string
Fast func(arg any) any
Func func(args ...any) (any, error)
Safe func(args ...any) (any, uint, error)
Types []reflect.Type
Validate func(args []reflect.Type) (reflect.Type, error)
Predicate bool
}

func (f *Function) Type() reflect.Type {
Expand Down
10 changes: 4 additions & 6 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (c *compiler) addConstant(constant any) int {
indexable := true
hash := constant
switch reflect.TypeOf(constant).Kind() {
case reflect.Slice, reflect.Map, reflect.Struct:
case reflect.Slice, reflect.Map, reflect.Struct, reflect.Func:
indexable = false
}
if field, ok := constant.(*runtime.Field); ok {
Expand Down Expand Up @@ -908,13 +908,11 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
c.compile(arg)
}

if f.ValidateArgs != nil {
c.emit(OpLoadFunc, c.addFunction("$_validate_args_"+f.Name, f.ValidateArgs))
c.emit(OpValidateArgs, len(node.Arguments))
}

if f.Fast != nil {
c.emit(OpCallBuiltin1, id)
} else if f.Safe != nil {
c.emit(OpPush, c.addConstant(f.Safe))
c.emit(OpCallSafe, len(node.Arguments))
} else if f.Func != nil {
c.emitFunction(f, len(node.Arguments))
}
Expand Down
2 changes: 1 addition & 1 deletion vm/opcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ const (
OpCall3
OpCallN
OpCallFast
OpCallSafe
OpCallTyped
OpCallBuiltin1
OpValidateArgs
OpArray
OpMap
OpLen
Expand Down
6 changes: 3 additions & 3 deletions vm/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,16 @@ func (program *Program) DisassembleWriter(w io.Writer) {
case OpCallFast:
argument("OpCallFast")

case OpCallSafe:
argument("OpCallSafe")

case OpCallTyped:
signature := reflect.TypeOf(FuncTypes[arg]).Elem().String()
_, _ = fmt.Fprintf(w, "%v\t%v\t<%v>\t%v\n", pp, "OpCallTyped", arg, signature)

case OpCallBuiltin1:
builtinArg("OpCallBuiltin1")

case OpValidateArgs:
argument("OpValidateArgs")

case OpArray:
code("OpArray")

Expand Down
6 changes: 5 additions & 1 deletion vm/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ import (
"reflect"
)

type Function = func(params ...any) (any, error)
type (
Function = func(params ...any) (any, error)
SafeFunction = func(params ...any) (any, uint, error)
)

// MemoryBudget represents an upper limit of memory usage.
var MemoryBudget uint = 1e6

var errorType = reflect.TypeOf((*error)(nil)).Elem()
22 changes: 14 additions & 8 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,20 +389,26 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
}
vm.push(fn(in...))

case OpCallSafe:
fn := vm.pop().(SafeFunction)
size := arg
in := make([]any, size)
for i := int(size) - 1; i >= 0; i-- {
in[i] = vm.pop()
}
out, mem, err := fn(in...)
if err != nil {
panic(err)
}
vm.memGrow(mem)
vm.push(out)

case OpCallTyped:
vm.push(vm.call(vm.pop(), arg))

case OpCallBuiltin1:
vm.push(builtin.Builtins[arg].Fast(vm.pop()))

case OpValidateArgs:
fn := vm.pop().(Function)
mem, err := fn(vm.Stack[len(vm.Stack)-arg:]...)
if err != nil {
panic(err)
}
vm.memGrow(mem.(uint))

case OpArray:
size := vm.pop().(int)
vm.memGrow(uint(size))
Expand Down

0 comments on commit 84ac0b8

Please sign in to comment.