Skip to content

Commit

Permalink
Delay sendLocalReply till after the call to prevent rentrant calls. (e…
Browse files Browse the repository at this point in the history
…nvoyproxy#432)

Signed-off-by: John Plevyak <[email protected]>
  • Loading branch information
jplevyak authored Feb 26, 2020
1 parent b238eb4 commit 7d648e4
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 2 deletions.
39 changes: 39 additions & 0 deletions source/extensions/common/wasm/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ namespace Wasm {

namespace {

class DeferAfterCallActions {
public:
DeferAfterCallActions(Context* context) : wasm_(context->wasm()) {}
~DeferAfterCallActions() { wasm_->doAfterVmCallActions(); }

private:
Wasm* const wasm_;
};

using HashPolicy = envoy::config::route::v3::RouteAction::HashPolicy;

class SharedData {
Expand Down Expand Up @@ -265,6 +274,8 @@ std::string Context::makeRootLogPrefix(absl::string_view vm_id) const {
WasmVm* Context::wasmVm() const { return wasm_->wasm_vm(); }
Upstream::ClusterManager& Context::clusterManager() const { return wasm_->clusterManager(); }

void Context::addAfterVmCallAction(std::function<void()> f) { wasm_->addAfterVmCallAction(f); }

WasmResult Context::setTickPeriod(std::chrono::milliseconds tick_period) {
wasm_->setTickPeriod(root_context_id_ ? root_context_id_ : id_, tick_period);
return WasmResult::Ok;
Expand Down Expand Up @@ -1059,6 +1070,7 @@ bool Context::isSsl() { return decoder_callbacks_->connection()->ssl() != nullpt
// Calls into the WASM code.
//
bool Context::onStart(absl::string_view vm_configuration, PluginSharedPtr plugin) {
DeferAfterCallActions actions(this);
bool result = 0;
if (wasm_->on_context_create_) {
plugin_ = plugin;
Expand Down Expand Up @@ -1094,6 +1106,7 @@ bool Context::onConfigure(absl::string_view plugin_configuration, PluginSharedPt
if (!wasm_->on_configure_) {
return true;
}
DeferAfterCallActions actions(this);
configuration_ = plugin_configuration;
plugin_ = plugin;
auto result =
Expand All @@ -1111,17 +1124,20 @@ std::pair<uint32_t, absl::string_view> Context::getStatus() {

void Context::onTick() {
if (wasm_->on_tick_) {
DeferAfterCallActions actions(this);
wasm_->on_tick_(this, id_);
}
}

void Context::onCreate(uint32_t parent_context_id) {
if (wasm_->on_context_create_) {
DeferAfterCallActions actions(this);
wasm_->on_context_create_(this, id_, parent_context_id);
}
}

Network::FilterStatus Context::onNetworkNewConnection() {
DeferAfterCallActions actions(this);
onCreate(root_context_id_);
if (!wasm_->on_new_connection_) {
return Network::FilterStatus::Continue;
Expand All @@ -1136,6 +1152,7 @@ Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_str
if (!wasm_->on_downstream_data_) {
return Network::FilterStatus::Continue;
}
DeferAfterCallActions actions(this);
end_of_stream_ = end_of_stream;
auto result = wasm_->on_downstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream));
Expand All @@ -1147,6 +1164,7 @@ Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_strea
if (!wasm_->on_upstream_data_) {
return Network::FilterStatus::Continue;
}
DeferAfterCallActions actions(this);
end_of_stream_ = end_of_stream;
auto result = wasm_->on_upstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream));
Expand All @@ -1156,12 +1174,14 @@ Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_strea

void Context::onDownstreamConnectionClose(PeerType peer_type) {
if (wasm_->on_downstream_connection_close_) {
DeferAfterCallActions actions(this);
wasm_->on_downstream_connection_close_(this, id_, static_cast<uint32_t>(peer_type));
}
}

void Context::onUpstreamConnectionClose(PeerType peer_type) {
if (wasm_->on_upstream_connection_close_) {
DeferAfterCallActions actions(this);
wasm_->on_upstream_connection_close_(this, id_, static_cast<uint32_t>(peer_type));
}
}
Expand All @@ -1170,6 +1190,7 @@ void Context::onUpstreamConnectionClose(PeerType peer_type) {
template <typename P> static uint32_t headerSize(const P& p) { return p ? p->size() : 0; }

Http::FilterHeadersStatus Context::onRequestHeaders() {
DeferAfterCallActions actions(this);
onCreate(root_context_id_);
in_vm_context_created_ = true;
if (!wasm_->on_request_headers_) {
Expand All @@ -1185,6 +1206,7 @@ Http::FilterDataStatus Context::onRequestBody(int body_buffer_length, bool end_o
if (!wasm_->on_request_body_) {
return Http::FilterDataStatus::Continue;
}
DeferAfterCallActions actions(this);
switch (wasm_
->on_request_body_(this, id_, static_cast<uint32_t>(body_buffer_length),
static_cast<uint32_t>(end_of_stream))
Expand All @@ -1204,6 +1226,7 @@ Http::FilterTrailersStatus Context::onRequestTrailers() {
if (!wasm_->on_request_trailers_) {
return Http::FilterTrailersStatus::Continue;
}
DeferAfterCallActions actions(this);
if (wasm_->on_request_trailers_(this, id_, headerSize(request_trailers_)).u64_ == 0) {
return Http::FilterTrailersStatus::Continue;
}
Expand All @@ -1214,13 +1237,15 @@ Http::FilterMetadataStatus Context::onRequestMetadata() {
if (!wasm_->on_request_metadata_) {
return Http::FilterMetadataStatus::Continue;
}
DeferAfterCallActions actions(this);
if (wasm_->on_request_metadata_(this, id_, headerSize(request_metadata_)).u64_ == 0) {
return Http::FilterMetadataStatus::Continue;
}
return Http::FilterMetadataStatus::Continue; // This is currently the only return code.
}

Http::FilterHeadersStatus Context::onResponseHeaders() {
DeferAfterCallActions actions(this);
if (!in_vm_context_created_) {
// If the request is invalid then onRequestHeaders() will not be called and neither will
// onCreate() then sendLocalReply be called which will call this function. In this case we
Expand All @@ -1242,6 +1267,7 @@ Http::FilterDataStatus Context::onResponseBody(int body_buffer_length, bool end_
if (!wasm_->on_response_body_) {
return Http::FilterDataStatus::Continue;
}
DeferAfterCallActions actions(this);
switch (wasm_
->on_response_body_(this, id_, static_cast<uint32_t>(body_buffer_length),
static_cast<uint32_t>(end_of_stream))
Expand All @@ -1261,6 +1287,7 @@ Http::FilterTrailersStatus Context::onResponseTrailers() {
if (!wasm_->on_response_trailers_) {
return Http::FilterTrailersStatus::Continue;
}
DeferAfterCallActions actions(this);
if (wasm_->on_response_trailers_(this, id_, headerSize(response_trailers_)).u64_ == 0) {
return Http::FilterTrailersStatus::Continue;
}
Expand All @@ -1271,6 +1298,7 @@ Http::FilterMetadataStatus Context::onResponseMetadata() {
if (!wasm_->on_response_metadata_) {
return Http::FilterMetadataStatus::Continue;
}
DeferAfterCallActions actions(this);
if (wasm_->on_response_metadata_(this, id_, headerSize(response_metadata_)).u64_ == 0) {
return Http::FilterMetadataStatus::Continue;
}
Expand All @@ -1282,11 +1310,13 @@ void Context::onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body
if (!wasm_->on_http_call_response_) {
return;
}
DeferAfterCallActions actions(this);
wasm_->on_http_call_response_(this, id_, token, headers, body_size, trailers);
}

void Context::onQueueReady(uint32_t token) {
if (wasm_->on_queue_ready_) {
DeferAfterCallActions actions(this);
wasm_->on_queue_ready_(this, id_, token);
}
}
Expand All @@ -1295,6 +1325,7 @@ void Context::onGrpcCreateInitialMetadata(uint32_t token, Http::HeaderMap& metad
if (!wasm_->on_grpc_create_initial_metadata_) {
return;
}
DeferAfterCallActions actions(this);
grpc_create_initial_metadata_ = &metadata;
wasm_->on_grpc_create_initial_metadata_(this, id_, token,
headerSize(grpc_create_initial_metadata_));
Expand All @@ -1305,6 +1336,7 @@ void Context::onGrpcReceiveInitialMetadata(uint32_t token, Http::HeaderMapPtr&&
if (!wasm_->on_grpc_receive_initial_metadata_) {
return;
}
DeferAfterCallActions actions(this);
grpc_receive_initial_metadata_ = std::move(metadata);
wasm_->on_grpc_receive_initial_metadata_(this, id_, token,
headerSize(grpc_receive_initial_metadata_));
Expand All @@ -1315,6 +1347,7 @@ void Context::onGrpcReceiveTrailingMetadata(uint32_t token, Http::HeaderMapPtr&&
if (!wasm_->on_grpc_receive_trailing_metadata_) {
return;
}
DeferAfterCallActions actions(this);
grpc_receive_trailing_metadata_ = std::move(metadata);
wasm_->on_grpc_receive_trailing_metadata_(this, id_, token,
headerSize(grpc_receive_trailing_metadata_));
Expand Down Expand Up @@ -1452,13 +1485,15 @@ Context::~Context() {
Network::FilterStatus Context::onNewConnection() { return onNetworkNewConnection(); };

Network::FilterStatus Context::onData(Buffer::Instance& data, bool end_stream) {
DeferAfterCallActions actions(this);
network_downstream_data_buffer_ = &data;
auto result = onDownstreamData(data.length(), end_stream);
network_downstream_data_buffer_ = nullptr;
return result;
}

Network::FilterStatus Context::onWrite(Buffer::Instance& data, bool end_stream) {
DeferAfterCallActions actions(this);
network_upstream_data_buffer_ = &data;
auto result = onUpstreamData(data.length(), end_stream);
network_upstream_data_buffer_ = nullptr;
Expand All @@ -1471,6 +1506,7 @@ Network::FilterStatus Context::onWrite(Buffer::Instance& data, bool end_stream)
}

void Context::onEvent(Network::ConnectionEvent event) {
DeferAfterCallActions actions(this);
switch (event) {
case Network::ConnectionEvent::LocalClose:
onDownstreamConnectionClose(PeerType::Local);
Expand Down Expand Up @@ -1528,19 +1564,22 @@ void Context::onDestroy() {
}

bool Context::onDone() {
DeferAfterCallActions actions(this);
if (wasm_->on_done_) {
return wasm_->on_done_(this, id_).u64_ != 0;
}
return true;
}

void Context::onLog() {
DeferAfterCallActions actions(this);
if (wasm_->on_log_) {
wasm_->on_log_(this, id_);
}
}

void Context::onDelete() {
DeferAfterCallActions actions(this);
if (wasm_->on_delete_) {
wasm_->on_delete_(this, id_);
}
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/common/wasm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ class Context : public Logger::Loggable<Logger::Id::wasm>,
// Connection
virtual bool isSsl();

void addAfterVmCallAction(std::function<void()> f);

protected:
friend class Wasm;

Expand Down
8 changes: 6 additions & 2 deletions source/extensions/common/wasm/exports.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,12 @@ Word send_local_response(void* raw_context, Word response_code, Word response_co
auto grpc_status_opt = (grpc_status != Grpc::Status::WellKnownGrpcStatus::InvalidCode)
? absl::optional<Grpc::Status::WellKnownGrpcStatus>(grpc_status)
: absl::optional<Grpc::Status::WellKnownGrpcStatus>();
context->sendLocalResponse(static_cast<Envoy::Http::Code>(response_code.u64_), body.value(),
modify_headers, grpc_status_opt, details.value());
context->addAfterVmCallAction([context, response_code, body = std::string(body.value()),
modify_headers = std::move(modify_headers), grpc_status_opt,
details = std::string(details.value())] {
context->sendLocalResponse(static_cast<Envoy::Http::Code>(response_code.u64_), body,
modify_headers, grpc_status_opt, details);
});
return wasmResultToWord(WasmResult::Ok);
}

Expand Down
13 changes: 13 additions & 0 deletions source/extensions/common/wasm/wasm.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <atomic>
#include <deque>
#include <map>
#include <memory>

Expand Down Expand Up @@ -141,6 +142,15 @@ class Wasm : public Logger::Loggable<Logger::Id::wasm>, public std::enable_share
return true;
}

void addAfterVmCallAction(std::function<void()> f) { after_vm_call_actions_.push_back(f); }
void doAfterVmCallActions() {
while (!after_vm_call_actions_.empty()) {
auto f = std::move(after_vm_call_actions_.front());
after_vm_call_actions_.pop_front();
f();
}
}

private:
friend class Context;
class ShutdownHandle;
Expand Down Expand Up @@ -268,6 +278,9 @@ class Wasm : public Logger::Loggable<Logger::Id::wasm>, public std::enable_share

// Foreign Functions.
absl::flat_hash_map<std::string, WasmForeignFunction> foreign_functions_;

// Actions to be done after the call into the VM returns.
std::deque<std::function<void()>> after_vm_call_actions_;
};
using WasmSharedPtr = std::shared_ptr<Wasm>;

Expand Down

0 comments on commit 7d648e4

Please sign in to comment.