Skip to content

Commit

Permalink
wasm: move implementation details into WasmVm abstraction layer. (env…
Browse files Browse the repository at this point in the history
…oyproxy#47)

Signed-off-by: Piotr Sikora <[email protected]>
  • Loading branch information
PiotrSikora authored Mar 22, 2019
1 parent 712f43f commit 0e61789
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 56 deletions.
13 changes: 6 additions & 7 deletions source/extensions/common/wasm/wasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ Wasm::Wasm(absl::string_view vm, absl::string_view id, absl::string_view initial
}

void Wasm::registerCallbacks() {
#define _REGISTER(_fn) registerCallback(wasm_vm_.get(), "envoy", #_fn, &_fn##Handler);
#define _REGISTER(_fn) wasm_vm_->registerCallback("envoy", #_fn, &_fn##Handler);
if (is_emscripten_) {
_REGISTER(getTotalMemory);
_REGISTER(_emscripten_get_heap_size);
Expand All @@ -962,8 +962,7 @@ void Wasm::registerCallbacks() {
#undef _REGISTER

// Calls with the "_proxy_" prefix.
#define _REGISTER_PROXY(_fn) \
registerCallback(wasm_vm_.get(), "envoy", "_proxy_" #_fn, &_fn##Handler);
#define _REGISTER_PROXY(_fn) wasm_vm_->registerCallback("envoy", "_proxy_" #_fn, &_fn##Handler);
_REGISTER_PROXY(log);

_REGISTER_PROXY(getRequestStreamInfoProtocol);
Expand Down Expand Up @@ -1018,19 +1017,19 @@ void Wasm::registerCallbacks() {
void Wasm::establishEnvironment() {
if (is_emscripten_) {
wasm_vm_->makeModule("global");
emscripten_NaN_ = makeGlobal(wasm_vm_.get(), "global", "NaN", std::nan("0"));
emscripten_NaN_ = wasm_vm_->makeGlobal("global", "NaN", std::nan("0"));
emscripten_Infinity_ =
makeGlobal(wasm_vm_.get(), "global", "Infinity", std::numeric_limits<double>::infinity());
wasm_vm_->makeGlobal("global", "Infinity", std::numeric_limits<double>::infinity());
}
}

void Wasm::getFunctions() {
#define _GET(_fn) getFunction(wasm_vm_.get(), "_" #_fn, &_fn##_);
#define _GET(_fn) wasm_vm_->getFunction("_" #_fn, &_fn##_);
_GET(malloc);
_GET(free);
#undef _GET

#define _GET_PROXY(_fn) getFunction(wasm_vm_.get(), "_proxy_" #_fn, &_fn##_);
#define _GET_PROXY(_fn) wasm_vm_->getFunction("_proxy_" #_fn, &_fn##_);
_GET_PROXY(onStart);
_GET_PROXY(onConfigure);
_GET_PROXY(onTick);
Expand Down
107 changes: 58 additions & 49 deletions source/extensions/common/wasm/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,33 @@ class WasmVm;
using Pairs = std::vector<std::pair<absl::string_view, absl::string_view>>;
using PairsWithStringValues = std::vector<std::pair<absl::string_view, std::string>>;

// 1st arg is always a pointer to Context (Context*).
using WasmCall0Void = std::function<void(Context*)>;
using WasmCall1Void = std::function<void(Context*, uint32_t)>;
using WasmCall1Int = std::function<uint32_t(Context*, uint32_t)>;
using WasmCall2Void = std::function<void(Context*, uint32_t, uint32_t)>;

using WasmContextCall0Void = std::function<void(Context*, uint32_t context_id)>;
using WasmContextCall7Void = std::function<void(Context*, uint32_t context_id, uint32_t, uint32_t,
uint32_t, uint32_t, uint32_t, uint32_t, uint32_t)>;

using WasmContextCall0Int = std::function<uint32_t(Context*, uint32_t context_id)>;
using WasmContextCall2Int =
std::function<uint32_t(Context*, uint32_t context_id, uint32_t, uint32_t)>;
using WasmCall8Void = std::function<void(Context*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
uint32_t, uint32_t, uint32_t)>;
using WasmCall1Int = std::function<uint32_t(Context*, uint32_t)>;
using WasmCall3Int = std::function<uint32_t(Context*, uint32_t, uint32_t, uint32_t)>;

// 1st arg is always a context_id (uint32_t).
using WasmContextCall0Void = WasmCall1Void;
using WasmContextCall7Void = WasmCall8Void;
using WasmContextCall0Int = WasmCall1Int;
using WasmContextCall2Int = WasmCall3Int;

// 1st arg is always a pointer to raw_context (void*).
using WasmCallback0Void = void (*)(void*);
using WasmCallback1Void = void (*)(void*, uint32_t);
using WasmCallback2Void = void (*)(void*, uint32_t, uint32_t);
using WasmCallback3Void = void (*)(void*, uint32_t, uint32_t, uint32_t);
using WasmCallback4Void = void (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t);
using WasmCallback5Void = void (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t);
using WasmCallback0Int = uint32_t (*)(void*);
using WasmCallback3Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t);
using WasmCallback5Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t);
using WasmCallback9Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
uint32_t, uint32_t, uint32_t, uint32_t);

// A context which will be the target of callbacks for a particular session
// e.g. a handler of a stream.
Expand Down Expand Up @@ -435,6 +450,40 @@ class WasmVm : public Logger::Loggable<Logger::Id::wasm> {
// Get the contents of the user section with the given name or "" if it does not exist and
// optionally a presence indicator.
virtual absl::string_view getUserSection(absl::string_view name, bool* present = nullptr) PURE;

// Get typed function exported by the WASM module.
virtual void getFunction(absl::string_view functionName, WasmCall0Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall1Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall2Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall8Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall1Int* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall3Int* f) PURE;

// Register typed callbacks exported by the host environment.
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback0Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback1Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback2Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback3Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback4Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback5Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback0Int f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback3Int f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback5Int f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback9Int f) PURE;

// Register typed value exported by the host environment.
virtual std::unique_ptr<Global<double>>
makeGlobal(absl::string_view moduleName, absl::string_view name, double initialValue) PURE;
};

// Create a new low-level WASM VM of the give type (e.g. "envoy.wasm.vm.wavm").
Expand Down Expand Up @@ -466,46 +515,6 @@ class WasmVmException : public EnvoyException {

inline Context::Context(Wasm* wasm) : wasm_(wasm), id_(wasm->allocContextId()) {}

// Forward declarations for VM implemenations.
template <typename R, typename... Args>
void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*)(Args...));
template <typename R, typename... Args>
void getFunctionWavm(WasmVm* vm, absl::string_view functionName,
std::function<R(Context*, Args...)>*);

template <typename T>
std::unique_ptr<Global<T>> makeGlobalWavm(WasmVm* vm, absl::string_view moduleName,
absl::string_view name, T initialValue);

template <typename R, typename... Args>
void registerCallback(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*f)(Args...)) {
if (vm->vm() == WasmVmNames::get().Wavm) {
registerCallbackWavm(vm, moduleName, functionName, f);
} else {
throw WasmVmException("unsupported wasm vm");
}
}

template <typename F> void getFunction(WasmVm* vm, absl::string_view functionName, F* function) {
if (vm->vm() == WasmVmNames::get().Wavm) {
getFunctionWavm(vm, functionName, function);
} else {
throw WasmVmException("unsupported wasm vm");
}
}

template <typename T>
std::unique_ptr<Global<T>> makeGlobal(WasmVm* vm, absl::string_view moduleName,
absl::string_view name, T initialValue) {
if (vm->vm() == WasmVmNames::get().Wavm) {
return makeGlobalWavm(vm, moduleName, name, initialValue);
} else {
throw WasmVmException("unsupported wasm vm");
}
}

inline void* Wasm::allocMemory(uint32_t size, uint32_t* address) {
uint32_t a = malloc_(generalContext(), size);
*address = a;
Expand Down
45 changes: 45 additions & 0 deletions source/extensions/common/wasm/wavm/wavm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ namespace Wasm {

extern thread_local Envoy::Extensions::Common::Wasm::Context* current_context_;

// Forward declarations.
template <typename R, typename... Args>
void getFunctionWavm(WasmVm* vm, absl::string_view functionName,
std::function<R(Context*, Args...)>* function);
template <typename R, typename... Args>
void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*)(Args...));
template <typename T>
std::unique_ptr<Global<T>> makeGlobalWavm(WasmVm* vm, absl::string_view moduleName,
absl::string_view name, T initialValue);

namespace Wavm {

struct Wavm;
Expand Down Expand Up @@ -221,6 +232,40 @@ struct Wavm : public WasmVm {

void getInstantiatedGlobals();

#define _GET_FUNCTION(_type) \
void getFunction(absl::string_view functionName, _type* f) override { \
getFunctionWavm(this, functionName, f); \
};
_GET_FUNCTION(WasmCall0Void);
_GET_FUNCTION(WasmCall1Void);
_GET_FUNCTION(WasmCall2Void);
_GET_FUNCTION(WasmCall8Void);
_GET_FUNCTION(WasmCall1Int);
_GET_FUNCTION(WasmCall3Int);
#undef _GET_FUNCTION

#define _REGISTER_CALLBACK(_type) \
void registerCallback(absl::string_view moduleName, absl::string_view functionName, \
_type f) override { \
registerCallbackWavm(this, moduleName, functionName, f); \
};
_REGISTER_CALLBACK(WasmCallback0Void);
_REGISTER_CALLBACK(WasmCallback1Void);
_REGISTER_CALLBACK(WasmCallback2Void);
_REGISTER_CALLBACK(WasmCallback3Void);
_REGISTER_CALLBACK(WasmCallback4Void);
_REGISTER_CALLBACK(WasmCallback5Void);
_REGISTER_CALLBACK(WasmCallback0Int);
_REGISTER_CALLBACK(WasmCallback3Int);
_REGISTER_CALLBACK(WasmCallback5Int);
_REGISTER_CALLBACK(WasmCallback9Int);
#undef _REGISTER_CALLBACK

std::unique_ptr<Global<double>> makeGlobal(absl::string_view moduleName, absl::string_view name,
double initialValue) override {
return makeGlobalWavm(this, moduleName, name, initialValue);
};

bool hasInstantiatedModule_ = false;
IR::Module irModule_;
WAVM::Runtime::ModuleRef module_ = nullptr;
Expand Down

0 comments on commit 0e61789

Please sign in to comment.