Skip to content

Commit

Permalink
Fix: ujrpc_ssl_context_t corruption on move.
Browse files Browse the repository at this point in the history
Use destructors with `new` and `delete`
  • Loading branch information
ishkhan42 committed Apr 10, 2023
1 parent 865217b commit 50ec686
Showing 1 changed file with 101 additions and 100 deletions.
201 changes: 101 additions & 100 deletions src/engine_posix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,66 @@ using time_clock_t = std::chrono::steady_clock;
using time_point_t = std::chrono::time_point<time_clock_t>;

struct ujrpc_ssl_context_t {
ujrpc_ssl_context_t() = default;
ujrpc_ssl_context_t() noexcept = default;
~ujrpc_ssl_context_t() noexcept {
mbedtls_x509_crt_free(&srvcert);
mbedtls_pk_free(&pkey);
mbedtls_ssl_free(&ssl);
mbedtls_ssl_config_free(&conf);
// #if defined(MBEDTLS_SSL_CACHE_C)
// mbedtls_ssl_cache_free(&cache);
// #endif
mbedtls_ctr_drbg_free(&ctr_drbg);
mbedtls_entropy_free(&entropy);
}

int init(const char* pk_path, const char** crts_path, size_t crts_cnt) {
mbedtls_ssl_init(&ssl);
mbedtls_ssl_config_init(&conf);
// #if defined(MBEDTLS_SSL_CACHE_C)
// mbedtls_ssl_cache_init(&cache);
// #endif
mbedtls_x509_crt_init(&srvcert);
mbedtls_pk_init(&pkey);
mbedtls_entropy_init(&entropy);
mbedtls_ctr_drbg_init(&ctr_drbg);
int ret = 0;

// Seed the RNG
if ((ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0)) != 0)
// TODO Use personalization string. Required or Optional ?
return ret;

operator bool() const noexcept { return this->ssl.private_in_buf != nullptr; }
// Load Private Key
if ((ret = mbedtls_pk_parse_keyfile(&pkey, pk_path, NULL, NULL, &ctr_drbg)) != 0)
// TODO Use Password. Required or Optional ?
return ret;

// Load Certificates
for (size_t i = 0; i < crts_cnt; ++i)
if ((ret = mbedtls_x509_crt_parse_file(&srvcert, crts_path[i])) != 0)
// TODO Notify which certificate was invalid ?
return ret;

if ((ret = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
return ret;

mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg);

// #if defined(MBEDTLS_SSL_CACHE_C)
// mbedtls_ssl_conf_session_cache(&conf, &cache, mbedtls_ssl_cache_get, mbedtls_ssl_cache_set);
// #endif

mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, NULL);
if ((ret = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0)
return ret;

if ((ret = mbedtls_ssl_setup(&ssl, &conf)) != 0)
return ret;

return 0;
}

mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
Expand All @@ -46,32 +103,36 @@ struct ujrpc_ssl_context_t {
};

struct engine_t {
descriptor_t socket{};
std::size_t max_batch_size{};
engine_t() = default;

~engine_t() noexcept { delete ssl_ctx; }

descriptor_t socket;
std::size_t max_batch_size;

/// @brief Establishes an SSL connection if SSL is enabled, otherwise the `ssl_ctx` is unused and uninitialized.
ujrpc_ssl_context_t ssl_ctx{};
ujrpc_ssl_context_t* ssl_ctx = nullptr;

/// @brief The file descriptor of the stateful connection over TCP.
descriptor_t connection{};
descriptor_t connection;
/// @brief A small memory buffer to store small requests.
alignas(align_k) char packet_buffer[ram_page_size_k + sj::SIMDJSON_PADDING]{};
alignas(align_k) char packet_buffer[ram_page_size_k + sj::SIMDJSON_PADDING];
/// @brief An array of function callbacks. Can be in dozens.
array_gt<named_callback_t> callbacks{};
array_gt<named_callback_t> callbacks;
/// @brief Statically allocated memory to process small requests.
scratch_space_t scratch{};
scratch_space_t scratch;
/// @brief For batch-requests in synchronous connections we need a place to
struct batch_response_t {
buffer_gt<struct iovec> iovecs{};
buffer_gt<char*> copies{};
std::size_t iovecs_count{};
std::size_t copies_count{};
} batch_response{};

stats_t stats{};
std::int32_t logs_file_descriptor{};
std::string_view logs_format{};
time_point_t log_last_time{};
buffer_gt<struct iovec> iovecs;
buffer_gt<char*> copies;
std::size_t iovecs_count;
std::size_t copies_count;
} batch_response;

stats_t stats;
std::int32_t logs_file_descriptor;
std::string_view logs_format;
time_point_t log_last_time;
};

