Skip to content

Commit

Permalink
Add SSL handshake error handling
Browse files Browse the repository at this point in the history
Signed-off-by: Méven Car <[email protected]>
Change-Id: I8ad58d7ae165d81a43744da507a957e81996d56b
  • Loading branch information
meven committed Nov 26, 2024
1 parent 03c6006 commit eed41e5
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 65 deletions.
92 changes: 64 additions & 28 deletions net/HttpRequest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ class Session final : public ProtocolHandlerInterface
void setFinishedHandler(FinishedCallback onFinished) { _onFinished = std::move(onFinished); }

/// The onConnectFail callback handler signature.
using ConnectFailCallback = std::function<void()>;
using ConnectFailCallback = std::function<void(const std::shared_ptr<Session>& session)>;

void setConnectFailHandler(ConnectFailCallback onConnectFail) { _onConnectFail = std::move(onConnectFail); }

Expand Down Expand Up @@ -1380,6 +1380,11 @@ class Session final : public ProtocolHandlerInterface
}
}

net::asyncConnectResult connectionResult()
{
return _result;
}

/// Returns the socket FD, for logging/informational purposes.
int getFD() const { return _fd; }

Expand Down Expand Up @@ -1448,6 +1453,30 @@ class Session final : public ProtocolHandlerInterface
return _response->state() == Response::State::Complete;
}

void callOnFinished()
{
if (_onFinished)
{
LOG_TRC("onFinished calling client");
auto self = shared_from_this();
try
{
[[maybe_unused]] const auto references = self.use_count();
assert(references > 1 && "Expected more than 1 reference to http::Session.");

_onFinished(std::static_pointer_cast<Session>(self));

assert(self.use_count() > 1 &&
"Erroneously onFinish reset 'this'. Use 'addCallback()' on the "
"SocketPoll to reset on idle instead.");
}
catch (const std::exception& exc)
{
LOG_ERR("Error while invoking onFinished client callback: " << exc.what());
}
}
}

/// Set up a new request and response.
void newRequest(const Request& req)
{
Expand All @@ -1468,26 +1497,8 @@ class Session final : public ProtocolHandlerInterface
assert(_response->state() != Response::State::Incomplete &&
"Unexpected response in Incomplete state");
assert(_response->done() && "Must have response in done state");
if (_onFinished)
{
LOG_TRC("onFinished calling client");
auto self = shared_from_this();
try
{
[[maybe_unused]] const auto references = self.use_count();
assert(references > 1 && "Expected more than 1 reference to http::Session.");

_onFinished(std::static_pointer_cast<Session>(self));

assert(self.use_count() > 1 &&
"Erroneously onFinish reset 'this'. Use 'addCallback()' on the "
"SocketPoll to reset on idle instead.");
}
catch (const std::exception& exc)
{
LOG_ERR("Error while invoking onFinished client callback: " << exc.what());
}
}
callOnFinished();

if (_response->header().getConnectionToken() == Header::ConnectionToken::Close)
{
Expand Down Expand Up @@ -1609,11 +1620,34 @@ class Session final : public ProtocolHandlerInterface

if (!socket->send(_request))
{
_result = net::asyncConnectResult::SocketError;
LOG_ERR("Error while writing to socket");
}
}
}

void callOnConnectFail()
{
if (_onConnectFail) {
auto self = shared_from_this();
try
{
[[maybe_unused]] const auto references = self.use_count();
assert(references > 1 && "Expected more than 1 reference to http::Session.");

_onConnectFail(std::static_pointer_cast<Session>(self));

assert(self.use_count() > 1 &&
"Erroneously onConnectFail reset 'this'. Use 'addCallback()' on the "
"SocketPoll to reset on idle instead.");
}
catch (const std::exception& exc)
{
LOG_ERR("Error while invoking onConnectFail client callback: " << exc.what());
}
}
}

// on failure the stream will be discarded, so save the ssl verification
// result while it is still available
void onHandshakeFail() override
Expand All @@ -1623,7 +1657,10 @@ class Session final : public ProtocolHandlerInterface
{
LOG_TRC("onHandshakeFail");
_handshakeSslVerifyFailure = socket->getSslVerifyResult();
_result = net::asyncConnectResult::SSLHandShakeFailure;
}

callOnConnectFail();
}

void onDisconnect() override
Expand Down Expand Up @@ -1659,22 +1696,20 @@ class Session final : public ProtocolHandlerInterface
return socket; // Return the shared pointer.
}

