Skip to content

Commit

Permalink
api: adds CallWithStack to avoid allocations (#1407)
Browse files Browse the repository at this point in the history
Signed-off-by: Nuno Cruces <[email protected]>
  • Loading branch information
ncruces authored Apr 30, 2023
1 parent 0bfb4b5 commit 77e8d72
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 105 deletions.
28 changes: 28 additions & 0 deletions api/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,34 @@ type Function interface {
// the end-to-end demonstrations of how these terminations can be performed.
Call(ctx context.Context, params ...uint64) ([]uint64, error)

// CallWithStack is an optimized variation of Call that saves memory
// allocations when the stack slice is reused across calls.
//
// Stack length must be at least the max of parameter or result length.
// The caller adds parameters in order to the stack, and reads any results
// in order from the stack, except in the error case.
//
// For example, the following reuses the same stack slice to call searchFn
// repeatedly saving one allocation per iteration:
//
// stack := make([]uint64, 4)
// for i, search := range searchParams {
// // copy the next params to the stack
// copy(stack, search)
// if err := searchFn.CallWithStack(ctx, stack); err != nil {
// return err
// } else if stack[0] == 1 { // found
// return i // searchParams[i] matched!
// }
// }
//
// # Notes
//
// - This is similar to GoModuleFunction, except for using calling functions
// instead of implementing them. Moreover, this is used regardless of
// whether the callee is a host or wasm defined function.
CallWithStack(ctx context.Context, stack []uint64) error

internalapi.WazeroOnly
}

Expand Down
9 changes: 5 additions & 4 deletions internal/emscripten/emscripten.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ func (v *InvokeFunc) Call(ctx context.Context, mod api.Module, stack []uint64) {
panic(err)
}

// This needs copy (not reslice) because the stack is reused for results.
// Consider invoke_i (zero arguments, one result): index zero (tableOffset)
// is needed to store the result.
tableOffset := wasm.Index(stack[0]) // position in the module's only table.
params := stack[1:] // parameters to the dynamic function being called
copy(stack, stack[1:]) // pop the tableOffset.

// Lookup the table index we will call.
t := m.Tables[0] // Note: Emscripten doesn't use multiple tables
Expand All @@ -86,10 +89,8 @@ func (v *InvokeFunc) Call(ctx context.Context, mod api.Module, stack []uint64) {
panic(err)
}

ret, err := m.Engine.NewFunction(idx).Call(ctx, params...)
err = m.Engine.NewFunction(idx).CallWithStack(ctx, stack)
if err != nil {
panic(err)
}
// if there are any results, copy them back to the stack
copy(stack, ret)
}
38 changes: 25 additions & 13 deletions internal/engine/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,24 @@ func (ce *callEngine) Definition() api.FunctionDefinition {

// Call implements the same method as documented on wasm.ModuleEngine.
func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uint64, err error) {
ft := ce.initialFn.funcType
if n := ft.ParamNumInUint64; n != len(params) {
return nil, fmt.Errorf("expected %d params, but passed %d", n, len(params))
}
return ce.call(ctx, params, nil)
}

// CallWithStack implements the same method as documented on wasm.ModuleEngine.
func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error {
params, results, err := wasm.SplitCallStack(ce.initialFn.funcType, stack)
if err != nil {
return err
}
_, err = ce.call(ctx, params, results)
return err
}

func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) {
m := ce.initialFn.moduleInstance
if ce.ensureTermination {
select {
Expand All @@ -717,13 +735,6 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin
}
}

tp := ce.initialFn.funcType

paramCount := len(params)
if tp.ParamNumInUint64 != paramCount {
return nil, fmt.Errorf("expected %d params, but passed %d", ce.initialFn.funcType.ParamNumInUint64, paramCount)
}

// We ensure that this Call method never panics as
// this Call method is indirectly invoked by embedders via store.CallFunction,
// and we have to make sure that all the runtime errors, including the one happening inside
Expand All @@ -736,7 +747,8 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin
}
}()

ce.initializeStack(tp, params)
ft := ce.initialFn.funcType
ce.initializeStack(ft, params)

if ce.ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
Expand All @@ -747,12 +759,12 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin

