diff --git a/net/HttpRequest.hpp b/net/HttpRequest.hpp index 6ad623ada669..9e67cfde512a 100644 --- a/net/HttpRequest.hpp +++ b/net/HttpRequest.hpp @@ -1236,7 +1236,7 @@ class Session final : public ProtocolHandlerInterface void setFinishedHandler(FinishedCallback onFinished) { _onFinished = std::move(onFinished); } /// The onConnectFail callback handler signature. - using ConnectFailCallback = std::function; + using ConnectFailCallback = std::function& session)>; void setConnectFailHandler(ConnectFailCallback onConnectFail) { _onConnectFail = std::move(onConnectFail); } @@ -1380,6 +1380,11 @@ class Session final : public ProtocolHandlerInterface } } + net::AsyncConnectResult connectionResult() + { + return _result; + } + /// Returns the socket FD, for logging/informational purposes. int getFD() const { return _fd; } @@ -1448,6 +1453,30 @@ class Session final : public ProtocolHandlerInterface return _response->state() == Response::State::Complete; } + void callOnFinished() + { + if (!_onFinished) + return; + + LOG_TRC("onFinished calling client"); + std::shared_ptr self = shared_from_this(); + try + { + [[maybe_unused]] const long references = self.use_count(); + assert(references > 1 && "Expected more than 1 reference to http::Session."); + + _onFinished(self); + + assert(self.use_count() > 1 && + "Erroneously onFinish reset 'this'. Use 'addCallback()' on the " + "SocketPoll to reset on idle instead."); + } + catch (const std::exception& exc) + { + LOG_ERR("Error while invoking onFinished client callback: " << exc.what()); + } + } + /// Set up a new request and response. void newRequest(const Request& req) { @@ -1468,26 +1497,8 @@ class Session final : public ProtocolHandlerInterface assert(_response->state() != Response::State::Incomplete && "Unexpected response in Incomplete state"); assert(_response->done() && "Must have response in done state"); - if (_onFinished) - { - LOG_TRC("onFinished calling client"); - auto self = shared_from_this(); - try - { - [[maybe_unused]] const auto references = self.use_count(); - assert(references > 1 && "Expected more than 1 reference to http::Session."); - - _onFinished(std::static_pointer_cast(self)); - assert(self.use_count() > 1 && - "Erroneously onFinish reset 'this'. Use 'addCallback()' on the " - "SocketPoll to reset on idle instead."); - } - catch (const std::exception& exc) - { - LOG_ERR("Error while invoking onFinished client callback: " << exc.what()); - } - } + callOnFinished(); if (_response->header().getConnectionToken() == Header::ConnectionToken::Close) { @@ -1609,11 +1620,40 @@ class Session final : public ProtocolHandlerInterface if (!socket->send(_request)) { + _result = net::AsyncConnectResult::SocketError; LOG_ERR("Error while writing to socket"); } } } + std::shared_ptr shared_from_this() + { + return std::static_pointer_cast(ProtocolHandlerInterface::shared_from_this()); + } + + void callOnConnectFail() + { + if (!_onConnectFail) + return; + + std::shared_ptr self = shared_from_this(); + try + { + [[maybe_unused]] const long references = self.use_count(); + assert(references > 1 && "Expected more than 1 reference to http::Session."); + + _onConnectFail(self); + + assert(self.use_count() > 1 && + "Erroneously onConnectFail reset 'this'. Use 'addCallback()' on the " + "SocketPoll to reset on idle instead."); + } + catch (const std::exception& exc) + { + LOG_ERR("Error while invoking onConnectFail client callback: " << exc.what()); + } + } + // on failure the stream will be discarded, so save the ssl verification // result while it is still available void onHandshakeFail() override @@ -1623,7 +1663,10 @@ class Session final : public ProtocolHandlerInterface { LOG_TRC("onHandshakeFail"); _handshakeSslVerifyFailure = socket->getSslVerifyResult(); + _result = net::AsyncConnectResult::SSLHandShakeFailure; } + + callOnConnectFail(); } void onDisconnect() override @@ -1659,7 +1702,7 @@ class Session final : public ProtocolHandlerInterface return socket; // Return the shared pointer. } - void asyncConnectCompleted(SocketPoll& poll, std::shared_ptr socket) + void asyncConnectCompleted(SocketPoll& poll, const std::shared_ptr &socket, net::AsyncConnectResult result) { assert((!socket || _fd == socket->getFD()) && "The socket FD must have been set in onConnect"); @@ -1667,14 +1710,12 @@ class Session final : public ProtocolHandlerInterface // When used with proxy.php we may indeed get nullptr here. // assert(socket && "Unexpected nullptr returned from net::connect"); _socket = socket; // Hold a weak pointer to it. + _result = result; if (!socket) { LOG_ERR("Failed to connect to " << _host << ':' << _port); - - if (_onConnectFail) - _onConnectFail(); - + callOnConnectFail(); return; } @@ -1689,9 +1730,9 @@ class Session final : public ProtocolHandlerInterface { _socket.reset(); // Reset to make sure we are disconnected. - auto pushConnectCompleteToPoll = [this, &poll](std::shared_ptr socket) { - poll.addCallback([selfLifecycle = shared_from_this(), this, &poll, socket=std::move(socket)]() { - asyncConnectCompleted(poll, socket); + auto pushConnectCompleteToPoll = [this, &poll](std::shared_ptr socket, net::AsyncConnectResult result ) { + poll.addCallback([selfLifecycle = shared_from_this(), this, &poll, socket=std::move(socket), &result]() { + asyncConnectCompleted(poll, socket, result); }); }; @@ -1744,6 +1785,7 @@ class Session final : public ProtocolHandlerInterface ConnectFailCallback _onConnectFail; std::shared_ptr _response; std::weak_ptr _socket; ///< Must be the last member. + net::AsyncConnectResult _result; // last connection tentative result }; /// HTTP Get a URL synchronously. diff --git a/net/NetUtil.cpp b/net/NetUtil.cpp index 271cd4a28faa..29691567785c 100644 --- a/net/NetUtil.cpp +++ b/net/NetUtil.cpp @@ -379,7 +379,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL, if (host.empty() || port.empty()) { LOG_ERR("Invalid host/port " << host << ':' << port); - asyncCb(nullptr); + asyncCb(nullptr, AsyncConnectResult::HostNameError); return; } @@ -389,7 +389,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL, if (isSSL) { LOG_ERR("Error: isSSL socket requested but SSL is not compiled in."); - asyncCb(nullptr); + asyncCb(nullptr, asyncConnectResult::MissingSSLError); return; } #endif @@ -399,6 +399,8 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL, { std::shared_ptr socket; + AsyncConnectResult result = AsyncConnectResult::UnknownHostError; + if (const addrinfo* ainfo = hostEntry.getAddrInfo()) { for (const addrinfo* ai = ainfo; ai; ai = ai->ai_next) @@ -408,6 +410,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL, int fd = ::socket(ai->ai_addr->sa_family, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); if (fd < 0) { + result = AsyncConnectResult::SocketError; LOG_SYS("Failed to create socket"); continue; } @@ -415,6 +418,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL, int res = ::connect(fd, ai->ai_addr, ai->ai_addrlen); if (res < 0 && errno != EINPROGRESS) { + result = AsyncConnectResult::ConnectionError; LOG_SYS("Failed to connect to " << host); ::close(fd); } @@ -439,9 +443,12 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL, { LOG_DBG('#' << fd << " New socket connected to " << host << ':' << port << " (" << (isSSL ? "SSL)" : "Unencrypted)")); + result = AsyncConnectResult::Ok; break; } + result = AsyncConnectResult::SocketError; + LOG_ERR("Failed to allocate socket for client websocket " << host); ::close(fd); break; @@ -452,7 +459,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL, else LOG_SYS("Failed to lookup host [" << host << "]. Skipping"); - asyncCb(std::move(socket)); + asyncCb(std::move(socket), result); }; net::AsyncDNS::DNSThreadDumpStateFn dumpState = [host, port]() -> std::string @@ -570,7 +577,7 @@ connect(std::string uri, const std::shared_ptr& protoc } bool parseUri(std::string uri, std::string& scheme, std::string& host, std::string& port, - std::string& url) + std::string& pathAndQuery) { const auto itScheme = uri.find("://"); if (itScheme != uri.npos) @@ -587,12 +594,12 @@ bool parseUri(std::string uri, std::string& scheme, std::string& host, std::stri const auto itUrl = uri.find('/'); if (itUrl != uri.npos) { - url = uri.substr(itUrl); // Including the first foreslash. + pathAndQuery = uri.substr(itUrl); // Including the first foreslash. uri = uri.substr(0, itUrl); } else { - url.clear(); + pathAndQuery.clear(); } const auto itPort = uri.find(':'); diff --git a/net/NetUtil.hpp b/net/NetUtil.hpp index fad832750ace..721c65595321 100644 --- a/net/NetUtil.hpp +++ b/net/NetUtil.hpp @@ -89,7 +89,16 @@ std::shared_ptr connect(const std::string& host, const std::string& port, const bool isSSL, const std::shared_ptr& protocolHandler); -typedef std::function)> asyncConnectCB; +enum class AsyncConnectResult{ + Ok = 0, + SocketError, + ConnectionError, + HostNameError, + UnknownHostError, + SSLHandShakeFailure, +}; + +typedef std::function, AsyncConnectResult result)> asyncConnectCB; void asyncConnect(const std::string& host, const std::string& port, const bool isSSL, @@ -103,14 +112,14 @@ connect(std::string uri, const std::shared_ptr& protoc /// Decomposes a URI into its components. /// Returns true if parsing was successful. bool parseUri(std::string uri, std::string& scheme, std::string& host, std::string& port, - std::string& url); + std::string& pathAndQuery); /// Decomposes a URI into its components. /// Returns true if parsing was successful. inline bool parseUri(std::string uri, std::string& scheme, std::string& host, std::string& port) { - std::string url; - return parseUri(std::move(uri), scheme, host, port, url); + std::string pathAndQuery; + return parseUri(std::move(uri), scheme, host, port, pathAndQuery); } /// Return the locator given a URI. diff --git a/test/HttpRequestTests.cpp b/test/HttpRequestTests.cpp index f18eafe30572..5201123db0c1 100644 --- a/test/HttpRequestTests.cpp +++ b/test/HttpRequestTests.cpp @@ -311,7 +311,7 @@ void HttpRequestTests::testSimpleGet() std::unique_lock lock(mutex); - httpSession->setConnectFailHandler([]() { + httpSession->setConnectFailHandler([](const std::shared_ptr&) { LOK_ASSERT_FAIL("Unexpected connection failure"); }); @@ -535,7 +535,7 @@ void HttpRequestTests::test500GetStatuses() std::unique_lock lock(mutex); timedout = true; // Assume we timed out until we prove otherwise. - httpSession->setConnectFailHandler([]() { + httpSession->setConnectFailHandler([](const std::shared_ptr&) { LOK_ASSERT_FAIL("Unexpected connection failure"); }); @@ -628,7 +628,7 @@ void HttpRequestTests::testSimplePost_External() std::unique_lock lock(mutex); - httpSession->setConnectFailHandler([]() { + httpSession->setConnectFailHandler([](const std::shared_ptr&) { LOK_ASSERT_FAIL("Unexpected connection failure"); }); diff --git a/test/UnitProxy.cpp b/test/UnitProxy.cpp index b44963835751..443b6aca16fa 100644 --- a/test/UnitProxy.cpp +++ b/test/UnitProxy.cpp @@ -53,7 +53,7 @@ class UnitProxy : public UnitWSD // Request from rating.collaboraonline.com. _req = http::Request("/browser/a90f83c/foo/remote/static/lokit-extra-img.svg"); - httpSession->setConnectFailHandler([this]() { + httpSession->setConnectFailHandler([this](const std::shared_ptr&) { LOK_ASSERT_FAIL("Unexpected connection failure"); }); diff --git a/test/WhiteBoxTests.cpp b/test/WhiteBoxTests.cpp index 0ad7fc4ae761..6877092e2ebc 100644 --- a/test/WhiteBoxTests.cpp +++ b/test/WhiteBoxTests.cpp @@ -1099,72 +1099,72 @@ void WhiteBoxTests::testParseUriUrl() std::string scheme = "***"; std::string host = "***"; std::string port = "***"; - std::string url = "***"; + std::string pathAndQuery = "***"; - LOK_ASSERT(!net::parseUri(std::string(), scheme, host, port, url)); + LOK_ASSERT(!net::parseUri(std::string(), scheme, host, port, pathAndQuery)); LOK_ASSERT(scheme.empty()); LOK_ASSERT(host.empty()); LOK_ASSERT(port.empty()); - LOK_ASSERT(url.empty()); + LOK_ASSERT(pathAndQuery.empty()); - LOK_ASSERT(net::parseUri("localhost", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("localhost", scheme, host, port, pathAndQuery)); LOK_ASSERT(scheme.empty()); LOK_ASSERT_EQUAL(std::string("localhost"), host); LOK_ASSERT(port.empty()); - LOK_ASSERT(url.empty()); + LOK_ASSERT(pathAndQuery.empty()); - LOK_ASSERT(net::parseUri("127.0.0.1", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("127.0.0.1", scheme, host, port, pathAndQuery)); LOK_ASSERT(scheme.empty()); LOK_ASSERT_EQUAL(std::string("127.0.0.1"), host); LOK_ASSERT(port.empty()); - LOK_ASSERT(url.empty()); + LOK_ASSERT(pathAndQuery.empty()); - LOK_ASSERT(net::parseUri("domain.com", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("domain.com", scheme, host, port, pathAndQuery)); LOK_ASSERT(scheme.empty()); LOK_ASSERT_EQUAL(std::string("domain.com"), host); LOK_ASSERT(port.empty()); - LOK_ASSERT(url.empty()); + LOK_ASSERT(pathAndQuery.empty()); - LOK_ASSERT(net::parseUri("127.0.0.1:9999", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("127.0.0.1:9999", scheme, host, port, pathAndQuery)); LOK_ASSERT(scheme.empty()); LOK_ASSERT_EQUAL(std::string("127.0.0.1"), host); LOK_ASSERT_EQUAL(std::string("9999"), port); - LOK_ASSERT(url.empty()); + LOK_ASSERT(pathAndQuery.empty()); - LOK_ASSERT(net::parseUri("domain.com:88", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("domain.com:88", scheme, host, port, pathAndQuery)); LOK_ASSERT(scheme.empty()); LOK_ASSERT_EQUAL(std::string("domain.com"), host); LOK_ASSERT_EQUAL(std::string("88"), port); - LOK_ASSERT(url.empty()); + LOK_ASSERT(pathAndQuery.empty()); - LOK_ASSERT(net::parseUri("http://domain.com", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("http://domain.com", scheme, host, port, pathAndQuery)); LOK_ASSERT_EQUAL(std::string("http://"), scheme); LOK_ASSERT_EQUAL(std::string("domain.com"), host); LOK_ASSERT(port.empty()); - LOK_ASSERT(url.empty()); + LOK_ASSERT(pathAndQuery.empty()); - LOK_ASSERT(net::parseUri("https://domain.com:88", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("https://domain.com:88", scheme, host, port, pathAndQuery)); LOK_ASSERT_EQUAL(std::string("https://"), scheme); LOK_ASSERT_EQUAL(std::string("domain.com"), host); LOK_ASSERT_EQUAL(std::string("88"), port); - LOK_ASSERT(net::parseUri("http://domain.com/path/to/file", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("http://domain.com/path/to/file", scheme, host, port, pathAndQuery)); LOK_ASSERT_EQUAL(std::string("http://"), scheme); LOK_ASSERT_EQUAL(std::string("domain.com"), host); LOK_ASSERT(port.empty()); - LOK_ASSERT_EQUAL(std::string("/path/to/file"), url); + LOK_ASSERT_EQUAL(std::string("/path/to/file"), pathAndQuery); - LOK_ASSERT(net::parseUri("https://domain.com:88/path/to/file", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("https://domain.com:88/path/to/file", scheme, host, port, pathAndQuery)); LOK_ASSERT_EQUAL(std::string("https://"), scheme); LOK_ASSERT_EQUAL(std::string("domain.com"), host); LOK_ASSERT_EQUAL(std::string("88"), port); - LOK_ASSERT_EQUAL(std::string("/path/to/file"), url); + LOK_ASSERT_EQUAL(std::string("/path/to/file"), pathAndQuery); - LOK_ASSERT(net::parseUri("wss://127.0.0.1:9999/", scheme, host, port, url)); + LOK_ASSERT(net::parseUri("wss://127.0.0.1:9999/", scheme, host, port, pathAndQuery)); LOK_ASSERT_EQUAL(std::string("wss://"), scheme); LOK_ASSERT_EQUAL(std::string("127.0.0.1"), host); LOK_ASSERT_EQUAL(std::string("9999"), port); - LOK_ASSERT_EQUAL(std::string("/"), url); + LOK_ASSERT_EQUAL(std::string("/"), pathAndQuery); } void WhiteBoxTests::testParseUrl() diff --git a/wsd/ClientRequestDispatcher.cpp b/wsd/ClientRequestDispatcher.cpp index f48cbc24a7dc..70dc59de142e 100644 --- a/wsd/ClientRequestDispatcher.cpp +++ b/wsd/ClientRequestDispatcher.cpp @@ -16,6 +16,7 @@ #endif #include +#include #include #include #include @@ -23,6 +24,8 @@ #include #include #include +#include +#include #include #include #include @@ -868,6 +871,8 @@ void ClientRequestDispatcher::handleIncomingMessage(SocketDisposition& dispositi servedSync = handleWopiDiscoveryRequest(requestDetails, socket); else if (requestDetails.equals(1, "capabilities")) servedSync = handleCapabilitiesRequest(request, socket); + else if (requestDetails.equals(1, "wopiAccessCheck")) + handleWopiAccessCheckRequest(request, message, socket); else HttpHelper::sendErrorAndShutdown(http::StatusCode::BadRequest, socket); } @@ -1096,6 +1101,245 @@ bool ClientRequestDispatcher::handleWopiDiscoveryRequest( return true; } + +// NB: these names are part of the published API, and should not be renamed or altered but can be expanded +STATE_ENUM(CheckStatus, + Ok, + NotHttpSucess, + HostNotFound, + WopiHostNotAllowed, + HostUnReachable, + UnspecifiedError, + ConnectionAborted, + ConnectionRefused, + CertificateValidation, + SSLHandshakeFail, + MissingSsl, + NotHttps, + NoScheme, + Timeout, +); + +bool ClientRequestDispatcher::handleWopiAccessCheckRequest(const Poco::Net::HTTPRequest& request, + Poco::MemoryInputStream& message, + const std::shared_ptr& socket) +{ + assert(socket && "Must have a valid socket"); + + LOG_DBG("Wopi Access Check request: " << request.getURI()); + + Poco::MemoryInputStream startmessage(&socket->getInBuffer()[0], socket->getInBuffer().size()); + StreamSocket::MessageMap map; + + Poco::JSON::Object::Ptr jsonObject; + + std::string text(std::istreambuf_iterator(message), {}); + LOG_TRC("Wopi Access Check request text: " << text); + + std::string callbackUrlStr; + + if (!JsonUtil::parseJSON(text, jsonObject)) + { + LOG_WRN_S("Wopi Access Check request error, json object expected got [" + << text << "] on request to URL: " << request.getURI()); + + HttpHelper::sendErrorAndShutdown(http::StatusCode::BadRequest, socket); + return false; + } + + if (!JsonUtil::findJSONValue(jsonObject, "callbackUrl", callbackUrlStr)) + { + LOG_WRN_S("Wopi Access Check request error, missing key callbackUrl expected got [" + << text << "] on request to URL: " << request.getURI()); + + HttpHelper::sendErrorAndShutdown(http::StatusCode::BadRequest, socket); + return false; + } + + LOG_TRC("Wopi Access Check request callbackUrlStr: " << callbackUrlStr); + + std::string scheme, host, portStr, pathAndQuery; + if (!net::parseUri(callbackUrlStr, scheme, host, portStr, pathAndQuery)) { + LOG_WRN_S("Wopi Access Check request error, invalid url [" + << callbackUrlStr << "] on request to URL: " << request.getURI() << scheme); + + HttpHelper::sendErrorAndShutdown(http::StatusCode::BadRequest, socket); + return false; + } + http::Session::Protocol protocol = http::Session::Protocol::HttpSsl; + ulong port = 443; + if (scheme == "https://" || scheme.empty()) { + // empty scheme assumes https + } else if (scheme == "http://") { + protocol = http::Session::Protocol::HttpUnencrypted; + port = 80; + } else { + LOG_WRN_S("Wopi Access Check request error, bad request protocol [" + << text << "] on request to URL: " << request.getURI() << scheme); + + HttpHelper::sendErrorAndShutdown(http::StatusCode::BadRequest, socket); + return false; + } + + if (!portStr.empty()) { + try { + port = std::stoul(portStr); + + } catch(std::invalid_argument &exception) { + LOG_WRN("Wopi Access Check error parsing invalid_argument portStr:" << portStr); + HttpHelper::sendErrorAndShutdown(http::StatusCode::BadRequest, socket); + return false; + } catch(std::exception &exception) { + + LOG_WRN_S("Wopi Access Check request error, bad request invalid porl [" + << text << "] on request to URL: " << request.getURI()); + + HttpHelper::sendErrorAndShutdown(http::StatusCode::BadRequest, socket); + return false; + } + } + + LOG_TRC("Wopi Access Check request scheme: " << scheme << " " << port); + + const auto sendResult = [this, socket](CheckStatus result) + { + const auto output = "{\"status\": \"" + JsonUtil::escapeJSONValue(name(result)) + "\"}\n"; + + http::Response jsonResponse(http::StatusCode::OK); + FileServerRequestHandler::hstsHeaders(jsonResponse); + jsonResponse.set("Last-Modified", Util::getHttpTimeNow()); + jsonResponse.setBody(output, "application/json"); + jsonResponse.set("X-Content-Type-Options", "nosniff"); + + socket->sendAndShutdown(jsonResponse); + LOG_INF("Wopi Access Check request, result" << name(result)); + }; + + if (scheme.empty()) + { + sendResult(CheckStatus::NoScheme); + return true; + } + // if the wopi hosts uses https, so must cool or it will have Mixed Content errors + if (protocol == http::Session::Protocol::HttpSsl && +#if ENABLE_SSL + !(ConfigUtil::isSslEnabled() || ConfigUtil::isSSLTermination()) +#else + false +#endif + ) + { + sendResult(CheckStatus::NotHttps); + return true; + } + + if (HostUtil::isWopiHostsEmpty()) + // make sure the wopi hosts settings are loaded + StorageBase::initialize(); + + bool wopiHostAllowed = false; + if (Util::iequal(ConfigUtil::getString("storage.wopi.alias_groups[@mode]", "first"), "first")) + // if first mode was selected and wopi Hosts are empty + // the domain is allowed, as it will be the effective "first" host + wopiHostAllowed = HostUtil::isWopiHostsEmpty(); + + if (!wopiHostAllowed) { + // port and scheme from wopi host config are currently ignored by HostUtil + LOG_TRC("Wopi Access Check, matching allowed wopi host for host " << host); + wopiHostAllowed = HostUtil::allowedWopiHost(host); + } + if (!wopiHostAllowed) + { + LOG_TRC("Wopi Access Check, wopi host not allowed " << host); + sendResult(CheckStatus::WopiHostNotAllowed); + return true; + } + + http::Request httpRequest(pathAndQuery.empty() ? "/" : pathAndQuery); + auto httpProbeSession = http::Session::create(host, protocol, port); + httpProbeSession->setTimeout(std::chrono::seconds(2)); + + httpProbeSession->setConnectFailHandler( + [=, this] (const std::shared_ptr& probeSession){ + + CheckStatus status = CheckStatus::UnspecifiedError; + + const auto result = probeSession->connectionResult(); + + if (result == net::AsyncConnectResult::UnknownHostError || result == net::AsyncConnectResult::HostNameError) + { + status = CheckStatus::HostNotFound; + } + + if (result == net::AsyncConnectResult::SSLHandShakeFailure) { + status = CheckStatus::SSLHandshakeFail; + } + + if (!probeSession->getSslVerifyMessage().empty()) + { + status = CheckStatus::CertificateValidation; + + LOG_DBG("Result ssl: " << probeSession->getSslVerifyMessage()); + } + + sendResult(status); + }); + + auto finishHandler = [=, this](const std::shared_ptr& probeSession) + { + LOG_TRC("finishHandler "); + + CheckStatus status = CheckStatus::Ok; + const auto lastErrno = errno; + + const auto httpResponse = probeSession->response(); + const auto responseState = httpResponse->state(); + LOG_DBG("Wopi Access Check: got response state: " << responseState << " " + << ", response status code: " <statusCode() << " " + << ", last errno: " << lastErrno); + + if (responseState != http::Response::State::Complete) + { + // are TLS errors here ? + status = CheckStatus::UnspecifiedError; + } + + if (responseState == http::Response::State::Timeout) + status = CheckStatus::Timeout; + + + const auto result = probeSession->connectionResult(); + + if (result == net::AsyncConnectResult::UnknownHostError) + status = CheckStatus::HostNotFound; + + if (protocol == http::Session::Protocol::HttpSsl && lastErrno == ENOTCONN) + status = CheckStatus::MissingSsl; + + if (result == net::AsyncConnectResult::ConnectionError) + status = CheckStatus::ConnectionAborted; + + // TODO complete error coverage + // certificate errors + // self-signed + // expired + + if (!probeSession->getSslVerifyMessage().empty()) + { + status = CheckStatus::CertificateValidation; + + LOG_DBG("Result ssl: " << probeSession->getSslVerifyMessage()); + } + + sendResult(status); + }; + + httpProbeSession->setFinishedHandler(std::move(finishHandler)); + httpProbeSession->asyncRequest(httpRequest, *COOLWSD::getWebServerPoll()); + + return true; +} + bool ClientRequestDispatcher::handleClipboardRequest(const Poco::Net::HTTPRequest& request, Poco::MemoryInputStream& message, SocketDisposition& disposition, diff --git a/wsd/ClientRequestDispatcher.hpp b/wsd/ClientRequestDispatcher.hpp index b19b035e9d67..141a1e6251f5 100644 --- a/wsd/ClientRequestDispatcher.hpp +++ b/wsd/ClientRequestDispatcher.hpp @@ -71,6 +71,10 @@ class ClientRequestDispatcher final : public SimpleSocketHandler bool handleCapabilitiesRequest(const Poco::Net::HTTPRequest& request, const std::shared_ptr& socket); + bool handleWopiAccessCheckRequest(const Poco::Net::HTTPRequest& request, + Poco::MemoryInputStream& message, + const std::shared_ptr& socket); + /// @return true if request has been handled synchronously and response sent, otherwise false static bool handleClipboardRequest(const Poco::Net::HTTPRequest& request, Poco::MemoryInputStream& message, diff --git a/wsd/ClientSession.cpp b/wsd/ClientSession.cpp index 6b4332328a6a..ac32e3ad8ef3 100644 --- a/wsd/ClientSession.cpp +++ b/wsd/ClientSession.cpp @@ -369,7 +369,7 @@ void ClientSession::handleClipboardRequest(DocumentBroker::ClipboardRequest { httpSession->setFinishedHandler(std::move(finishedCallback)); - http::Session::ConnectFailCallback connectFailCallback = [this, url]() + http::Session::ConnectFailCallback connectFailCallback = [this, url](const std::shared_ptr& /* session */) { LOG_ERR( "Failed to start an async clipboard download request with URL [" diff --git a/wsd/HostUtil.cpp b/wsd/HostUtil.cpp index 4ee79d0c6023..0f89e270b5c6 100644 --- a/wsd/HostUtil.cpp +++ b/wsd/HostUtil.cpp @@ -240,6 +240,11 @@ void HostUtil::setFirstHost(const Poco::URI& uri) } } +bool HostUtil::isWopiHostsEmpty() +{ + return WopiHosts.empty(); +} + #endif // !MOBILEAPP /* vim:set shiftwidth=4 softtabstop=4 expandtab: */ diff --git a/wsd/HostUtil.hpp b/wsd/HostUtil.hpp index 317636625e0b..faa8b61437bd 100644 --- a/wsd/HostUtil.hpp +++ b/wsd/HostUtil.hpp @@ -55,6 +55,8 @@ class HostUtil static void setFirstHost(const Poco::URI& uri); + static bool isWopiHostsEmpty(); + private: /// add host to WopiHosts static void addWopiHost(const std::string& host, bool allow); diff --git a/wsd/ProxyRequestHandler.cpp b/wsd/ProxyRequestHandler.cpp index c210cb152600..983b513038bf 100644 --- a/wsd/ProxyRequestHandler.cpp +++ b/wsd/ProxyRequestHandler.cpp @@ -90,7 +90,7 @@ void ProxyRequestHandler::handleRequest(const std::string& relPath, sessionProxy->setFinishedHandler(std::move(proxyCallback)); http::Session::ConnectFailCallback connectFailCallback = - [socket]() { + [socket](const std::shared_ptr& /* session */) { HttpHelper::sendErrorAndShutdown(http::StatusCode::BadRequest, socket); }; sessionProxy->setConnectFailHandler(std::move(connectFailCallback)); diff --git a/wsd/wopi/CheckFileInfo.cpp b/wsd/wopi/CheckFileInfo.cpp index 263337659a23..5bcbc3807f34 100644 --- a/wsd/wopi/CheckFileInfo.cpp +++ b/wsd/wopi/CheckFileInfo.cpp @@ -145,7 +145,7 @@ void CheckFileInfo::checkFileInfo(int redirectLimit) _httpSession->setFinishedHandler(std::move(finishedCallback)); http::Session::ConnectFailCallback connectFailCallback = - [this]() + [this](const std::shared_ptr& /* httpSession */) { _state = State::Fail; LOG_ERR("Failed to start an async CheckFileInfo request"); diff --git a/wsd/wopi/WopiStorage.cpp b/wsd/wopi/WopiStorage.cpp index 90399f97571a..5ac9b2d13355 100644 --- a/wsd/wopi/WopiStorage.cpp +++ b/wsd/wopi/WopiStorage.cpp @@ -899,7 +899,7 @@ std::size_t WopiStorage::uploadLocalFileToStorageAsync( LOG_DBG(wopiLog << " async upload request: " << httpRequest.header().toString()); - _uploadHttpSession->setConnectFailHandler([asyncUploadCallback]() { + _uploadHttpSession->setConnectFailHandler([asyncUploadCallback](const std::shared_ptr& /* httpSession */) { LOG_ERR("Cannot connect for uploading to wopi storage."); asyncUploadCallback(AsyncUpload(AsyncUpload::State::Error, UploadResult(UploadResult::Result::FAILED, "Connection failed.")));