Skip to content

Commit

Permalink
Do not call into the VM unless the VM Context has been created. (isti…
Browse files Browse the repository at this point in the history
…o#214)

Signed-off-by: John Plevyak <[email protected]>

Co-authored-by: John Plevyak <[email protected]>
  • Loading branch information
PiotrSikora and jplevyak authored May 13, 2020
1 parent 7f79e45 commit 6144102
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
28 changes: 15 additions & 13 deletions source/extensions/common/wasm/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ bool Context::onStart(absl::string_view vm_configuration, PluginSharedPtr plugin
wasm_->on_context_create_(this, id_, 0);
plugin_.reset();
}
in_vm_context_created_ = true;
if (wasm_->on_vm_start_) {
configuration_ = vm_configuration;
plugin_ = plugin;
Expand Down Expand Up @@ -1192,6 +1193,7 @@ void Context::onCreate(uint32_t parent_context_id) {
Network::FilterStatus Context::onNetworkNewConnection() {
DeferAfterCallActions actions(this);
onCreate(root_context_id_);
in_vm_context_created_ = true;
if (!wasm_->on_new_connection_) {
return Network::FilterStatus::Continue;
}
Expand All @@ -1202,7 +1204,7 @@ Network::FilterStatus Context::onNetworkNewConnection() {
}

Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_stream) {
if (!wasm_->on_downstream_data_) {
if (!in_vm_context_created_ || !wasm_->on_downstream_data_) {
return Network::FilterStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1214,7 +1216,7 @@ Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_str
}

Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_stream) {
if (!wasm_->on_upstream_data_) {
if (!in_vm_context_created_ || !wasm_->on_upstream_data_) {
return Network::FilterStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1226,7 +1228,7 @@ Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_strea
}

void Context::onDownstreamConnectionClose(PeerType peer_type) {
if (wasm_->on_downstream_connection_close_) {
if (in_vm_context_created_ && wasm_->on_downstream_connection_close_) {
DeferAfterCallActions actions(this);
wasm_->on_downstream_connection_close_(this, id_, static_cast<uint32_t>(peer_type));
}
Expand All @@ -1239,7 +1241,7 @@ void Context::onDownstreamConnectionClose(PeerType peer_type) {
}

void Context::onUpstreamConnectionClose(PeerType peer_type) {
if (wasm_->on_upstream_connection_close_) {
if (in_vm_context_created_ && wasm_->on_upstream_connection_close_) {
DeferAfterCallActions actions(this);
wasm_->on_upstream_connection_close_(this, id_, static_cast<uint32_t>(peer_type));
}
Expand All @@ -1266,7 +1268,7 @@ Http::FilterHeadersStatus Context::onRequestHeaders() {
}

Http::FilterDataStatus Context::onRequestBody(bool end_of_stream) {
if (!wasm_->on_request_body_) {
if (!in_vm_context_created_ || !wasm_->on_request_body_) {
return Http::FilterDataStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1292,7 +1294,7 @@ Http::FilterDataStatus Context::onRequestBody(bool end_of_stream) {
}

Http::FilterTrailersStatus Context::onRequestTrailers() {
if (!wasm_->on_request_trailers_) {
if (!in_vm_context_created_ || !wasm_->on_request_trailers_) {
return Http::FilterTrailersStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1303,7 +1305,7 @@ Http::FilterTrailersStatus Context::onRequestTrailers() {
}

Http::FilterMetadataStatus Context::onRequestMetadata() {
if (!wasm_->on_request_metadata_) {
if (!in_vm_context_created_ || !wasm_->on_request_metadata_) {
return Http::FilterMetadataStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand Down Expand Up @@ -1333,7 +1335,7 @@ Http::FilterHeadersStatus Context::onResponseHeaders() {
}

Http::FilterDataStatus Context::onResponseBody(bool end_of_stream) {
if (!wasm_->on_response_body_) {
if (!in_vm_context_created_ || !wasm_->on_response_body_) {
return Http::FilterDataStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1359,7 +1361,7 @@ Http::FilterDataStatus Context::onResponseBody(bool end_of_stream) {
}

Http::FilterTrailersStatus Context::onResponseTrailers() {
if (!wasm_->on_response_trailers_) {
if (!in_vm_context_created_ || !wasm_->on_response_trailers_) {
return Http::FilterTrailersStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1370,7 +1372,7 @@ Http::FilterTrailersStatus Context::onResponseTrailers() {
}

Http::FilterMetadataStatus Context::onResponseMetadata() {
if (!wasm_->on_response_metadata_) {
if (!in_vm_context_created_ || !wasm_->on_response_metadata_) {
return Http::FilterMetadataStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand Down Expand Up @@ -1637,22 +1639,22 @@ void Context::onDestroy() {

bool Context::onDone() {
DeferAfterCallActions actions(this);
if (wasm_->on_done_) {
if (in_vm_context_created_ && wasm_->on_done_) {
return wasm_->on_done_(this, id_).u64_ != 0;
}
return true;
}

void Context::onLog() {
DeferAfterCallActions actions(this);
if (wasm_->on_log_) {
if (in_vm_context_created_ && wasm_->on_log_) {
wasm_->on_log_(this, id_);
}
}

void Context::onDelete() {
DeferAfterCallActions actions(this);
if (wasm_->on_delete_) {
if (in_vm_context_created_ && wasm_->on_delete_) {
wasm_->on_delete_(this, id_);
}
}
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/common/wasm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ class Context : public Logger::Loggable<Logger::Id::wasm>,

void addAfterVmCallAction(std::function<void()> f);

void setInVmContextCreatedForTesting() { in_vm_context_created_ = true; }

protected:
friend class Wasm;

Expand Down
3 changes: 2 additions & 1 deletion test/extensions/wasm/wasm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ TEST_P(WasmTest, DivByZero) {
auto context = std::make_unique<TestContext>(wasm.get());
EXPECT_CALL(*context, scriptLog_(spdlog::level::err, Eq("before div by zero")));
EXPECT_TRUE(wasm->initialize(code, false));
wasm->setContext(context.get());
context->setInVmContextCreatedForTesting();

if (GetParam() == "v8") {
EXPECT_THROW_WITH_MESSAGE(
Expand Down Expand Up @@ -401,6 +401,7 @@ TEST_P(WasmTest, StatsHighLevel) {
"{{ test_rundir }}/test/extensions/wasm/test_data/stats_cpp.wasm"));
EXPECT_FALSE(code.empty());
auto context = std::make_unique<TestContext>(wasm.get());
context->setInVmContextCreatedForTesting();

EXPECT_CALL(*context, scriptLog_(spdlog::level::trace, Eq("get counter = 1")));
EXPECT_CALL(*context, scriptLog_(spdlog::level::debug, Eq("get counter = 2")));
Expand Down

0 comments on commit 6144102

Please sign in to comment.