void asyncConnectCompleted(SocketPoll& poll, std::shared_ptr<StreamSocket> socket)
void asyncConnectCompleted(SocketPoll& poll, const std::shared_ptr<StreamSocket> &socket, net::asyncConnectResult result)
{
assert((!socket || _fd == socket->getFD()) &&
"The socket FD must have been set in onConnect");

// When used with proxy.php we may indeed get nullptr here.
// assert(socket && "Unexpected nullptr returned from net::connect");
_socket = socket; // Hold a weak pointer to it.
_result = result;

if (!socket)
{
LOG_ERR("Failed to connect to " << _host << ':' << _port);

if (_onConnectFail)
_onConnectFail();

callOnConnectFail();
return;
}

Expand All @@ -1689,9 +1724,9 @@ class Session final : public ProtocolHandlerInterface
{
_socket.reset(); // Reset to make sure we are disconnected.

auto pushConnectCompleteToPoll = [this, &poll](std::shared_ptr<StreamSocket> socket) {
poll.addCallback([selfLifecycle = shared_from_this(), this, &poll, socket=std::move(socket)]() {
asyncConnectCompleted(poll, socket);
auto pushConnectCompleteToPoll = [this, &poll](std::shared_ptr<StreamSocket> socket, net::asyncConnectResult result ) {
poll.addCallback([selfLifecycle = shared_from_this(), this, &poll, socket=std::move(socket), &result]() {
asyncConnectCompleted(poll, socket, result);
});
};

Expand Down Expand Up @@ -1744,6 +1779,7 @@ class Session final : public ProtocolHandlerInterface
ConnectFailCallback _onConnectFail;
std::shared_ptr<Response> _response;
std::weak_ptr<StreamSocket> _socket; ///< Must be the last member.
net::asyncConnectResult _result; // last connection tentative result
};

/// HTTP Get a URL synchronously.
Expand Down
13 changes: 10 additions & 3 deletions net/NetUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
if (host.empty() || port.empty())
{
LOG_ERR("Invalid host/port " << host << ':' << port);
asyncCb(nullptr);
asyncCb(nullptr, asyncConnectResult::HostNameError);
return;
}

Expand All @@ -389,7 +389,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
if (isSSL)
{
LOG_ERR("Error: isSSL socket requested but SSL is not compiled in.");
asyncCb(nullptr);
asyncCb(nullptr, asyncConnectResult::MissingSSLError);
return;
}
#endif
Expand All @@ -399,6 +399,8 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
{
std::shared_ptr<StreamSocket> socket;

asyncConnectResult result = asyncConnectResult::UnknownHostError;

if (const addrinfo* ainfo = hostEntry.getAddrInfo())
{
for (const addrinfo* ai = ainfo; ai; ai = ai->ai_next)
Expand All @@ -408,13 +410,15 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
int fd = ::socket(ai->ai_addr->sa_family, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);
if (fd < 0)
{
result = asyncConnectResult::SocketError;
LOG_SYS("Failed to create socket");
continue;
}

int res = ::connect(fd, ai->ai_addr, ai->ai_addrlen);
if (res < 0 && errno != EINPROGRESS)
{
result = asyncConnectResult::ConnectionError;
LOG_SYS("Failed to connect to " << host);
::close(fd);
}
Expand All @@ -439,9 +443,12 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
{
LOG_DBG('#' << fd << " New socket connected to " << host << ':' << port
<< " (" << (isSSL ? "SSL)" : "Unencrypted)"));
result = asyncConnectResult::Ok;
break;
}

result = asyncConnectResult::SocketError;

LOG_ERR("Failed to allocate socket for client websocket " << host);
::close(fd);
break;
Expand All @@ -452,7 +459,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
else
LOG_SYS("Failed to lookup host [" << host << "]. Skipping");

asyncCb(std::move(socket));
asyncCb(std::move(socket), result);
};

net::AsyncDNS::DNSThreadDumpStateFn dumpState = [host, port]() -> std::string
Expand Down
11 changes: 10 additions & 1 deletion net/NetUtil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ std::shared_ptr<StreamSocket>
connect(const std::string& host, const std::string& port, const bool isSSL,
const std::shared_ptr<ProtocolHandlerInterface>& protocolHandler);

typedef std::function<void(std::shared_ptr<StreamSocket>)> asyncConnectCB;
enum class asyncConnectResult{
Ok = 0,
SocketError,
ConnectionError,
HostNameError,
UnknownHostError,
SSLHandShakeFailure,
};

typedef std::function<void(std::shared_ptr<StreamSocket>, asyncConnectResult result)> asyncConnectCB;

void
asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
Expand Down
6 changes: 3 additions & 3 deletions test/HttpRequestTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ void HttpRequestTests::testSimpleGet()

std::unique_lock<std::mutex> lock(mutex);

httpSession->setConnectFailHandler([]() {
httpSession->setConnectFailHandler([](const std::shared_ptr<http::Session>&) {
LOK_ASSERT_FAIL("Unexpected connection failure");
});

Expand Down Expand Up @@ -535,7 +535,7 @@ void HttpRequestTests::test500GetStatuses()
std::unique_lock<std::mutex> lock(mutex);
timedout = true; // Assume we timed out until we prove otherwise.

httpSession->setConnectFailHandler([]() {
httpSession->setConnectFailHandler([](const std::shared_ptr<http::Session>&) {
LOK_ASSERT_FAIL("Unexpected connection failure");
});

Expand Down Expand Up @@ -628,7 +628,7 @@ void HttpRequestTests::testSimplePost_External()

std::unique_lock<std::mutex> lock(mutex);

httpSession->setConnectFailHandler([]() {
httpSession->setConnectFailHandler([](const std::shared_ptr<http::Session>&) {
LOK_ASSERT_FAIL("Unexpected connection failure");
});

Expand Down
2 changes: 1 addition & 1 deletion test/UnitProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class UnitProxy : public UnitWSD
// Request from rating.collaboraonline.com.
_req = http::Request("/browser/a90f83c/foo/remote/static/lokit-extra-img.svg");

httpSession->setConnectFailHandler([this]() {
httpSession->setConnectFailHandler([this](const std::shared_ptr<http::Session>&) {
LOK_ASSERT_FAIL("Unexpected connection failure");
});

Expand Down
Loading

0 comments on commit eed41e5

Please sign in to comment.