Skip to content

Commit

Permalink
dns: refactor QueryWrap lifetime management
Browse files Browse the repository at this point in the history
- Prefer RAII-style management over manual resource management.
- Prefer `env->SetImmediate()` over a separate `uv_async_t`.
- Perform `ares_destroy()` before possibly tearing down c-ares state.
- Verify that the number of active queries is non-negative.
- Let pending callbacks know when their underlying `QueryWrap` object
  has been destroyed.

The last item has been a real bug, in that when Workers shut down
during currently running DNS queries, they may run into use-after-free
situations because:

1. Shutting the `Worker` down leads to the cleanup code deleting
   the `QueryWrap` objects first; then
2. deleting the `ChannelWrap` object (as it has been created before
   the `QueryWrap`s), whose destructor runs `ares_destroy()`, which
   in turn invokes all pending query callbacks with `ARES_ECANCELLED`,
3. which lead to use-after-free, as the callback tried to access the
   deleted `QueryWrap` object.

The added test verifies that this is no longer an issue.

PR-URL: #26253
Reviewed-By: Colin Ihrig <[email protected]>
Reviewed-By: James M Snell <[email protected]>
  • Loading branch information
addaleax committed Mar 1, 2019
1 parent 4561cf3 commit ea26ac0
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 66 deletions.
130 changes: 64 additions & 66 deletions src/cares_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,11 @@ void safe_free_hostent(struct hostent* host) {
host->h_aliases = nullptr;
}

if (host->h_name != nullptr) {
free(host->h_name);
}

host->h_addrtype = host->h_length = 0;
free(host->h_name);
free(host);
}