// This returns a safe copy of the results, instead of a slice view. If we
// returned a re-slice, the caller could accidentally or purposefully
// corrupt the stack of subsequent calls
if resultCount := tp.ResultNumInUint64; resultCount > 0 {
results = make([]uint64, resultCount)
copy(results, ce.stack[:resultCount])
// corrupt the stack of subsequent calls.
if results == nil && ft.ResultNumInUint64 > 0 {
results = make([]uint64, ft.ResultNumInUint64)
}
return
copy(results, ce.stack)
return results, nil
}

// initializeStack initializes callEngine.stack before entering native code.
Expand Down
10 changes: 10 additions & 0 deletions internal/engine/compiler/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ func TestCompiler_ModuleEngine_Call(t *testing.T) {
`, "\n"+functionLog.String())
}

func TestCompiler_ModuleEngine_CallWithStack(t *testing.T) {
defer functionLog.Reset()
requireSupportedOSArch(t)
enginetest.RunTestModuleEngineCallWithStack(t, et)
require.Equal(t, `
--> .$0(1,2)
<-- (1,2)
`, "\n"+functionLog.String())
}

func TestCompiler_ModuleEngine_Call_HostFn(t *testing.T) {
defer functionLog.Reset()
requireSupportedOSArch(t)
Expand Down
49 changes: 34 additions & 15 deletions internal/engine/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ func (ce *callEngine) pushValue(v uint64) {
ce.stack = append(ce.stack, v)
}

func (ce *callEngine) pushValues(v []uint64) {
ce.stack = append(ce.stack, v...)
}

func (ce *callEngine) popValue() (v uint64) {
// No need to check stack bound
// as we can assume that all the operations
Expand All @@ -129,6 +133,12 @@ func (ce *callEngine) popValue() (v uint64) {
return
}

func (ce *callEngine) popValues(v []uint64) {
stackTopIndex := len(ce.stack) - len(v)
copy(v, ce.stack[stackTopIndex:])
ce.stack = ce.stack[:stackTopIndex]
}

// peekValues peeks api.ValueType values from the stack and returns them.
func (ce *callEngine) peekValues(count int) []uint64 {
if count == 0 {
Expand Down Expand Up @@ -445,10 +455,24 @@ func (ce *callEngine) Definition() api.FunctionDefinition {

// Call implements the same method as documented on api.Function.
func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uint64, err error) {
return ce.call(ctx, ce.compiled, params)
ft := ce.compiled.funcType
if n := ft.ParamNumInUint64; n != len(params) {
return nil, fmt.Errorf("expected %d params, but passed %d", n, len(params))
}
return ce.call(ctx, params, nil)
}

// CallWithStack implements the same method as documented on api.Function.
func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error {
params, results, err := wasm.SplitCallStack(ce.compiled.funcType, stack)
if err != nil {
return err
}
_, err = ce.call(ctx, params, results)
return err
}

func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (results []uint64, err error) {
func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) {
m := ce.compiled.moduleInstance
if ce.compiled.parent.ensureTermination {
select {
Expand All @@ -461,13 +485,6 @@ func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (
}
}

ft := tf.funcType
paramSignature := ft.ParamNumInUint64
paramCount := len(params)
if paramSignature != paramCount {
return nil, fmt.Errorf("expected %d params, but passed %d", paramSignature, paramCount)
}

defer func() {
// If the module closed during the call, and the call didn't err for another reason, set an ExitError.
if err == nil {
Expand All @@ -480,22 +497,24 @@ func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (
}
}()

for _, param := range params {
ce.pushValue(param)
}
ce.pushValues(params)

if ce.compiled.parent.ensureTermination {
done := m.CloseModuleOnCanceledOrTimeout(ctx)
defer done()
}

ce.callFunction(ctx, m, tf)
ce.callFunction(ctx, m, ce.compiled)

// This returns a safe copy of the results, instead of a slice view. If we
// returned a re-slice, the caller could accidentally or purposefully
// corrupt the stack of subsequent calls.
results = wasm.PopValues(ft.ResultNumInUint64, ce.popValue)
return
ft := ce.compiled.funcType
if results == nil && ft.ResultNumInUint64 > 0 {
results = make([]uint64, ft.ResultNumInUint64)
}
ce.popValues(results)
return results, nil
}

// recoverOnCall takes the recovered value `recoverOnCall`, and wraps it
Expand Down
9 changes: 9 additions & 0 deletions internal/engine/interpreter/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ func TestInterpreter_ModuleEngine_Call(t *testing.T) {
`, "\n"+functionLog.String())
}

