diff --git a/api/wasm.go b/api/wasm.go index ffab892866..38728b4273 100644 --- a/api/wasm.go +++ b/api/wasm.go @@ -366,6 +366,34 @@ type Function interface { // the end-to-end demonstrations of how these terminations can be performed. Call(ctx context.Context, params ...uint64) ([]uint64, error) + // CallWithStack is an optimized variation of Call that saves memory + // allocations when the stack slice is reused across calls. + // + // Stack length must be at least the max of parameter or result length. + // The caller adds parameters in order to the stack, and reads any results + // in order from the stack, except in the error case. + // + // For example, the following reuses the same stack slice to call searchFn + // repeatedly saving one allocation per iteration: + // + // stack := make([]uint64, 4) + // for i, search := range searchParams { + // // copy the next params to the stack + // copy(stack, search) + // if err := searchFn.CallWithStack(ctx, stack); err != nil { + // return err + // } else if stack[0] == 1 { // found + // return i // searchParams[i] matched! + // } + // } + // + // # Notes + // + // - This is similar to GoModuleFunction, except for using calling functions + // instead of implementing them. Moreover, this is used regardless of + // whether the callee is a host or wasm defined function. + CallWithStack(ctx context.Context, stack []uint64) error + internalapi.WazeroOnly } diff --git a/internal/emscripten/emscripten.go b/internal/emscripten/emscripten.go index b014d09cb6..6cda874c8c 100644 --- a/internal/emscripten/emscripten.go +++ b/internal/emscripten/emscripten.go @@ -76,8 +76,11 @@ func (v *InvokeFunc) Call(ctx context.Context, mod api.Module, stack []uint64) { panic(err) } + // This needs copy (not reslice) because the stack is reused for results. + // Consider invoke_i (zero arguments, one result): index zero (tableOffset) + // is needed to store the result. tableOffset := wasm.Index(stack[0]) // position in the module's only table. - params := stack[1:] // parameters to the dynamic function being called + copy(stack, stack[1:]) // pop the tableOffset. // Lookup the table index we will call. t := m.Tables[0] // Note: Emscripten doesn't use multiple tables @@ -86,10 +89,8 @@ func (v *InvokeFunc) Call(ctx context.Context, mod api.Module, stack []uint64) { panic(err) } - ret, err := m.Engine.NewFunction(idx).Call(ctx, params...) + err = m.Engine.NewFunction(idx).CallWithStack(ctx, stack) if err != nil { panic(err) } - // if there are any results, copy them back to the stack - copy(stack, ret) } diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index 24da7889f5..b6f1104a66 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -705,6 +705,24 @@ func (ce *callEngine) Definition() api.FunctionDefinition { // Call implements the same method as documented on wasm.ModuleEngine. func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uint64, err error) { + ft := ce.initialFn.funcType + if n := ft.ParamNumInUint64; n != len(params) { + return nil, fmt.Errorf("expected %d params, but passed %d", n, len(params)) + } + return ce.call(ctx, params, nil) +} + +// CallWithStack implements the same method as documented on wasm.ModuleEngine. +func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error { + params, results, err := wasm.SplitCallStack(ce.initialFn.funcType, stack) + if err != nil { + return err + } + _, err = ce.call(ctx, params, results) + return err +} + +func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) { m := ce.initialFn.moduleInstance if ce.ensureTermination { select { @@ -717,13 +735,6 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin } } - tp := ce.initialFn.funcType - - paramCount := len(params) - if tp.ParamNumInUint64 != paramCount { - return nil, fmt.Errorf("expected %d params, but passed %d", ce.initialFn.funcType.ParamNumInUint64, paramCount) - } - // We ensure that this Call method never panics as // this Call method is indirectly invoked by embedders via store.CallFunction, // and we have to make sure that all the runtime errors, including the one happening inside @@ -736,7 +747,8 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin } }() - ce.initializeStack(tp, params) + ft := ce.initialFn.funcType + ce.initializeStack(ft, params) if ce.ensureTermination { done := m.CloseModuleOnCanceledOrTimeout(ctx) @@ -747,12 +759,12 @@ func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uin // This returns a safe copy of the results, instead of a slice view. If we // returned a re-slice, the caller could accidentally or purposefully - // corrupt the stack of subsequent calls - if resultCount := tp.ResultNumInUint64; resultCount > 0 { - results = make([]uint64, resultCount) - copy(results, ce.stack[:resultCount]) + // corrupt the stack of subsequent calls. + if results == nil && ft.ResultNumInUint64 > 0 { + results = make([]uint64, ft.ResultNumInUint64) } - return + copy(results, ce.stack) + return results, nil } // initializeStack initializes callEngine.stack before entering native code. diff --git a/internal/engine/compiler/engine_test.go b/internal/engine/compiler/engine_test.go index b61442b000..e624ebe94f 100644 --- a/internal/engine/compiler/engine_test.go +++ b/internal/engine/compiler/engine_test.go @@ -67,6 +67,16 @@ func TestCompiler_ModuleEngine_Call(t *testing.T) { `, "\n"+functionLog.String()) } +func TestCompiler_ModuleEngine_CallWithStack(t *testing.T) { + defer functionLog.Reset() + requireSupportedOSArch(t) + enginetest.RunTestModuleEngineCallWithStack(t, et) + require.Equal(t, ` +--> .$0(1,2) +<-- (1,2) +`, "\n"+functionLog.String()) +} + func TestCompiler_ModuleEngine_Call_HostFn(t *testing.T) { defer functionLog.Reset() requireSupportedOSArch(t) diff --git a/internal/engine/interpreter/interpreter.go b/internal/engine/interpreter/interpreter.go index ac8e055348..d10f257151 100644 --- a/internal/engine/interpreter/interpreter.go +++ b/internal/engine/interpreter/interpreter.go @@ -116,6 +116,10 @@ func (ce *callEngine) pushValue(v uint64) { ce.stack = append(ce.stack, v) } +func (ce *callEngine) pushValues(v []uint64) { + ce.stack = append(ce.stack, v...) +} + func (ce *callEngine) popValue() (v uint64) { // No need to check stack bound // as we can assume that all the operations @@ -129,6 +133,12 @@ func (ce *callEngine) popValue() (v uint64) { return } +func (ce *callEngine) popValues(v []uint64) { + stackTopIndex := len(ce.stack) - len(v) + copy(v, ce.stack[stackTopIndex:]) + ce.stack = ce.stack[:stackTopIndex] +} + // peekValues peeks api.ValueType values from the stack and returns them. func (ce *callEngine) peekValues(count int) []uint64 { if count == 0 { @@ -445,10 +455,24 @@ func (ce *callEngine) Definition() api.FunctionDefinition { // Call implements the same method as documented on api.Function. func (ce *callEngine) Call(ctx context.Context, params ...uint64) (results []uint64, err error) { - return ce.call(ctx, ce.compiled, params) + ft := ce.compiled.funcType + if n := ft.ParamNumInUint64; n != len(params) { + return nil, fmt.Errorf("expected %d params, but passed %d", n, len(params)) + } + return ce.call(ctx, params, nil) +} + +// CallWithStack implements the same method as documented on api.Function. +func (ce *callEngine) CallWithStack(ctx context.Context, stack []uint64) error { + params, results, err := wasm.SplitCallStack(ce.compiled.funcType, stack) + if err != nil { + return err + } + _, err = ce.call(ctx, params, results) + return err } -func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) (results []uint64, err error) { +func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []uint64, err error) { m := ce.compiled.moduleInstance if ce.compiled.parent.ensureTermination { select { @@ -461,13 +485,6 @@ func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) ( } } - ft := tf.funcType - paramSignature := ft.ParamNumInUint64 - paramCount := len(params) - if paramSignature != paramCount { - return nil, fmt.Errorf("expected %d params, but passed %d", paramSignature, paramCount) - } - defer func() { // If the module closed during the call, and the call didn't err for another reason, set an ExitError. if err == nil { @@ -480,22 +497,24 @@ func (ce *callEngine) call(ctx context.Context, tf *function, params []uint64) ( } }() - for _, param := range params { - ce.pushValue(param) - } + ce.pushValues(params) if ce.compiled.parent.ensureTermination { done := m.CloseModuleOnCanceledOrTimeout(ctx) defer done() } - ce.callFunction(ctx, m, tf) + ce.callFunction(ctx, m, ce.compiled) // This returns a safe copy of the results, instead of a slice view. If we // returned a re-slice, the caller could accidentally or purposefully // corrupt the stack of subsequent calls. - results = wasm.PopValues(ft.ResultNumInUint64, ce.popValue) - return + ft := ce.compiled.funcType + if results == nil && ft.ResultNumInUint64 > 0 { + results = make([]uint64, ft.ResultNumInUint64) + } + ce.popValues(results) + return results, nil } // recoverOnCall takes the recovered value `recoverOnCall`, and wraps it diff --git a/internal/engine/interpreter/interpreter_test.go b/internal/engine/interpreter/interpreter_test.go index 65723e64d0..7f03b9e039 100644 --- a/internal/engine/interpreter/interpreter_test.go +++ b/internal/engine/interpreter/interpreter_test.go @@ -105,6 +105,15 @@ func TestInterpreter_ModuleEngine_Call(t *testing.T) { `, "\n"+functionLog.String()) } +func TestCompiler_ModuleEngine_CallWithStack(t *testing.T) { + defer functionLog.Reset() + enginetest.RunTestModuleEngineCallWithStack(t, et) + require.Equal(t, ` +--> .$0(1,2) +<-- (1,2) +`, "\n"+functionLog.String()) +} + func TestInterpreter_ModuleEngine_Call_HostFn(t *testing.T) { defer functionLog.Reset() enginetest.RunTestModuleEngineCallHostFn(t, et) diff --git a/internal/integration_test/bench/hostfunc_bench_test.go b/internal/integration_test/bench/hostfunc_bench_test.go index fcd5f795fc..e69c4bad49 100644 --- a/internal/integration_test/bench/hostfunc_bench_test.go +++ b/internal/integration_test/bench/hostfunc_bench_test.go @@ -58,6 +58,23 @@ func BenchmarkHostFunctionCall(b *testing.B) { } } }) + + b.Run(fn+"_with_stack", func(b *testing.B) { + ce := getCallEngine(m, fn) + + b.ResetTimer() + stack := make([]uint64, 1) + for i := 0; i < b.N; i++ { + stack[0] = offset + err := ce.CallWithStack(testCtx, stack) + if err != nil { + b.Fatal(err) + } + if uint32(stack[0]) != math.Float32bits(val) { + b.Fail() + } + } + }) } } diff --git a/internal/testing/enginetest/enginetest.go b/internal/testing/enginetest/enginetest.go index b3300bf6e3..d443bb37bb 100644 --- a/internal/testing/enginetest/enginetest.go +++ b/internal/testing/enginetest/enginetest.go @@ -191,6 +191,58 @@ func RunTestModuleEngineCall(t *testing.T, et EngineTester) { }) } +func RunTestModuleEngineCallWithStack(t *testing.T, et EngineTester) { + e := et.NewEngine(api.CoreFeaturesV2) + + // Define a basic function which defines two parameters and two results. + // This is used to test results when incorrect arity is used. + m := &wasm.Module{ + TypeSection: []wasm.FunctionType{ + { + Params: []wasm.ValueType{i64, i64}, + Results: []wasm.ValueType{i64, i64}, + ParamNumInUint64: 2, + ResultNumInUint64: 2, + }, + }, + FunctionSection: []wasm.Index{0}, + CodeSection: []wasm.Code{ + {Body: []byte{wasm.OpcodeLocalGet, 0, wasm.OpcodeLocalGet, 1, wasm.OpcodeEnd}}, + }, + } + + m.BuildFunctionDefinitions() + listeners := buildListeners(et.ListenerFactory(), m) + err := e.CompileModule(testCtx, m, listeners, false) + require.NoError(t, err) + + // To use the function, we first need to add it to a module. + module := &wasm.ModuleInstance{ + ModuleName: t.Name(), TypeIDs: []wasm.FunctionTypeID{0}, + Definitions: m.FunctionDefinitionSection, + } + + // Compile the module + me, err := e.NewModuleEngine(m, module) + require.NoError(t, err) + linkModuleToEngine(module, me) + + // Ensure the base case doesn't fail: A single parameter should work as that matches the function signature. + const funcIndex = 0 + ce := me.NewFunction(funcIndex) + + stack := []uint64{1, 2} + err = ce.CallWithStack(testCtx, stack) + require.NoError(t, err) + require.Equal(t, []uint64{1, 2}, stack) + + t.Run("errs when not enough parameters", func(t *testing.T) { + ce := me.NewFunction(funcIndex) + err = ce.CallWithStack(testCtx, nil) + require.EqualError(t, err, "need 2 params, but stack size is 0") + }) +} + func RunTestModuleEngineLookupFunction(t *testing.T, et EngineTester) { e := et.NewEngine(api.CoreFeaturesV1) diff --git a/internal/wasm/gofunc.go b/internal/wasm/gofunc.go index 6f77a831cf..9510c2588e 100644 --- a/internal/wasm/gofunc.go +++ b/internal/wasm/gofunc.go @@ -79,26 +79,6 @@ func (f *reflectGoFunction) Call(ctx context.Context, stack []uint64) { callGoFunc(ctx, nil, f.fn, stack) } -// PopValues pops the specified number of api.ValueType parameters off the -// stack into a parameter slice for use in api.GoFunction or api.GoModuleFunction. -// -// For example, if the host function F requires the (x1 uint32, x2 float32) -// parameters, and the stack is [..., A, B], then the function is called as -// F(A, B) where A and B are interpreted as uint32 and float32 respectively. -// -// Note: the popper intentionally doesn't return bool or error because the -// caller's stack depth is trusted. -func PopValues(count int, popper func() uint64) []uint64 { - if count == 0 { - return nil - } - params := make([]uint64, count) - for i := count - 1; i >= 0; i-- { - params[i] = popper() - } - return params -} - // callGoFunc executes the reflective function by converting params to Go // types. The results of the function call are converted back to api.ValueType. func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, stack []uint64) { diff --git a/internal/wasm/gofunc_test.go b/internal/wasm/gofunc_test.go index 64269bf846..8b2e8b8c26 100644 --- a/internal/wasm/gofunc_test.go +++ b/internal/wasm/gofunc_test.go @@ -124,55 +124,6 @@ func Test_parseGoFunc_Errors(t *testing.T) { } } -// stack simulates the value stack in a way easy to be tested. -type stack struct { - vals []uint64 -} - -func (s *stack) pop() (result uint64) { - stackTopIndex := len(s.vals) - 1 - result = s.vals[stackTopIndex] - s.vals = s.vals[:stackTopIndex] - return -} - -func TestPopValues(t *testing.T) { - stackVals := []uint64{1, 2, 3, 4, 5, 6, 7} - tests := []struct { - name string - count int - expected []uint64 - }{ - { - name: "pop zero doesn't allocate a slice ", - }, - { - name: "pop 1", - count: 1, - expected: []uint64{7}, - }, - { - name: "pop 2", - count: 2, - expected: []uint64{6, 7}, - }, - { - name: "pop 3", - count: 3, - expected: []uint64{5, 6, 7}, - }, - } - - for _, tt := range tests { - tc := tt - - t.Run(tc.name, func(t *testing.T) { - vals := PopValues(tc.count, (&stack{stackVals}).pop) - require.Equal(t, tc.expected, vals) - }) - } -} - func Test_callGoFunc(t *testing.T) { tPtr := uintptr(unsafe.Pointer(t)) inst := &ModuleInstance{} diff --git a/internal/wasm/store_test.go b/internal/wasm/store_test.go index 6d736b404b..135f688c6f 100644 --- a/internal/wasm/store_test.go +++ b/internal/wasm/store_test.go @@ -488,12 +488,16 @@ func (e *mockModuleEngine) Close(context.Context) { func (ce *mockCallEngine) Definition() api.FunctionDefinition { return nil } // Call implements the same method as documented on api.Function. -func (ce *mockCallEngine) Call(_ context.Context, _ ...uint64) (results []uint64, err error) { +func (ce *mockCallEngine) Call(ctx context.Context, _ ...uint64) (results []uint64, err error) { + return nil, ce.CallWithStack(ctx, nil) +} + +// CallWithStack implements the same method as documented on api.Function. +func (ce *mockCallEngine) CallWithStack(_ context.Context, _ []uint64) error { if ce.callFailIndex >= 0 && ce.index == Index(ce.callFailIndex) { - err = errors.New("call failed") - return + return errors.New("call failed") } - return + return nil } func TestStore_getFunctionTypeID(t *testing.T) {