sj::simdjson_result<sjd::element> param_at(ujrpc_call_t call, ujrpc_str_t name, size_t name_len) noexcept {
Expand All @@ -93,7 +154,7 @@ void send_message(engine_t& engine, struct msghdr& message) noexcept {
size_t sz = 0;
for (size_t i = 0; i < message.msg_iovlen; ++i)
sz += message.msg_iov[i].iov_len;
bytes_sent = mbedtls_ssl_write(&engine.ssl_ctx.ssl, (uint8_t*)message.msg_iov->iov_base, sz);
bytes_sent = mbedtls_ssl_write(&engine.ssl_ctx->ssl, (uint8_t*)message.msg_iov->iov_base, sz);
} else
bytes_sent = sendmsg(engine.connection, &message, 0);

Expand Down Expand Up @@ -222,66 +283,6 @@ int ssl_recv(void* ctx, unsigned char* buf, size_t len) {
return ret;
}

int init_ssl(ujrpc_ssl_context_t* ctx, const char* pk_path, const char** crts_path, size_t crts_cnt) {
mbedtls_ssl_init(&ctx->ssl);
mbedtls_ssl_config_init(&ctx->conf);
// #if defined(MBEDTLS_SSL_CACHE_C)
// mbedtls_ssl_cache_init(&cache);
// #endif
mbedtls_x509_crt_init(&ctx->srvcert);
mbedtls_pk_init(&ctx->pkey);
mbedtls_entropy_init(&ctx->entropy);
mbedtls_ctr_drbg_init(&ctx->ctr_drbg);
int ret = 0;

// Seed the RNG
if ((ret = mbedtls_ctr_drbg_seed(&ctx->ctr_drbg, mbedtls_entropy_func, &ctx->entropy, NULL, 0)) != 0)
// TODO Use personalization string. Required or Optional ?
return ret;

// Load Private Key
if ((ret = mbedtls_pk_parse_keyfile(&ctx->pkey, pk_path, NULL, NULL, &ctx->ctr_drbg)) != 0)
// TODO Use Password. Required or Optional ?
return ret;

// Load Certificates
for (size_t i = 0; i < crts_cnt; ++i)
if ((ret = mbedtls_x509_crt_parse_file(&ctx->srvcert, crts_path[i])) != 0)
// TODO Notify which certificate was invalid ?
return ret;

if ((ret = mbedtls_ssl_config_defaults(&ctx->conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
return ret;

mbedtls_ssl_conf_rng(&ctx->conf, mbedtls_ctr_drbg_random, &ctx->ctr_drbg);

// #if defined(MBEDTLS_SSL_CACHE_C)
// mbedtls_ssl_conf_session_cache(&conf, &cache, mbedtls_ssl_cache_get, mbedtls_ssl_cache_set);
// #endif

mbedtls_ssl_conf_ca_chain(&ctx->conf, ctx->srvcert.next, NULL);
if ((ret = mbedtls_ssl_conf_own_cert(&ctx->conf, &ctx->srvcert, &ctx->pkey)) != 0)
return ret;

if ((ret = mbedtls_ssl_setup(&ctx->ssl, &ctx->conf)) != 0)
return ret;

return 0;
}

void ssl_free(ujrpc_ssl_context_t* ssl_ctx) {
mbedtls_x509_crt_free(&ssl_ctx->srvcert);
mbedtls_pk_free(&ssl_ctx->pkey);
mbedtls_ssl_free(&ssl_ctx->ssl);
mbedtls_ssl_config_free(&ssl_ctx->conf);
// #if defined(MBEDTLS_SSL_CACHE_C)
// mbedtls_ssl_cache_free(&ssl_ctx->cache);
// #endif
mbedtls_ctr_drbg_free(&ssl_ctx->ctr_drbg);
mbedtls_entropy_free(&ssl_ctx->entropy);
}

void ujrpc_take_call(ujrpc_server_t server, uint16_t) {
engine_t& engine = *reinterpret_cast<engine_t*>(server);
scratch_space_t& scratch = engine.scratch;
Expand Down Expand Up @@ -314,12 +315,12 @@ void ujrpc_take_call(ujrpc_server_t server, uint16_t) {

if (engine.ssl_ctx) {
client_ctx.fd = connection_fd;
mbedtls_ssl_set_bio(&engine.ssl_ctx.ssl, &client_ctx, ssl_send, ssl_recv, NULL);
mbedtls_ssl_set_bio(&engine.ssl_ctx->ssl, &client_ctx, ssl_send, ssl_recv, NULL);
int ret = 0;
while ((ret = mbedtls_ssl_handshake(&engine.ssl_ctx.ssl)) != 0)
while ((ret = mbedtls_ssl_handshake(&engine.ssl_ctx->ssl)) != 0)
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
mbedtls_net_free(&client_ctx);
mbedtls_ssl_session_reset(&engine.ssl_ctx.ssl);
mbedtls_ssl_session_reset(&engine.ssl_ctx->ssl);
return;
}
}
Expand All @@ -333,8 +334,8 @@ void ujrpc_take_call(ujrpc_server_t server, uint16_t) {
size_t bytes_received = 0, bytes_expected = 0;
if (engine.ssl_ctx) {
// https://esp32.com/viewtopic.php?t=1101
mbedtls_ssl_read(&engine.ssl_ctx.ssl, NULL, 0);
bytes_expected = mbedtls_ssl_get_bytes_avail(&engine.ssl_ctx.ssl);
mbedtls_ssl_read(&engine.ssl_ctx->ssl, NULL, 0);
bytes_expected = mbedtls_ssl_get_bytes_avail(&engine.ssl_ctx->ssl);
} else {
bytes_received = recv(engine.connection, buffer_ptr, http_head_max_size_k, MSG_PEEK);
auto json_or_error = split_body_headers(std::string_view(buffer_ptr, bytes_received));
Expand All @@ -353,7 +354,7 @@ void ujrpc_take_call(ujrpc_server_t server, uint16_t) {
if (bytes_expected <= ram_page_size_k) {
if (engine.ssl_ctx)
bytes_received =
mbedtls_ssl_read(&engine.ssl_ctx.ssl, reinterpret_cast<uint8_t*>(buffer_ptr), bytes_expected);
mbedtls_ssl_read(&engine.ssl_ctx->ssl, reinterpret_cast<uint8_t*>(buffer_ptr), bytes_expected);
else
bytes_received = recv(engine.connection, buffer_ptr, bytes_expected, MSG_WAITALL);
scratch.dynamic_parser = &scratch.parser;
Expand All @@ -372,7 +373,7 @@ void ujrpc_take_call(ujrpc_server_t server, uint16_t) {

if (engine.ssl_ctx)
bytes_received =
mbedtls_ssl_read(&engine.ssl_ctx.ssl, reinterpret_cast<uint8_t*>(buffer_ptr), bytes_expected);
mbedtls_ssl_read(&engine.ssl_ctx->ssl, reinterpret_cast<uint8_t*>(buffer_ptr), bytes_expected);
else
bytes_received = recv(engine.connection, buffer_ptr, bytes_expected, MSG_WAITALL);
scratch.dynamic_parser = &parser;
Expand All @@ -386,11 +387,11 @@ void ujrpc_take_call(ujrpc_server_t server, uint16_t) {

if (engine.ssl_ctx) {
int ret = 0;
while ((ret = mbedtls_ssl_close_notify(&engine.ssl_ctx.ssl)) < 0)
while ((ret = mbedtls_ssl_close_notify(&engine.ssl_ctx->ssl)) < 0)
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
break;

mbedtls_ssl_session_reset(&engine.ssl_ctx.ssl);
mbedtls_ssl_session_reset(&engine.ssl_ctx->ssl);
}
shutdown(connection_fd, SHUT_WR);
// If later on some UB is detected for client not recieving full data,
Expand Down Expand Up @@ -432,7 +433,7 @@ void ujrpc_init(ujrpc_config_t* config_inout, ujrpc_server_t* server_out) {
buffer_gt<struct iovec> embedded_iovecs;
buffer_gt<char*> embedded_copies;
array_gt<named_callback_t> embedded_callbacks;
ujrpc_ssl_context_t ssl_context;
ujrpc_ssl_context_t* ssl_context = nullptr;
sjd::parser parser;

// By default, let's open TCP port for IPv4.
Expand Down Expand Up @@ -466,8 +467,11 @@ void ujrpc_init(ujrpc_config_t* config_inout, ujrpc_server_t* server_out) {
goto cleanup;
if (listen(socket_descriptor, config.queue_depth) < 0)
goto cleanup;
if (config.use_ssl && init_ssl(&ssl_context, config.ssl_pk_path, config.ssl_crts_path, config.ssl_crts_cnt) != 0)
goto cleanup;
if (config.use_ssl) {
ssl_context = new ujrpc_ssl_context_t();
if (ssl_context->init(config.ssl_pk_path, config.ssl_crts_path, config.ssl_crts_cnt) != 0)
goto cleanup;
}
if (parser.allocate(ram_page_size_k, ram_page_size_k / 2) != sj::SUCCESS)
goto cleanup;

Expand All @@ -482,7 +486,7 @@ void ujrpc_init(ujrpc_config_t* config_inout, ujrpc_server_t* server_out) {
server_ptr->logs_file_descriptor = config.logs_file_descriptor;
server_ptr->logs_format = config.logs_format ? std::string_view(config.logs_format) : std::string_view();
server_ptr->log_last_time = time_clock_t::now();
server_ptr->ssl_ctx = std::move(ssl_context);
server_ptr->ssl_ctx = ssl_context;
*server_out = (ujrpc_server_t)server_ptr;
return;

Expand All @@ -492,8 +496,7 @@ void ujrpc_init(ujrpc_config_t* config_inout, ujrpc_server_t* server_out) {
close(socket_descriptor);
std::free(server_ptr);
*server_out = nullptr;
if (config.use_ssl)
ssl_free(&ssl_context);
delete ssl_context;
}

void ujrpc_add_procedure(ujrpc_server_t server, ujrpc_str_t name, ujrpc_callback_t callback,
Expand All @@ -512,11 +515,9 @@ void ujrpc_free(ujrpc_server_t server) {
if (!server)
return;

engine_t& engine = *reinterpret_cast<engine_t*>(server);
close(engine.socket);
ssl_free(&engine.ssl_ctx);
engine.~engine_t();
std::free(server);
engine_t* engine = reinterpret_cast<engine_t*>(server);
close(engine->socket);
delete engine;
}

void prepend_http_headers(iovec* buffers, size_t content_len, char* http_buffer) {
Expand Down

0 comments on commit 50ec686

Please sign in to comment.