func TestCompiler_ModuleEngine_CallWithStack(t *testing.T) {
defer functionLog.Reset()
enginetest.RunTestModuleEngineCallWithStack(t, et)
require.Equal(t, `
--> .$0(1,2)
<-- (1,2)
`, "\n"+functionLog.String())
}

func TestInterpreter_ModuleEngine_Call_HostFn(t *testing.T) {
defer functionLog.Reset()
enginetest.RunTestModuleEngineCallHostFn(t, et)
Expand Down
17 changes: 17 additions & 0 deletions internal/integration_test/bench/hostfunc_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ func BenchmarkHostFunctionCall(b *testing.B) {
}
}
})

b.Run(fn+"_with_stack", func(b *testing.B) {
ce := getCallEngine(m, fn)

b.ResetTimer()
stack := make([]uint64, 1)
for i := 0; i < b.N; i++ {
stack[0] = offset
err := ce.CallWithStack(testCtx, stack)
if err != nil {
b.Fatal(err)
}
if uint32(stack[0]) != math.Float32bits(val) {
b.Fail()
}
}
})
}
}

Expand Down
52 changes: 52 additions & 0 deletions internal/testing/enginetest/enginetest.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,58 @@ func RunTestModuleEngineCall(t *testing.T, et EngineTester) {
})
}

func RunTestModuleEngineCallWithStack(t *testing.T, et EngineTester) {
e := et.NewEngine(api.CoreFeaturesV2)

// Define a basic function which defines two parameters and two results.
// This is used to test results when incorrect arity is used.
m := &wasm.Module{
TypeSection: []wasm.FunctionType{
{
Params: []wasm.ValueType{i64, i64},
Results: []wasm.ValueType{i64, i64},
ParamNumInUint64: 2,
ResultNumInUint64: 2,
},
},
FunctionSection: []wasm.Index{0},
CodeSection: []wasm.Code{
{Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeLocalGet, 1, wasm.OpcodeEnd}},
},
}

m.BuildFunctionDefinitions()
listeners := buildListeners(et.ListenerFactory(), m)
err := e.CompileModule(testCtx, m, listeners, false)
require.NoError(t, err)

// To use the function, we first need to add it to a module.
module := &wasm.ModuleInstance{
ModuleName: t.Name(), TypeIDs: []wasm.FunctionTypeID{0},
Definitions: m.FunctionDefinitionSection,
}

// Compile the module
me, err := e.NewModuleEngine(m, module)
require.NoError(t, err)
linkModuleToEngine(module, me)

// Ensure the base case doesn't fail: A single parameter should work as that matches the function signature.
const funcIndex = 0
ce := me.NewFunction(funcIndex)

stack := []uint64{1, 2}
err = ce.CallWithStack(testCtx, stack)
require.NoError(t, err)
require.Equal(t, []uint64{1, 2}, stack)

t.Run("errs when not enough parameters", func(t *testing.T) {
ce := me.NewFunction(funcIndex)
err = ce.CallWithStack(testCtx, nil)
require.EqualError(t, err, "need 2 params, but stack size is 0")
})
}

func RunTestModuleEngineLookupFunction(t *testing.T, et EngineTester) {
e := et.NewEngine(api.CoreFeaturesV1)

Expand Down
20 changes: 0 additions & 20 deletions internal/wasm/gofunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,6 @@ func (f *reflectGoFunction) Call(ctx context.Context, stack []uint64) {
callGoFunc(ctx, nil, f.fn, stack)
}

// PopValues pops the specified number of api.ValueType parameters off the
// stack into a parameter slice for use in api.GoFunction or api.GoModuleFunction.
//
// For example, if the host function F requires the (x1 uint32, x2 float32)
// parameters, and the stack is [..., A, B], then the function is called as
// F(A, B) where A and B are interpreted as uint32 and float32 respectively.
//
// Note: the popper intentionally doesn't return bool or error because the
// caller's stack depth is trusted.
func PopValues(count int, popper func() uint64) []uint64 {
if count == 0 {
return nil
}
params := make([]uint64, count)
for i := count - 1; i >= 0; i-- {
params[i] = popper()
}
return params
}

// callGoFunc executes the reflective function by converting params to Go
// types. The results of the function call are converted back to api.ValueType.
func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, stack []uint64) {
Expand Down
Loading

0 comments on commit 77e8d72

Please sign in to comment.