void cares_wrap_hostent_cpy(struct hostent* dest, struct hostent* src) {
void cares_wrap_hostent_cpy(struct hostent* dest, const struct hostent* src) {
dest->h_addr_list = nullptr;
dest->h_addrtype = 0;
dest->h_aliases = nullptr;
Expand Down Expand Up @@ -461,18 +458,6 @@ void cares_wrap_hostent_cpy(struct hostent* dest, struct hostent* src) {
}

class QueryWrap;
struct CaresAsyncData {
QueryWrap* wrap;
int status;
bool is_host;
union {
hostent* host;
unsigned char* buf;
} data;
int len;

uv_async_t async_handle;
};

void ChannelWrap::Setup() {
struct ares_options options;
Expand Down Expand Up @@ -525,20 +510,21 @@ void ChannelWrap::CloseTimer() {
}

ChannelWrap::~ChannelWrap() {
ares_destroy(channel_);

if (library_inited_) {
Mutex::ScopedLock lock(ares_library_mutex);
// This decreases the reference counter increased by ares_library_init().
ares_library_cleanup();
}

ares_destroy(channel_);
CloseTimer();
}


void ChannelWrap::ModifyActivityQueryCount(int count) {
active_query_count_ += count;
if (active_query_count_ < 0) active_query_count_ = 0;
CHECK_GE(active_query_count_, 0);
}


Expand Down Expand Up @@ -602,6 +588,10 @@ class QueryWrap : public AsyncWrap {

~QueryWrap() override {
CHECK_EQ(false, persistent().IsEmpty());

// Let Callback() know that this object no longer exists.
if (callback_ptr_ != nullptr)
*callback_ptr_ = nullptr;
}

// Subclasses should implement the appropriate Send method.
Expand All @@ -624,89 +614,93 @@ class QueryWrap : public AsyncWrap {
TRACING_CATEGORY_NODE2(dns, native), trace_name_, this,
"name", TRACE_STR_COPY(name));
ares_query(channel_->cares_channel(), name, dnsclass, type, Callback,
static_cast<void*>(this));
MakeCallbackPointer());
}

static void CaresAsyncClose(uv_async_t* async) {
auto data = static_cast<struct CaresAsyncData*>(async->data);
delete data->wrap;
delete data;
}
struct ResponseData {
int status;
bool is_host;
DeleteFnPtr<hostent, safe_free_hostent> host;
MallocedBuffer<unsigned char> buf;
};

static void CaresAsyncCb(uv_async_t* handle) {
auto data = static_cast<struct CaresAsyncData*>(handle->data);
void AfterResponse() {
CHECK(response_data_);

QueryWrap* wrap = data->wrap;
int status = data->status;
const int status = response_data_->status;

if (status != ARES_SUCCESS) {
wrap->ParseError(status);
} else if (!data->is_host) {
unsigned char* buf = data->data.buf;
wrap->Parse(buf, data->len);
free(buf);
ParseError(status);
} else if (!response_data_->is_host) {
Parse(response_data_->buf.data, response_data_->buf.size);
} else {
hostent* host = data->data.host;
wrap->Parse(host);
safe_free_hostent(host);
free(host);
Parse(response_data_->host.get());
}

wrap->env()->CloseHandle(handle, CaresAsyncClose);
delete this;
}

void* MakeCallbackPointer() {
CHECK_NULL(callback_ptr_);
callback_ptr_ = new QueryWrap*(this);
return callback_ptr_;
}

static QueryWrap* FromCallbackPointer(void* arg) {
std::unique_ptr<QueryWrap*> wrap_ptr { static_cast<QueryWrap**>(arg) };
QueryWrap* wrap = *wrap_ptr.get();
if (wrap == nullptr) return nullptr;
wrap->callback_ptr_ = nullptr;
return wrap;
}

static void Callback(void* arg, int status, int timeouts,
unsigned char* answer_buf, int answer_len) {
QueryWrap* wrap = static_cast<QueryWrap*>(arg);
QueryWrap* wrap = FromCallbackPointer(arg);
if (wrap == nullptr) return;

unsigned char* buf_copy = nullptr;
if (status == ARES_SUCCESS) {
buf_copy = node::Malloc<unsigned char>(answer_len);
memcpy(buf_copy, answer_buf, answer_len);
}

CaresAsyncData* data = new CaresAsyncData();
wrap->response_data_.reset(new ResponseData());
ResponseData* data = wrap->response_data_.get();
data->status = status;
data->wrap = wrap;
data->is_host = false;
data->data.buf = buf_copy;
data->len = answer_len;

uv_async_t* async_handle = &data->async_handle;
CHECK_EQ(0, uv_async_init(wrap->env()->event_loop(),
async_handle,
CaresAsyncCb));
data->buf = MallocedBuffer<unsigned char>(buf_copy, answer_len);

wrap->channel_->set_query_last_ok(status != ARES_ECONNREFUSED);
wrap->channel_->ModifyActivityQueryCount(-1);
async_handle->data = data;
uv_async_send(async_handle);
wrap->QueueResponseCallback(status);
}

static void Callback(void* arg, int status, int timeouts,
struct hostent* host) {
QueryWrap* wrap = static_cast<QueryWrap*>(arg);
QueryWrap* wrap = FromCallbackPointer(arg);
if (wrap == nullptr) return;

struct hostent* host_copy = nullptr;
if (status == ARES_SUCCESS) {
host_copy = node::Malloc<hostent>(1);
cares_wrap_hostent_cpy(host_copy, host);
}

CaresAsyncData* data = new CaresAsyncData();
wrap->response_data_.reset(new ResponseData());
ResponseData* data = wrap->response_data_.get();
data->status = status;
data->data.host = host_copy;
data->wrap = wrap;
data->host.reset(host_copy);
data->is_host = true;

uv_async_t* async_handle = &data->async_handle;
CHECK_EQ(0, uv_async_init(wrap->env()->event_loop(),
async_handle,
CaresAsyncCb));
wrap->QueueResponseCallback(status);
}

void QueueResponseCallback(int status) {
env()->SetImmediate([](Environment*, void* data) {
static_cast<QueryWrap*>(data)->AfterResponse();
}, this, object());

wrap->channel_->set_query_last_ok(status != ARES_ECONNREFUSED);
async_handle->data = data;
uv_async_send(async_handle);
channel_->set_query_last_ok(status != ARES_ECONNREFUSED);
channel_->ModifyActivityQueryCount(-1);
}

void CallOnComplete(Local<Value> answer,
Expand Down Expand Up @@ -749,7 +743,11 @@ class QueryWrap : public AsyncWrap {
ChannelWrap* channel_;

private:
std::unique_ptr<ResponseData> response_data_;
const char* trace_name_;
// Pointer to pointer to 'this' that can be reset from the destructor,
// in order to let Callback() know that 'this' no longer exists.
QueryWrap** callback_ptr_ = nullptr;
};


Expand Down Expand Up @@ -1768,7 +1766,7 @@ class GetHostByAddrWrap: public QueryWrap {
length,
family,
Callback,
static_cast<void*>(static_cast<QueryWrap*>(this)));
MakeCallbackPointer());
return 0;
}

Expand Down
23 changes: 23 additions & 0 deletions test/parallel/test-worker-dns-terminate-during-query.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
'use strict';
const common = require('../common');
const { Resolver } = require('dns');
const dgram = require('dgram');
const { Worker, isMainThread } = require('worker_threads');

// Test that Workers can terminate while DNS queries are outstanding.

if (isMainThread) {
return new Worker(__filename);
}

const socket = dgram.createSocket('udp4');

socket.bind(0, common.mustCall(() => {
const resolver = new Resolver();
resolver.setServers([`127.0.0.1:${socket.address().port}`]);
resolver.resolve4('example.org', common.mustNotCall());
}));

socket.on('message', common.mustCall(() => {
process.exit();
}));

0 comments on commit ea26ac0

Please sign in to comment.