Skip to content

Commit

Permalink
use async_read for websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
Arkrissym committed Oct 3, 2024
1 parent 377a970 commit 5b150f0
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 172 deletions.
5 changes: 1 addition & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ Set(FETCHCONTENT_QUIET FALSE)

FetchContent_Declare(certify GIT_REPOSITORY https://github.com/djarek/certify)
FetchContent_GetProperties(certify)
if(NOT certify_POPULATED)
FetchContent_Populate(certify)
endif()

set(BOOST_INCLUDE_LIBRARIES thread filesystem system process asio endian logic static_string)
set(BOOST_ENABLE_CMAKE ON)
Expand Down Expand Up @@ -44,7 +41,7 @@ FetchContent_Declare(zlib

FetchContent_Declare(sodium
GIT_REPOSITORY https://github.com/jedisct1/libsodium.git
GIT_TAG 1.0.19
GIT_TAG 1.0.20-RELEASE
)


Expand Down
161 changes: 84 additions & 77 deletions Discord.C++/Gateway.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ void DiscordCPP::Gateway::start_heartbeating() {
if (_last_heartbeat_ack * 1000 + static_cast<time_t>(_heartbeat_interval * 2) < time(nullptr) * 1000) {
_log.warning("Gateway stopped responding. Closing and restarting websocket...");
try {
_client->close(websocket::close_reason(websocket::close_code::going_away, "Server not responding"));
get_lowest_layer(*_client).close();
//_client->close(websocket::close_reason(websocket::close_code::going_away, "Server not responding"));
} catch (std::exception& e) {
_log.error("Cannot close websocket: " + std::string(e.what()));
}
Expand Down Expand Up @@ -75,16 +76,48 @@ void DiscordCPP::Gateway::on_websocket_disconnnect() {
}

try {
connect(_resume_url).get();
io_context.restart();
connect(_resume_url);
_log.info("reconnected");
_reconnect_timeout = 0;
_last_heartbeat_ack = time(nullptr);
} catch (const beast::system_error& e) {
_log.error("Failed to reconnect: " + std::string(e.what()));
on_websocket_disconnnect();
}
});
}

void DiscordCPP::Gateway::on_read(boost::system::error_code error_code, std::size_t bytes) {
_log.debug("Received " + std::to_string(bytes) + " bytes");

if (error_code == boost::beast::websocket::error::closed || error_code == boost::beast::errc::operation_canceled || error_code == boost::asio::ssl::error::stream_truncated || bytes <= 0) {
on_websocket_disconnnect();
return;
}

if (error_code) {
_log.error("Error while reading message: " + error_code.message());
} else {
std::stringstream message_stream;
message_stream << beast::make_printable(buffer.data());
std::string message = message_stream.str();

threadpool->execute([this, message]() {
try {
on_websocket_incoming_message(message);
} catch (const std::exception& e) {
_log.error("Error while handling incoming message: " + std::string(e.what()));
}
});
}

buffer.clear();
_client->async_read(buffer, [this](boost::system::error_code ec, std::size_t b) {
on_read(ec, b);
});
}

DiscordCPP::Gateway::Gateway(std::string token, const std::shared_ptr<Threadpool>& threadpool)
: threadpool(threadpool),
io_context(),
Expand All @@ -111,99 +144,74 @@ void DiscordCPP::Gateway::set_message_handler(
_message_handler = handler;
}

DiscordCPP::SharedFuture<void> DiscordCPP::Gateway::connect(const std::string& url) {
void DiscordCPP::Gateway::connect(const std::string& url) {
if (_url.empty()) {
_url = url;
}
return threadpool
->execute([this, url]() {
std::string tmp_url = url;

// cut protocol
auto index = tmp_url.find("://");
if (index != std::string::npos) {
tmp_url = tmp_url.substr(index + 3, std::string::npos);
}

auto port_index = tmp_url.find(':');
auto query_index = tmp_url.find('?');

std::string host;
if (port_index != std::string::npos) {
host = tmp_url.substr(0, port_index);
} else {
host = tmp_url.substr(0, query_index);
}

std::string query = "/";
if (query_index != std::string::npos) {
query = "/" + tmp_url.substr(query_index, std::string::npos);
}
std::string tmp_url = url;

_log.debug("host: " + host + "\t\tquery: " + query);
_log.info("connecting to websocket: " + url);
// cut protocol
auto index = tmp_url.find("://");
if (index != std::string::npos) {
tmp_url = tmp_url.substr(index + 3, std::string::npos);
}

tcp::resolver resolver{io_context};
auto results = resolver.resolve(host, "443");
auto port_index = tmp_url.find(':');
auto query_index = tmp_url.find('?');

_client = std::make_unique<websocket::stream<beast::ssl_stream<tcp::socket>>>(io_context, ssl_context);
std::string host;
if (port_index != std::string::npos) {
host = tmp_url.substr(0, port_index);
} else {
host = tmp_url.substr(0, query_index);
}

auto endpoint = net::connect(get_lowest_layer(*_client), results);
std::string query = "/";
if (query_index != std::string::npos) {
query = "/" + tmp_url.substr(query_index, std::string::npos);
}

if (!SSL_set_tlsext_host_name(_client->next_layer().native_handle(), host.c_str()))
throw beast::system_error(
beast::error_code(
static_cast<int>(::ERR_get_error()),
net::error::get_ssl_category()),
"Failed to set SNI Hostname");
_log.debug("host: " + host + "\tquery: " + query);
_log.info("connecting to websocket: " + url);

_client->set_option(websocket::stream_base::decorator([](websocket::request_type& req) {
req.set(http::field::user_agent, "Discord.C++ DiscordBot");
}));
tcp::resolver resolver{io_context};
auto results = resolver.resolve(host, "443");

_client->next_layer().handshake(ssl::stream_base::client);
_client->handshake(host + std::string(":") + std::to_string(endpoint.port()), query);
_client = std::make_unique<websocket::stream<beast::ssl_stream<tcp::socket>>>(io_context, ssl_context);

if (_heartbeat_task.get_id() == std::thread::id()) {
start_heartbeating();
}
auto endpoint = net::connect(get_lowest_layer(*_client), results);

_log.info("Successfully connected to endpoint: " + endpoint.address().to_string() + ":" + std::to_string(endpoint.port()));
_connected = true;
})
.then([this]() {
threadpool->execute([this]() {
_log.debug("Starting message reveiving loop.");
while (_connected) {
beast::flat_buffer buffer;
beast::error_code error_code;
size_t bytes = _client->read(buffer, error_code);
if (!SSL_set_tlsext_host_name(_client->next_layer().native_handle(), host.c_str()))
throw beast::system_error(
beast::error_code(
static_cast<int>(::ERR_get_error()),
net::error::get_ssl_category()),
"Failed to set SNI Hostname");

_log.debug("Received " + std::to_string(bytes) + " bytes");
_client->set_option(websocket::stream_base::decorator([](websocket::request_type& req) {
req.set(http::field::user_agent, "Discord.C++ DiscordBot");
}));

if (error_code == boost::beast::websocket::error::closed || error_code == boost::beast::errc::operation_canceled || error_code == boost::asio::ssl::error::stream_truncated || bytes <= 0) {
on_websocket_disconnnect();
_client->next_layer().handshake(ssl::stream_base::client);
_client->handshake(host + std::string(":") + std::to_string(endpoint.port()), query);

break;
}
if (_heartbeat_task.get_id() == std::thread::id()) {
start_heartbeating();
}

if (error_code) {
_log.error("Error while reading message: " + error_code.message());
continue;
}
_log.info("Successfully connected to endpoint: " + endpoint.address().to_string() + ":" + std::to_string(endpoint.port()));
_connected = true;

std::stringstream message_stream;
message_stream << beast::make_printable(buffer.data());
std::string message = message_stream.str();
_client->async_read(buffer, [this](boost::system::error_code error_code, std::size_t bytes) {
on_read(error_code, bytes);
});

try {
on_websocket_incoming_message(message);
} catch (const std::exception& e) {
_log.error("Error while handling incoming message: " + std::string(e.what()));
}
}
});
});
threadpool->execute([this]() {
_log.debug("Start io_context");
io_context.run();
_log.debug("Stop io_context");
});
}

///@throws ClientException
Expand Down Expand Up @@ -233,7 +241,6 @@ DiscordCPP::SharedFuture<void> DiscordCPP::Gateway::close() {

return threadpool->execute([this]() {
try {
//_client->close(websocket::close_code::normal);
get_lowest_layer(*_client).close();
} catch (const std::exception& e) {
_log.error("Error while closing websocket: " + std::string(e.what()));
Expand Down
5 changes: 4 additions & 1 deletion Discord.C++/Gateway.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class Gateway {
boost::asio::io_context io_context;
/// ssl context used by the websocket client
boost::asio::ssl::context ssl_context;
/// buffer for websocket connection
boost::beast::flat_buffer buffer;
/// websocket client
std::unique_ptr<boost::beast::websocket::stream<boost::beast::ssl_stream<boost::asio::ip::tcp::socket>>> _client;
/// the url of the gateway
Expand Down Expand Up @@ -50,6 +52,7 @@ class Gateway {
DLL_EXPORT void start_heartbeating();
DLL_EXPORT virtual json get_heartbeat_payload() = 0;
DLL_EXPORT virtual void identify() = 0;
DLL_EXPORT void on_read(boost::system::error_code error_code, std::size_t bytes);
DLL_EXPORT virtual void on_websocket_incoming_message(const std::string& message) = 0;
DLL_EXPORT void on_websocket_disconnnect();

Expand All @@ -59,7 +62,7 @@ class Gateway {

DLL_EXPORT void set_message_handler(const std::function<void(json payload)>& handler);

DLL_EXPORT virtual SharedFuture<void> connect(const std::string& url);
DLL_EXPORT virtual void connect(const std::string& url);
DLL_EXPORT SharedFuture<void> send(const json& message);
DLL_EXPORT SharedFuture<void> close();
};
Expand Down
112 changes: 55 additions & 57 deletions Discord.C++/MainGateway.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,63 +52,61 @@ void DiscordCPP::MainGateway::on_websocket_incoming_message(
_sequence_number = payload["s"].get<int>();
}

threadpool->execute([this, payload, op] {
switch (op) {
case 0:
if (payload["t"].get<std::string>() == "READY") {
_reconnect_timeout = 0;
_last_heartbeat_ack = time(nullptr);

_invalid_session = false;

_session_id = payload["d"]["session_id"].get<std::string>();

_resume_url = payload["d"]["resume_gateway_url"].get<std::string>() + "?v=10&encoding=json&compress=zlib-stream";

std::string str = set_trace(payload);

_log.info("connected to: " + str + " ]");
_log.info("session id: " + _session_id);
} else if (payload["t"].get<std::string>() == "RESUMED") {
_reconnect_timeout = 0;
_last_heartbeat_ack = time(nullptr);

std::string str = set_trace(payload);

_log.info("successfully resumed session " + _session_id +
" with trace " + str + " ]");
}
break;
case 1:
send_heartbeat_ack();
break;
case 7:
_log.info("received opcode 7: reconnecting to the gateway");
try {
_client->close(boost::beast::websocket::close_reason(boost::beast::websocket::close_code::going_away, "Server requested reconnect"));
} catch (std::exception& e) {
_log.error("Cannot close websocket: " + std::string(e.what()));
}
break;
case 9:
_invalid_session = true;
break;
case 10:
_heartbeat_interval = payload["d"]["heartbeat_interval"].get<int>();
_log.debug("set heartbeat_interval: " +
std::to_string(_heartbeat_interval));
identify();
break;
case 11:
_log.debug("received heartbeat ACK");
switch (op) {
case 0:
if (payload["t"].get<std::string>() == "READY") {
_reconnect_timeout = 0;
_last_heartbeat_ack = time(nullptr);
break;
default:
break;
}

_message_handler(payload);
});
_invalid_session = false;

_session_id = payload["d"]["session_id"].get<std::string>();

_resume_url = payload["d"]["resume_gateway_url"].get<std::string>() + "?v=10&encoding=json&compress=zlib-stream";

std::string str = set_trace(payload);

_log.info("connected to: " + str + " ]");
_log.info("session id: " + _session_id);
} else if (payload["t"].get<std::string>() == "RESUMED") {
_reconnect_timeout = 0;
_last_heartbeat_ack = time(nullptr);

std::string str = set_trace(payload);

_log.info("successfully resumed session " + _session_id +
" with trace " + str + " ]");
}
break;
case 1:
send_heartbeat_ack();
break;
case 7:
_log.info("received opcode 7: reconnecting to the gateway");
try {
_client->close(boost::beast::websocket::close_reason(boost::beast::websocket::close_code::going_away, "Server requested reconnect"));
} catch (std::exception& e) {
_log.error("Cannot close websocket: " + std::string(e.what()));
}
break;
case 9:
_invalid_session = true;
break;
case 10:
_heartbeat_interval = payload["d"]["heartbeat_interval"].get<int>();
_log.debug("set heartbeat_interval: " +
std::to_string(_heartbeat_interval));
identify();
break;
case 11:
_log.debug("received heartbeat ACK");
_last_heartbeat_ack = time(nullptr);
break;
default:
break;
}

_message_handler(payload);
}

DiscordCPP::SharedFuture<void> DiscordCPP::MainGateway::send_heartbeat_ack() {
Expand Down Expand Up @@ -210,7 +208,7 @@ unsigned int DiscordCPP::MainGateway::get_shard_id() {
return _shard_id;
}

DiscordCPP::SharedFuture<void> DiscordCPP::MainGateway::connect(const std::string& url) {
void DiscordCPP::MainGateway::connect(const std::string& url) {
zs.zalloc = Z_NULL;
zs.zfree = Z_NULL;
zs.opaque = Z_NULL;
Expand All @@ -221,5 +219,5 @@ DiscordCPP::SharedFuture<void> DiscordCPP::MainGateway::connect(const std::strin
throw ClientException("Failed to initialize zlib");
}

return Gateway::connect(url);
Gateway::connect(url);
}
Loading

0 comments on commit 5b150f0

Please sign in to comment.