diff --git a/eng/pipelines/templates/jobs/archetype-sdk-client.yml b/eng/pipelines/templates/jobs/archetype-sdk-client.yml index e755496f90..3081c4a558 100644 --- a/eng/pipelines/templates/jobs/archetype-sdk-client.yml +++ b/eng/pipelines/templates/jobs/archetype-sdk-client.yml @@ -33,10 +33,10 @@ parameters: type: object default: [] - name: PreTestSteps - type: object + type: stepList default: [] - name: PostTestSteps - type: object + type: stepList default: [] jobs: diff --git a/eng/pipelines/templates/jobs/ci.tests.yml b/eng/pipelines/templates/jobs/ci.tests.yml index 1e1ede079b..19d7dd6535 100644 --- a/eng/pipelines/templates/jobs/ci.tests.yml +++ b/eng/pipelines/templates/jobs/ci.tests.yml @@ -47,10 +47,10 @@ parameters: type: boolean default: false - name: PreTestSteps - type: object + type: stepList default: [] - name: PostTestSteps - type: object + type: stepList default: [] @@ -145,7 +145,7 @@ jobs: Env: "$(CmakeEnvArg)" - ${{ parameters.PreTestSteps }} - + - pwsh: | ctest ` -C Debug ` @@ -156,7 +156,7 @@ jobs: -T Test workingDirectory: build displayName: Test - + - ${{ parameters.PostTestSteps }} - task: PublishTestResults@2 diff --git a/eng/pipelines/templates/jobs/live.tests.yml b/eng/pipelines/templates/jobs/live.tests.yml index 19e73bb0b5..de57d3db24 100644 --- a/eng/pipelines/templates/jobs/live.tests.yml +++ b/eng/pipelines/templates/jobs/live.tests.yml @@ -37,10 +37,10 @@ parameters: type: boolean default: false - name: PreTestSteps - type: object + type: stepList default: [] - name: PostTestSteps - type: object + type: stepList default: [] jobs: diff --git a/eng/pipelines/templates/stages/archetype-sdk-client.yml b/eng/pipelines/templates/stages/archetype-sdk-client.yml index 334e580647..6a86decd0f 100644 --- a/eng/pipelines/templates/stages/archetype-sdk-client.yml +++ b/eng/pipelines/templates/stages/archetype-sdk-client.yml @@ -68,10 +68,10 @@ parameters: type: string default: '' - name: PreTestSteps - type: object + type: stepList default: [] - name: PostTestSteps - type: object + type: stepList default: [] diff --git a/eng/pipelines/templates/stages/archetype-sdk-tests.yml b/eng/pipelines/templates/stages/archetype-sdk-tests.yml index a4aba5565b..ed3d8f5459 100644 --- a/eng/pipelines/templates/stages/archetype-sdk-tests.yml +++ b/eng/pipelines/templates/stages/archetype-sdk-tests.yml @@ -30,10 +30,10 @@ parameters: type: string default: '' - name: PreTestSteps - type: object + type: stepList default: [] - name: PostTestSteps - type: object + type: stepList default: [] stages: diff --git a/sdk/core/azure-core/inc/azure/core/http/curl_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/curl_transport.hpp index 5a8a6d91c0..cf6ef84ae3 100644 --- a/sdk/core/azure-core/inc/azure/core/http/curl_transport.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/curl_transport.hpp @@ -38,10 +38,31 @@ namespace Azure { namespace Core { namespace Http { * * @remark Libcurl does revocation list check by default for SSL backends that supports this * feature. However, the Azure SDK overrides libcurl's behavior and disables the revocation list - * check by default. - * + * check by default. This ensures that the LibCURL behavior matches the WinHTTP behavior. */ bool EnableCertificateRevocationListCheck = false; + + /** + * @brief This option allows SSL connections to proceed even if there is an error retrieving the + * Certificate Revocation List. + * + * @remark Note that this only works when LibCURL is configured to use openssl as its TLS + * provider. That functionally limits this check to Linux only, and then only when openssl is + * configured (the default). + */ + bool AllowFailedCrlRetrieval = false; + + /** + * @brief A set of PEM encoded X.509 certificates and CRLs describing the certificates used to + * validate the server. + * + * @remark The Azure SDK will not directly validate these certificates. + * + * @remark More about this option: + * https://curl.haxx.se/libcurl/c/CURLOPT_CAINFO_BLOB.html + * + */ + std::string PemEncodedExpectedRootCertificates; }; /** @@ -86,7 +107,8 @@ namespace Azure { namespace Core { namespace Http { */ std::string ProxyPassword; /** - * @brief The string for the certificate authenticator is sent to libcurl handle directly. + * @brief Path to a PEM encoded file containing the certificate authorities sent to libcurl + * handle directly. * * @remark The Azure SDK will not check if the path is valid or not. * @@ -95,6 +117,7 @@ namespace Azure { namespace Core { namespace Http { * */ std::string CAInfo; + /** * @brief All HTTP requests will keep the connection channel open to the service. * @@ -106,6 +129,7 @@ namespace Azure { namespace Core { namespace Http { * handle. It is `true` by default. */ bool HttpKeepAlive = true; + /** * @brief This option determines whether libcurl verifies the authenticity of the peer's * certificate. diff --git a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp index 645c387108..61e144939d 100644 --- a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp @@ -144,17 +144,35 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { * * @remark The URL for the proxy server to use for this connection. */ - Azure::Nullable HttpProxy; + Azure::Nullable HttpProxy{}; /** * @brief The username to use when authenticating with the proxy server. */ - std::string ProxyUserName; + std::string ProxyUserName{}; /** * @brief The password to use when authenticating with the proxy server. */ - std::string ProxyPassword; + std::string ProxyPassword{}; + + /** + * @brief Enable TLS Certificate validation against a certificate revocation list. + * + * @remark Note that by default CRL validation is *disabled*. + */ + bool EnableCertificateRevocationListCheck{false}; + + /** + * @brief Base64 encoded DER representation of an X.509 certificate expected in the certificate + * chain used in TLS connections. + * + * @remark Note that with the schannel and sectransp crypto backends, settting the + * expected root certificate disables access to the system certificate store. + * This means that the expected root certificate is the only certificate that will be trusted. + */ + std::string ExpectedTlsRootCertificate{}; + #endif // defined(CURL_ADAPTER) || defined(WINHTTP_ADAPTER) #endif // !defined(BUILD_TRANSPORT_CUSTOM_ADAPTER) diff --git a/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp index 7547f05be2..9001cbcbb5 100644 --- a/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp @@ -4,6 +4,7 @@ /** * @file * @brief #Azure::Core::Http::HttpTransport implementation via WinHTTP. + * cspell:words HCERTIFICATECHAIN PCCERT CCERT HCERTCHAINENGINE HCERTSTORE */ #pragma once @@ -22,11 +23,13 @@ #define NOMINMAX #endif #include + #endif #include #include #include +#include #include namespace Azure { namespace Core { namespace Http { @@ -49,6 +52,61 @@ namespace Azure { namespace Core { namespace Http { }; using unique_HINTERNET = std::unique_ptr; + // unique_ptr class wrapping a PCCERT_CHAIN_CONTEXT + struct HCERTIFICATECHAIN_deleter + { + void operator()(PCCERT_CHAIN_CONTEXT handle) + { + // unique_ptr class wrapping an HINTERNET handle + { + CertFreeCertificateChain(handle); + } + } + }; + using unique_CCERT_CHAIN_CONTEXT + = std::unique_ptr; + + // unique_ptr class wrapping an HCERTCHAINENGINE handle + struct HCERTCHAINENGINE_deleter + { + void operator()(HCERTCHAINENGINE handle) noexcept + { + if (handle != nullptr) + { + CertFreeCertificateChainEngine(handle); + } + } + }; + using unique_HCERTCHAINENGINE = std::unique_ptr; + + // unique_ptr class wrapping an HCERTSTORE handle + struct HCERTSTORE_deleter + { + public: + void operator()(HCERTSTORE handle) noexcept + { + if (handle != nullptr) + { + CertCloseStore(handle, 0); + } + } + }; + using unique_HCERTSTORE = std::unique_ptr; + + // unique_ptr class wrapping a PCCERT_CONTEXT + struct CERTCONTEXT_deleter + { + public: + void operator()(PCCERT_CONTEXT handle) noexcept + { + if (handle != nullptr) + { + CertFreeCertificateContext(handle); + } + } + }; + using unique_PCCERT_CONTEXT = std::unique_ptr; + class WinHttpStream final : public Azure::Core::IO::BodyStream { private: _detail::unique_HINTERNET m_requestHandle; @@ -122,6 +180,11 @@ namespace Azure { namespace Core { namespace Http { */ bool EnableSystemDefaultProxy{false}; + /** + * @brief If True, enables checks for certificate revocation. + */ + bool EnableCertificateRevocationListCheck{false}; + /** * @brief Proxy information. * @@ -142,6 +205,13 @@ namespace Azure { namespace Core { namespace Http { * @brief Password for proxy authentication. */ std::string ProxyPassword; + + /** + * @brief Array of Base64 encoded DER encoded X.509 certificate. These certificates should form + * a chain of certificates which will be used to validate the server certificate sent by the + * server. + */ + std::vector ExpectedTlsRootCertificates; }; /** @@ -155,6 +225,7 @@ namespace Azure { namespace Core { namespace Http { // This should remain immutable and not be modified after calling the ctor, to avoid threading // issues. _detail::unique_HINTERNET m_sessionHandle; + bool m_requestHandleClosed{false}; _detail::unique_HINTERNET CreateSessionHandle(); _detail::unique_HINTERNET CreateConnectionHandle( @@ -183,6 +254,34 @@ namespace Azure { namespace Core { namespace Http { _detail::unique_HINTERNET& requestHandle, HttpMethod requestMethod); + /* + * Callback from WinHTTP called after the TLS certificates are received when the caller sets + * expected TLS root certificates. + */ + static void CALLBACK StatusCallback( + HINTERNET hInternet, + DWORD_PTR dwContext, + DWORD dwInternetStatus, + LPVOID lpvStatusInformation, + DWORD dwStatusInformationLength) noexcept; + /* + * Callback from WinHTTP called after the TLS certificates are received when the caller sets + * expected TLS root certificates. + */ + void OnHttpStatusOperation(HINTERNET hInternet, DWORD dwInternetStatus); + /* + * Adds the specified trusted certificates to the specified certificate store. + */ + bool AddCertificatesToStore( + std::vector const& trustedCertificates, + _detail::unique_HCERTSTORE const& hCertStore); + /* + * Verifies that the certificate context is in the trustedCertificates set of certificates. + */ + bool VerifyCertificatesInChain( + std::vector const& trustedCertificates, + _detail::unique_PCCERT_CONTEXT const& serverCertificate); + // Callback to allow a derived transport to extract the request handle. Used for WebSocket // transports. protected: diff --git a/sdk/core/azure-core/inc/azure/core/platform.hpp b/sdk/core/azure-core/inc/azure/core/platform.hpp index 155f23c8ad..d46d3d587a 100644 --- a/sdk/core/azure-core/inc/azure/core/platform.hpp +++ b/sdk/core/azure-core/inc/azure/core/platform.hpp @@ -47,4 +47,10 @@ #define AZ_PLATFORM_WINDOWS #elif defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) #define AZ_PLATFORM_POSIX +#if defined(__APPLE__) || defined(__MACH__) +#define AZ_PLATFORM_MAC +#else +#define AZ_PLATFORM_LINUX +#endif + #endif diff --git a/sdk/core/azure-core/src/http/curl/curl.cpp b/sdk/core/azure-core/src/http/curl/curl.cpp index 79ce8c0f06..e15aa1569a 100644 --- a/sdk/core/azure-core/src/http/curl/curl.cpp +++ b/sdk/core/azure-core/src/http/curl/curl.cpp @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-License-Identifier: MIT +// cspell:words OCSP crls #include "azure/core/base64.hpp" #include "azure/core/platform.hpp" @@ -25,6 +26,12 @@ #include "curl_session_private.hpp" #if defined(AZ_PLATFORM_POSIX) +#include +#include +#include +#include +#include +#include #include // for poll() #include // for socket shutdown #elif defined(AZ_PLATFORM_WINDOWS) @@ -34,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -1099,7 +1107,9 @@ size_t CurlSession::ResponseBufferParser::Parse( { // Should never happen that parser is not statusLIne or Headers and we still try // to parse more. + // LCOV_EXCL_START AZURE_UNREACHABLE_CODE(); + // LCOV_EXCL_STOP } // clean internal buffer this->m_internalBuffer.clear(); @@ -1137,7 +1147,9 @@ size_t CurlSession::ResponseBufferParser::Parse( { // Should never happen that parser is not statusLIne or Headers and we still try // to parse more. + // LCOV_EXCL_START AZURE_UNREACHABLE_CODE(); + // LCOV_EXCL_STOP } } } @@ -1166,135 +1178,6 @@ size_t CurlSession::ResponseBufferParser::Parse( return index; } -// Finds delimiter '\r' as the end of the -size_t CurlSession::ResponseBufferParser::BuildStatusCode( - uint8_t const* const buffer, - size_t const bufferSize) -{ - if (this->state != ResponseParserState::StatusLine) - { - return 0; // Wrong internal state to call this method. - } - - uint8_t endOfStatusLine = '\r'; - auto endOfBuffer = buffer + bufferSize; - - // Look for the end of status line in buffer - auto indexOfEndOfStatusLine = std::find(buffer, endOfBuffer, endOfStatusLine); - - if (indexOfEndOfStatusLine == endOfBuffer) - { - // did not find the delimiter yet, copy to internal buffer - this->m_internalBuffer.append(buffer, endOfBuffer); - return bufferSize; // all buffer read and requesting for more - } - - // Delimiter found, check if there is data in the internal buffer - if (this->m_internalBuffer.size() > 0) - { - // If the index is same as buffer it means delimiter is at position 0, meaning that - // internalBuffer contains the status line and we don't need to add anything else - if (indexOfEndOfStatusLine > buffer) - { - // Append and build response minus the delimiter - this->m_internalBuffer.append(buffer, indexOfEndOfStatusLine); - } - this->m_response = CreateHTTPResponse(this->m_internalBuffer); - } - else - { - // Internal Buffer was not required, create response directly from buffer - this->m_response = CreateHTTPResponse(std::string(buffer, indexOfEndOfStatusLine)); - } - - // update control - this->state = ResponseParserState::Headers; - this->m_internalBuffer.clear(); - - // Return the index of the next char to read after delimiter - // No need to advance one more char ('\n') (since we might be at the end of the array) - // Parsing Headers will make sure to move one possition - return indexOfEndOfStatusLine + 1 - buffer; -} - -// Finds delimiter '\r' as the end of the -size_t CurlSession::ResponseBufferParser::BuildHeader( - uint8_t const* const buffer, - size_t const bufferSize) -{ - if (this->state != ResponseParserState::Headers) - { - return 0; // can't run this if state is not Headers. - } - - uint8_t delimiter = '\r'; - auto start = buffer; - auto endOfBuffer = buffer + bufferSize; - - if (bufferSize == 1 && buffer[0] == '\n') - { - // rare case of using buffer of size 1 to read. In this case, if the char is next value - // after headers or previous header, just consider it as read and return - return bufferSize; - } - else if (bufferSize > 1 && this->m_internalBuffer.size() == 0) // only if nothing in - // buffer, advance - { - // move offset one possition. This is because readStatusLine and readHeader will read up - // to - // '\r' then next delimiter is '\n' and we don't care - start = buffer + 1; - } - - // Look for the end of status line in buffer - auto indexOfEndOfStatusLine = std::find(start, endOfBuffer, delimiter); - - if (indexOfEndOfStatusLine == start && this->m_internalBuffer.size() == 0) - { - // \r found at the start means the end of headers - this->m_internalBuffer.clear(); - this->m_parseCompleted = true; - return 1; // can't return more than the found delimiter. On read remaining we need to - // also remove first char - } - - if (indexOfEndOfStatusLine == endOfBuffer) - { - // did not find the delimiter yet, copy to internal buffer - this->m_internalBuffer.append(start, endOfBuffer); - return bufferSize; // all buffer read and requesting for more - } - - // Delimiter found, check if there is data in the internal buffer - if (this->m_internalBuffer.size() > 0) - { - // If the index is same as buffer it means delimiter is at position 0, meaning that - // internalBuffer contains the status line and we don't need to add anything else - if (indexOfEndOfStatusLine > buffer) - { - // Append and build response minus the delimiter - this->m_internalBuffer.append(start, indexOfEndOfStatusLine); - } - // will throw if header is invalid - SetHeader(*m_response, this->m_internalBuffer); - } - else - { - // Internal Buffer was not required, create response directly from buffer - std::string header(std::string(start, indexOfEndOfStatusLine)); - // will throw if header is invalid - SetHeader(*this->m_response, header); - } - - // reuse buffer - this->m_internalBuffer.clear(); - - // Return the index of the next char to read after delimiter - // No need to advance one more char ('\n') (since we might be at the end of the array) - // Parsing Headers will make sure to move one position - return indexOfEndOfStatusLine + 1 - buffer; -} - namespace { // Calculate the connection key. // The connection key is a tuple of host, proxy info, TLS info, etc. Basically any characteristics @@ -1318,6 +1201,13 @@ inline std::string GetConnectionKey(std::string const& host, CurlTransportOption key.append(","); key.append(options.NoSignal ? "1" : "0"); key.append(","); + key.append(options.SslOptions.AllowFailedCrlRetrieval ? "FC" : "0"); + key.append(","); + key.append( + !options.SslOptions.PemEncodedExpectedRootCertificates.empty() ? std::to_string( + std::hash{}(options.SslOptions.PemEncodedExpectedRootCertificates)) + : "0"); + key.append(","); // using DefaultConnectionTimeout or 0 result in the same setting key.append( (options.ConnectionTimeout == Azure::Core::Http::_detail::DefaultConnectionTimeout @@ -1372,12 +1262,7 @@ void DumpCurlInfoToLog(std::string const& text, uint8_t* ptr, size_t size) } // namespace -int CurlConnectionPool::CurlLoggingCallback( - CURL*, - curl_infotype type, - char* data, - size_t size, - void*) +int CurlConnection::CurlLoggingCallback(CURL*, curl_infotype type, char* data, size_t size, void*) { if (type == CURLINFO_TEXT) { @@ -1420,6 +1305,558 @@ int CurlConnectionPool::CurlLoggingCallback( } return 0; } +#if !defined(AZ_PLATFORM_WINDOWS) +namespace Azure { namespace Core { namespace Http { + namespace _detail { + + // Helpers to provide RAII wrappers for OpenSSL types. + template struct openssl_deleter + { + void operator()(T* obj) { Deleter(obj); } + }; + template + using basic_openssl_unique_ptr = std::unique_ptr>; + + // *** Given just T, map it to the corresponding FreeFunc: + template struct type_map_helper; + template <> struct type_map_helper + { + using type = basic_openssl_unique_ptr; + }; + template <> struct type_map_helper + { + using type = basic_openssl_unique_ptr; + }; + + template <> struct type_map_helper + { + using type = basic_openssl_unique_ptr; + }; + + template <> struct type_map_helper + { + static void FreeCrlStack(STACK_OF(X509_CRL) * obj) + { + sk_X509_CRL_pop_free(obj, X509_CRL_free); + } + using type = basic_openssl_unique_ptr; + }; + + // *** Now users can say openssl_unique_ptr if they want: + template using openssl_unique_ptr = typename type_map_helper::type; + + // *** Or the current solution's convenience aliases: + using openssl_bio = openssl_unique_ptr; + using openssl_x509 = openssl_unique_ptr; + using openssl_x509_crl = openssl_unique_ptr; + using openssl_x509_crl_stack = openssl_unique_ptr; + + template + auto make_openssl_unique(Api& OpensslApi, Args&&... args) + { + auto raw = OpensslApi(std::forward( + args)...); // forwarding is probably unnecessary, could use const Args&... + // check raw + using T = std::remove_pointer_t; // no need to request T when we can see + // what OpensslApi returned + return openssl_unique_ptr{raw}; + } + + // Disable Code Coverage across GetOpenSSLError because we don't have a good way of forcing + // OpenSSL to fail. + // LCOV_EXCL_START + std::string GetOpenSSLError(std::string const& what) + { + auto bio(make_openssl_unique(BIO_new, BIO_s_mem())); + + BIO_printf(bio.get(), "Error in %hs: ", what.c_str()); + if (ERR_peek_error() != 0) + { + ERR_print_errors(bio.get()); + } + else + { + BIO_printf(bio.get(), "Unknown error."); + } + + uint8_t* bioData; + long bufferSize = BIO_get_mem_data(bio.get(), &bioData); + std::string returnValue; + returnValue.resize(bufferSize); + memcpy(&returnValue[0], bioData, bufferSize); + + return returnValue; + } + // LCOV_EXCL_STOP + + } // namespace _detail + + namespace { + // int g_ssl_crl_max_size_in_kb = 20; + /** + * @brief THe Cryptography class provides a set of basic cryptographic primatives required + * by the attestation samples. + */ + + _detail::openssl_x509_crl LoadCrlFromUrl(std::string const& url) + { + Log::Write(Logger::Level::Informational, "Load CRL from Url: " + url); + auto crl = _detail::make_openssl_unique(X509_CRL_load_http, url.c_str(), nullptr, nullptr, 5); + if (!crl) + { + Log::Write(Logger::Level::Error, _detail::GetOpenSSLError("Load CRL")); + } + + return crl; + } + + enum class CrlFormat : int + { + Http, + Asn1, + PEM, + }; + + _detail::openssl_x509_crl LoadCrl(std::string const& source, CrlFormat format) + { + _detail::openssl_x509_crl x; + _detail::openssl_bio in; + + if (format == CrlFormat::Http) + { + return LoadCrlFromUrl(source); + } + return x; + } + + bool IsCrlValid(X509_CRL* crl) + { + const ASN1_TIME* at = X509_CRL_get0_nextUpdate(crl); + + int day = -1; + int sec = -1; + if (!ASN1_TIME_diff(&day, &sec, nullptr, at)) + { + Log::Write(Logger::Level::Error, "Could not check expiration"); + return false; /* Safe default, invalid */ + } + + if (day > 0 || sec > 0) + { + return true; /* Later, valid */ + } + return false; /* Before or same, invalid */ + } + + const char* GetDistributionPointUrl(DIST_POINT* dp) + { + GENERAL_NAMES* gens; + GENERAL_NAME* gen; + int i, nameType; + ASN1_STRING* uri; + + if (!dp->distpoint) + { + Log::Write(Logger::Level::Informational, "returning, dp->distpoint is null"); + return nullptr; + } + + if (dp->distpoint->type != 0) + { + Log::Write( + Logger::Level::Informational, + "returning, dp->distpoint->type is " + std::to_string(dp->distpoint->type)); + return nullptr; + } + + gens = dp->distpoint->name.fullname; + + for (i = 0; i < sk_GENERAL_NAME_num(gens); i++) + { + gen = sk_GENERAL_NAME_value(gens, i); + uri = static_cast(GENERAL_NAME_get0_value(gen, &nameType)); + + if (nameType == GEN_URI && ASN1_STRING_length(uri) > 6) + { + const char* uptr = reinterpret_cast(ASN1_STRING_get0_data(uri)); + if (strncmp(uptr, "http://", 7) == 0) + { + return uptr; + } + } + } + + return nullptr; + } + + std::mutex crl_cache_lock; + std::vector crl_cache; + + bool SaveCertificateCrlToMemory(X509* cert, _detail::openssl_x509_crl const& crl) + { + std::unique_lock lockResult(crl_cache_lock); + + // update existing + X509_NAME* cert_issuer = cert ? X509_get_issuer_name(cert) : nullptr; + for (auto it = crl_cache.begin(); it != crl_cache.end(); ++it) + { + X509_CRL* cacheEntry = *it; + if (!cacheEntry) + { + continue; + } + + X509_NAME* crl_issuer = X509_CRL_get_issuer(cacheEntry); + if (!crl_issuer || !cert_issuer) + { + continue; + } + + // If we are getting a new CRL for an existing CRL, update the + // CRL with the new CRL. + if (0 == X509_NAME_cmp(crl_issuer, cert_issuer)) + { + // Bump the refcount on the new CRL before adding it to the cache. + X509_CRL_free(*it); + X509_CRL_up_ref(crl.get()); + *it = crl.get(); + return true; + } + } + + // not found, so try to find slot by purging outdated + for (auto it = crl_cache.begin(); it != crl_cache.end(); ++it) + { + if (!*it) + { + // set new + X509_CRL_free(*it); + X509_CRL_up_ref(crl.get()); + *it = crl.get(); + return true; + } + + if (!IsCrlValid(*it)) + { + // remove stale + X509_CRL_free(*it); + X509_CRL_up_ref(crl.get()); + *it = crl.get(); + return true; + } + } + + // Clone the certificate and add it to the cache. + X509_CRL_up_ref(crl.get()); + crl_cache.push_back(crl.get()); + return true; + } + + _detail::openssl_x509_crl LoadCertificateCrlFromMemory(X509* cert) + { + X509_NAME* cert_issuer = cert ? X509_get_issuer_name(cert) : nullptr; + + std::unique_lock lockResult(crl_cache_lock); + + for (auto it = crl_cache.begin(); it != crl_cache.end(); ++it) + { + X509_CRL* crl = *it; + if (!*it) + { + continue; + } + + // names don't match up. probably a hash collision + // so lets test if there is another crl on disk. + X509_NAME* crl_issuer = X509_CRL_get_issuer(crl); + if (!crl_issuer || !cert_issuer) + { + continue; + } + + if (0 != X509_NAME_cmp(crl_issuer, cert_issuer)) + { + continue; + } + + if (!IsCrlValid(crl)) + { + Log::Write(Logger::Level::Informational, "Discarding outdated CRL"); + X509_CRL_free(*it); + *it = nullptr; + continue; + } + + X509_CRL_up_ref(crl); + return _detail::openssl_x509_crl(crl); + } + return nullptr; + } + + _detail::openssl_x509_crl LoadCrlFromCacheAndDistributionPoint( + X509* cert, + STACK_OF(DIST_POINT) * crlDistributionPointStack) + { + int i; + + _detail::openssl_x509_crl crl = LoadCertificateCrlFromMemory(cert); + if (crl) + { + return crl; + } + + // file was not found on disk cache, + // so, now loading from web. + // Walk through the possible CRL distribution points + // looking for one which has a URL that we can download. + const char* urlptr = nullptr; + for (i = 0; i < sk_DIST_POINT_num(crlDistributionPointStack); i++) + { + DIST_POINT* dp = sk_DIST_POINT_value(crlDistributionPointStack, i); + + urlptr = GetDistributionPointUrl(dp); + if (urlptr) + { + // try to load from web, exit loop if + // successfully downloaded + crl = LoadCrl(urlptr, CrlFormat::Http); + if (crl) + break; + } + } + + if (!urlptr) + { + Log::Write(Logger::Level::Error, "No CRL dist point qualified for downloading."); + } + + if (crl) + { + // save it to memory + SaveCertificateCrlToMemory(cert, crl); + } + + return crl; + } + + /** + * @brief Retrieve the CRL associated with the provided store context, if available. + * + */ + STACK_OF(X509_CRL) * CrlHttpCallback(const X509_STORE_CTX* context, const X509_NAME*) + { + _detail::openssl_x509_crl crl; + STACK_OF(DIST_POINT) * crlDistributionPoint; + + _detail::openssl_x509_crl_stack crlStack + = _detail::openssl_x509_crl_stack(sk_X509_CRL_new_null()); + if (crlStack == nullptr) + { + Log::Write(Logger::Level::Error, "Failed to allocate STACK_OF(X509_CRL)"); + return nullptr; + } + + X509* currentCertificate = X509_STORE_CTX_get_current_cert(context); + + // try to download Crl + crlDistributionPoint = static_cast( + X509_get_ext_d2i(currentCertificate, NID_crl_distribution_points, nullptr, nullptr)); + if (!crlDistributionPoint + && X509_NAME_cmp( + X509_get_issuer_name(currentCertificate), + X509_get_subject_name(currentCertificate)) + != 0) + { + Log::Write( + Logger::Level::Error, + "No CRL distribution points defined on non self-issued cert, CRL check may fail."); + return nullptr; + } + + crl = LoadCrlFromCacheAndDistributionPoint(currentCertificate, crlDistributionPoint); + + sk_DIST_POINT_pop_free(crlDistributionPoint, DIST_POINT_free); + if (!crl) + { + Log::Write(Logger::Level::Error, "Unable to retrieve CRL, CRL check may fail."); + return nullptr; + } + + sk_X509_CRL_push(crlStack.get(), X509_CRL_dup(crl.get())); + + // try to download delta Crl + crlDistributionPoint = static_cast( + X509_get_ext_d2i(currentCertificate, NID_freshest_crl, nullptr, nullptr)); + if (crlDistributionPoint != nullptr) + { + crl = LoadCrlFromCacheAndDistributionPoint(currentCertificate, crlDistributionPoint); + + sk_DIST_POINT_pop_free(crlDistributionPoint, DIST_POINT_free); + if (crl) + { + sk_X509_CRL_push(crlStack.get(), X509_CRL_dup(crl.get())); + } + } + + return crlStack.release(); + } + + int GetOpenSSLContextConnectionIndex() + { + static int openSslConnectionIndex = -1; + if (openSslConnectionIndex < 0) + { + openSslConnectionIndex = X509_STORE_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + } + return openSslConnectionIndex; + } + int GetOpenSSLContextLastVerifyFunction() + { + static int openSslLastVerifyFunctionIndex = -1; + if (openSslLastVerifyFunctionIndex < 0) + { + openSslLastVerifyFunctionIndex + = X509_STORE_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + } + return openSslLastVerifyFunctionIndex; + } + } // namespace +}}} // namespace Azure::Core::Http + +// OpenSSL X509 Certificate Validation function - based off of the example found at: +// https://linux.die.net/man/3/x509_store_ctx_set_verify_cb +// +int CurlConnection::VerifyCertificateError(int ok, X509_STORE_CTX* storeContext) +{ + X509_STORE* certStore = X509_STORE_CTX_get0_store(storeContext); + X509* err_cert; + int err, depth; + _detail::openssl_bio bio_err(_detail::make_openssl_unique(BIO_new, BIO_s_mem())); + + err_cert = X509_STORE_CTX_get_current_cert(storeContext); + err = X509_STORE_CTX_get_error(storeContext); + depth = X509_STORE_CTX_get_error_depth(storeContext); + + BIO_printf(bio_err.get(), "depth=%d ", depth); + if (err_cert) + { + X509_NAME_print_ex(bio_err.get(), X509_get_subject_name(err_cert), 0, XN_FLAG_ONELINE); + BIO_puts(bio_err.get(), "\n"); + } + else + { + BIO_puts(bio_err.get(), "\n"); + } + if (!ok) + { + BIO_printf(bio_err.get(), "verify error:num=%d: %s\n", err, X509_verify_cert_error_string(err)); + } + + switch (err) + { + case X509_V_ERR_UNABLE_TO_GET_CRL: + BIO_printf(bio_err.get(), "Unable to retrieve CRL."); + break; + } + if (err == X509_V_OK && ok == 2) + { + /* print out policies */ + BIO_printf(bio_err.get(), "verify return:%d\n", ok); + } + + // Handle certificate specific errors here based on configuration options. + { + if (err == X509_V_ERR_UNABLE_TO_GET_CRL) + { + if (m_allowFailedCrlRetrieval) + { + BIO_printf(bio_err.get(), "Ignoring CRL retrieval error by configuration.\n"); + // Clear the X509 error in the store context, because CURL retrieves it, + // and it overwrites the successful result. + X509_STORE_CTX_set_error(storeContext, X509_V_OK); + // Return true, indicating that things are all good. + ok = 1; + } + else + { + BIO_printf( + bio_err.get(), "Fail TLS negotiation because CRL retrieval is not configured.\n"); + } + } + } + + char outputString[128]; + int len; + while ((len = BIO_gets(bio_err.get(), outputString, sizeof(outputString))) >= 0) + { + if (len == 0) + { + break; + } + if (outputString[len - 1] == '\n') + { + outputString[len - 1] = '\0'; + } + Log::Write(Logger::Level::Informational, std::string(outputString)); + } + + if (ok) + { + // We've done our stuff, call the pre-existing callback. + auto existingCallback = reinterpret_cast( + X509_STORE_get_ex_data(certStore, GetOpenSSLContextLastVerifyFunction())); + if (existingCallback != nullptr) + { + ok = existingCallback(ok, storeContext); + } + } + return (ok); +} + +int CurlConnection::CurlSslCtxCallback(CURL* curl, void* sslctx, void* parm) +{ + CurlConnection* connection = static_cast(parm); + return connection->SslCtxCallback(curl, sslctx); +} + +int CurlConnection::SslCtxCallback(CURL*, void* sslctx) +{ + SSL_CTX* ctx = reinterpret_cast(sslctx); + + // Note: SSL_CTX_get_cert_store does NOT increase the store reference count. + X509_STORE* certStore = SSL_CTX_get_cert_store(ctx); + X509_VERIFY_PARAM* verifyParam = X509_STORE_get0_param(certStore); + if (m_enableCrlValidation) + { + + // Store our connection handle in the store extended data so it can be retrieved + // in later callbacks. This allows setting options on a per-connection basis. + X509_STORE_set_ex_data(certStore, GetOpenSSLContextConnectionIndex(), this); + + X509_VERIFY_PARAM_set_flags(verifyParam, X509_V_FLAG_CRL_CHECK | X509_V_FLAG_CRL_CHECK_ALL); + X509_STORE_set_lookup_crls_cb(certStore, CrlHttpCallback); + + X509_STORE_set_ex_data( + certStore, + GetOpenSSLContextLastVerifyFunction(), + reinterpret_cast(X509_STORE_get_verify_cb(certStore))); + + X509_STORE_set_verify_cb(certStore, [](int ok, X509_STORE_CTX* storeContext) { + X509_STORE* certStore = X509_STORE_CTX_get0_store(storeContext); + + CurlConnection* thisConnection = reinterpret_cast( + X509_STORE_get_ex_data(certStore, GetOpenSSLContextConnectionIndex())); + + return thisConnection->VerifyCertificateError(ok, storeContext); + }); + } + else + { + X509_VERIFY_PARAM_clear_flags(verifyParam, X509_V_FLAG_CRL_CHECK); + } + return CURLE_OK; +} +#endif std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlConnection( Request& request, @@ -1480,8 +1917,82 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo // Creating a new connection is thread safe. No need to lock mutex here. // No available connection for the pool for the required host. Create one Log::Write(Logger::Level::Verbose, LogMsgPrefix + "Spawn new connection."); - unique_CURL newHandle(curl_easy_init(), CURL_deleter{}); - if (!newHandle) + + return std::make_unique(request, options, hostDisplayName, connectionKey); +} + +// Move the connection back to the connection pool. Push it to the front so it becomes the +// first connection to be picked next time some one ask for a connection to the pool (LIFO) +void CurlConnectionPool::MoveConnectionBackToPool( + std::unique_ptr connection, + HttpStatusCode lastStatusCode) +{ + auto code = static_cast::type>(lastStatusCode); + // laststatusCode = 0 + if (code < 200 || code >= 300) + { + // A handler with previous response with Error can't be re-use. + return; + } + + if (connection->IsShutdown()) + { + // Can't re-used a shut down connection + return; + } + + Log::Write(Logger::Level::Verbose, "Moving connection to pool..."); + + decltype(CurlConnectionPool::g_curlConnectionPool + .ConnectionPoolIndex)::mapped_type::value_type connectionToBeRemoved; + + // Lock mutex to access connection pool. mutex is unlock as soon as lock is out of scope + std::unique_lock lock(CurlConnectionPool::ConnectionPoolMutex); + auto& poolId = connection->GetConnectionKey(); + auto& hostPool = g_curlConnectionPool.ConnectionPoolIndex[poolId]; + + if (hostPool.size() >= _detail::MaxConnectionsPerIndex && !hostPool.empty()) + { + // Remove the last connection from the pool to insert this one. + auto lastConnection = --hostPool.end(); + connectionToBeRemoved = std::move(*lastConnection); + hostPool.erase(lastConnection); + } + + // update the time when connection was moved back to pool + connection->UpdateLastUsageTime(); + hostPool.push_front(std::move(connection)); + + if (m_cleanThread.joinable() && !IsCleanThreadRunning) + { + // Clean thread was running before but it's finished, join it to finalize + m_cleanThread.join(); + } + + // Cleanup will start a background thread which will close abandoned connections from the pool. + // This will free-up resources from the app + // This is the only call to cleanup. + if (!m_cleanThread.joinable()) + { + Log::Write(Logger::Level::Verbose, "Start clean thread"); + IsCleanThreadRunning = true; + m_cleanThread = std::thread(CleanupThread); + } + else + { + Log::Write(Logger::Level::Verbose, "Clean thread running. Won't start a new one."); + } +} + +CurlConnection::CurlConnection( + Request& request, + CurlTransportOptions const& options, + std::string const& hostDisplayName, + std::string const& connectionPropertiesKey) + : m_connectionKey(std::move(connectionPropertiesKey)) +{ + m_handle = _detail::unique_CURL(curl_easy_init(), _detail::CURL_deleter{}); + if (!m_handle) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " @@ -1491,16 +2002,15 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo if (options.EnableCurlTracing) { - if (!SetLibcurlOption( - newHandle, CURLOPT_DEBUGFUNCTION, CurlConnectionPool::CurlLoggingCallback, &result)) + m_handle, CURLOPT_DEBUGFUNCTION, CurlConnection::CurlLoggingCallback, &result)) { throw TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + std::string(". Could not enable logging callback.") + std::string(curl_easy_strerror(result))); } - if (!SetLibcurlOption(newHandle, CURLOPT_VERBOSE, 1, &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_VERBOSE, 1, &result)) { throw TransportException( _detail::DefaultFailedToGetNewConnectionTemplate @@ -1510,31 +2020,32 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo } // Libcurl setup before open connection (url, connect_only, timeout) - if (!SetLibcurlOption(newHandle, CURLOPT_URL, request.GetUrl().GetAbsoluteUrl().data(), &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_URL, request.GetUrl().GetAbsoluteUrl().data(), &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " + std::string(curl_easy_strerror(result))); } - if (port != 0 && !SetLibcurlOption(newHandle, CURLOPT_PORT, port, &result)) + if (request.GetUrl().GetPort() != 0 + && !SetLibcurlOption(m_handle, CURLOPT_PORT, request.GetUrl().GetPort(), &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " + std::string(curl_easy_strerror(result))); } - if (!SetLibcurlOption(newHandle, CURLOPT_CONNECT_ONLY, 1L, &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_CONNECT_ONLY, 1L, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " + std::string(curl_easy_strerror(result))); } - // Set timeout to 24h. Libcurl will fail uploading on windows if timeout is: + // Set timeout to 24h. Libcurl will fail uploading on windows if timeout is: // timeout >= 25 days. Fails as soon as trying to upload any data // 25 days < timeout > 1 days. Fail on huge uploads ( > 1GB) - if (!SetLibcurlOption(newHandle, CURLOPT_TIMEOUT, 60L * 60L * 24L, &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_TIMEOUT, 60L * 60L * 24L, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " @@ -1543,7 +2054,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo if (options.ConnectionTimeout != Azure::Core::Http::_detail::DefaultConnectionTimeout) { - if (!SetLibcurlOption(newHandle, CURLOPT_CONNECTTIMEOUT_MS, options.ConnectionTimeout, &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_CONNECTTIMEOUT_MS, options.ConnectionTimeout, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1558,7 +2069,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo */ if (options.Proxy) { - if (!SetLibcurlOption(newHandle, CURLOPT_PROXY, options.Proxy->c_str(), &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_PROXY, options.Proxy->c_str(), &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1569,7 +2080,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo if (!options.ProxyUsername.empty()) { - if (!SetLibcurlOption(newHandle, CURLOPT_PROXYUSERNAME, options.ProxyUsername.c_str(), &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_PROXYUSERNAME, options.ProxyUsername.c_str(), &result)) { throw TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1579,7 +2090,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo } if (!options.ProxyPassword.empty()) { - if (!SetLibcurlOption(newHandle, CURLOPT_PROXYPASSWORD, options.ProxyPassword.c_str(), &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_PROXYPASSWORD, options.ProxyPassword.c_str(), &result)) { throw TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1590,7 +2101,23 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo if (!options.CAInfo.empty()) { - if (!SetLibcurlOption(newHandle, CURLOPT_CAINFO, options.CAInfo.c_str(), &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_CAINFO, options.CAInfo.c_str(), &result)) + { + throw Azure::Core::Http::TransportException( + _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + + ". Failed to set CA cert file to:" + options.CAInfo + ". " + + std::string(curl_easy_strerror(result))); + } + } + + if (!options.SslOptions.PemEncodedExpectedRootCertificates.empty()) + { + curl_blob rootCertBlob + = {const_cast(reinterpret_cast( + options.SslOptions.PemEncodedExpectedRootCertificates.c_str())), + options.SslOptions.PemEncodedExpectedRootCertificates.size(), + CURL_BLOB_COPY}; + if (!SetLibcurlOption(m_handle, CURLOPT_CAINFO_BLOB, &rootCertBlob, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1599,23 +2126,52 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo } } +#if defined(AZ_PLATFORM_WINDOWS) long sslOption = 0; if (!options.SslOptions.EnableCertificateRevocationListCheck) { sslOption |= CURLSSLOPT_NO_REVOKE; } - if (!SetLibcurlOption(newHandle, CURLOPT_SSL_OPTIONS, sslOption, &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_SSL_OPTIONS, sslOption, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". Failed to set ssl options to long bitmask:" + std::to_string(sslOption) + ". " + std::string(curl_easy_strerror(result))); } +#elif !defined(AZ_PLATFORM_MAC) + if (options.SslOptions.EnableCertificateRevocationListCheck) + { + if (!SetLibcurlOption( + m_handle, CURLOPT_SSL_CTX_FUNCTION, CurlConnection::CurlSslCtxCallback, &result)) + { + throw TransportException( + _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + + ". Failed to set SSL context callback. " + std::string(curl_easy_strerror(result))); + } + if (!SetLibcurlOption(m_handle, CURLOPT_SSL_CTX_DATA, this, &result)) + { + throw TransportException( + _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + + ". Failed to set SSL context callback data. " + + std::string(curl_easy_strerror(result))); + } + // if (!SetLibcurlOption(m_handle, CURLOPT_SSL_VERIFYSTATUS, 1, &result)) + // { + // throw TransportException( + // _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + // + ". Failed to enable OCSP chaining. " + + // std::string(curl_easy_strerror(result))); + // } + } + m_allowFailedCrlRetrieval = options.SslOptions.AllowFailedCrlRetrieval; +#endif + m_enableCrlValidation = options.SslOptions.EnableCertificateRevocationListCheck; if (!options.SslVerifyPeer) { - if (!SetLibcurlOption(newHandle, CURLOPT_SSL_VERIFYPEER, 0L, &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_SSL_VERIFYPEER, 0L, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1625,7 +2181,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo if (options.NoSignal) { - if (!SetLibcurlOption(newHandle, CURLOPT_NOSIGNAL, 1L, &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_NOSIGNAL, 1L, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1637,91 +2193,58 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo // curl-transport adapter supports only HTTP/1.1 // https://github.com/Azure/azure-sdk-for-cpp/issues/2848 // The libcurl uses HTTP/2 by default, if it can be negotiated with a server on handshake. - if (!SetLibcurlOption(newHandle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_1_1, &result)) + if (!SetLibcurlOption(m_handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_1_1, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". Failed to set libcurl HTTP/1.1" + ". " + std::string(curl_easy_strerror(result))); } - // Make libcurl to support only TLS v1.2 or later - if (!SetLibcurlOption(newHandle, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1_2, &result)) + // Make libcurl to support only TLS v1.2 or later + if (!SetLibcurlOption(m_handle, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1_2, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". Failed enforcing TLS v1.2 or greater. " + std::string(curl_easy_strerror(result))); } - auto performResult = curl_easy_perform(newHandle.get()); + auto performResult = curl_easy_perform(m_handle.get()); if (performResult != CURLE_OK) { - throw Http::TransportException( - _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " - + std::string(curl_easy_strerror(performResult))); - } - - return std::make_unique(std::move(newHandle), connectionKey); -} - -// Move the connection back to the connection pool. Push it to the front so it becomes the -// first connection to be picked next time some one ask for a connection to the pool (LIFO) -void CurlConnectionPool::MoveConnectionBackToPool( - std::unique_ptr connection, - HttpStatusCode lastStatusCode) -{ - auto code = static_cast::type>(lastStatusCode); - // laststatusCode = 0 - if (code < 200 || code >= 300) - { - // A handler with previous response with Error can't be re-use. - return; - } - - if (connection->IsShutdown()) - { - // Can't re-used a shut down connection - return; - } - - Log::Write(Logger::Level::Verbose, "Moving connection to pool..."); - - decltype(CurlConnectionPool::g_curlConnectionPool - .ConnectionPoolIndex)::mapped_type::value_type connectionToBeRemoved; - - // Lock mutex to access connection pool. mutex is unlock as soon as lock is out of scope - std::unique_lock lock(CurlConnectionPool::ConnectionPoolMutex); - auto& poolId = connection->GetConnectionKey(); - auto& hostPool = g_curlConnectionPool.ConnectionPoolIndex[poolId]; - - if (hostPool.size() >= _detail::MaxConnectionsPerIndex && !hostPool.empty()) - { - // Remove the last connection from the pool to insert this one. - auto lastConnection = --hostPool.end(); - connectionToBeRemoved = std::move(*lastConnection); - hostPool.erase(lastConnection); - } - - // update the time when connection was moved back to pool - connection->UpdateLastUsageTime(); - hostPool.push_front(std::move(connection)); - - if (m_cleanThread.joinable() && !IsCleanThreadRunning) - { - // Clean thread was running before but it's finished, join it to finalize - m_cleanThread.join(); +#if defined(AZ_PLATFORM_LINUX) + if (performResult == CURLE_SSL_PEER_CERTIFICATE) + { + curl_easy_getinfo(m_handle.get(), CURLINFO_SSL_VERIFYRESULT, &result); + throw Http::TransportException( + _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " + + std::string(curl_easy_strerror(performResult)) + + ". Underlying error: " + X509_verify_cert_error_string(result)); + } + else +#endif + { + throw Http::TransportException( + _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " + + std::string(curl_easy_strerror(performResult))); + } } - // Cleanup will start a background thread which will close abandoned connections from the pool. - // This will free-up resources from the app - // This is the only call to cleanup. - if (!m_cleanThread.joinable()) - { - Log::Write(Logger::Level::Verbose, "Start clean thread"); - IsCleanThreadRunning = true; - m_cleanThread = std::thread(CleanupThread); - } - else + // Get the socket that libcurl is using from handle. Will use this to wait while + // reading/writing + // into wire +#if defined(_MSC_VER) +#pragma warning(push) +// C26812: The enum type 'CURLcode' is un-scoped. Prefer 'enum class' over 'enum' (Enum.3) +#pragma warning(disable : 26812) +#endif + result = curl_easy_getinfo(m_handle.get(), CURLINFO_ACTIVESOCKET, &m_curlSocket); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (result != CURLE_OK) { - Log::Write(Logger::Level::Verbose, "Clean thread running. Won't start a new one."); + throw Http::TransportException( + "Broken connection. Couldn't get the active sockect for it." + + std::string(curl_easy_strerror(result))); } } diff --git a/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp b/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp index 605136c875..df455984e3 100644 --- a/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp +++ b/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp @@ -120,13 +120,6 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { // private constructor to keep this as singleton. CurlConnectionPool() { curl_global_init(CURL_GLOBAL_ALL); } - static int CurlLoggingCallback( - CURL* handle, - curl_infotype type, - char* data, - size_t size, - void* userp); - // Makes possible to know the number of current connections in the connection pool for an // index size_t ConnectionsOnPool(std::string const& host) { return ConnectionPoolIndex[host].size(); } diff --git a/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp b/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp index 705ec2d862..ccf1b217b5 100644 --- a/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp +++ b/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp @@ -10,7 +10,6 @@ #pragma once #include "azure/core/http/http.hpp" - #include #include @@ -26,6 +25,8 @@ #pragma warning(pop) #endif +typedef struct x509_store_ctx_st X509_STORE_CTX; + namespace Azure { namespace Core { namespace Http { namespace _detail { @@ -143,92 +144,88 @@ namespace Azure { namespace Core { namespace Http { curl_socket_t m_curlSocket; std::chrono::steady_clock::time_point m_lastUseTime; std::string m_connectionKey; + // CRL validation is disabled by default to be consistent with WinHTTP behavior + bool m_enableCrlValidation{false}; + // Allow the connection to proceed if retrieving the CRL failed. + bool m_allowFailedCrlRetrieval{true}; + + static int CurlLoggingCallback( + CURL* handle, + curl_infotype type, + char* data, + size_t size, + void* userp); + + static int CurlSslCtxCallback(CURL* curl, void* sslctx, void* parm); + int SslCtxCallback(CURL* curl, void* sslctx); + int VerifyCertificateError(int ok, X509_STORE_CTX* storeContext); public: /** * @brief Construct CURL HTTP connection. * - * @param handle CURL handle. + * @param request Remote request + * @param options Connection options. + * @param hostDisplayName Display name for remote host, used for diagnostics. * * @param connectionPropertiesKey CURL connection properties key */ - CurlConnection(_detail::unique_CURL&& handle, std::string connectionPropertiesKey) - : m_handle(std::move(handle)), m_connectionKey(std::move(connectionPropertiesKey)) - { - // Get the socket that libcurl is using from handle. Will use this to wait while - // reading/writing - // into wire -#if defined(_MSC_VER) -#pragma warning(push) -// C26812: The enum type 'CURLcode' is un-scoped. Prefer 'enum class' over 'enum' (Enum.3) -#pragma warning(disable : 26812) -#endif - auto result = curl_easy_getinfo(m_handle.get(), CURLINFO_ACTIVESOCKET, &m_curlSocket); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (result != CURLE_OK) - { - throw Http::TransportException( - "Broken connection. Couldn't get the active sockect for it." - + std::string(curl_easy_strerror(result))); - } - } + CurlConnection( + Azure::Core::Http::Request& request, + Azure::Core::Http::CurlTransportOptions const& options, + std::string const& hostDisplayName, + std::string const& connectionPropertiesKey); - /** - * @brief Destructor. - * @details Cleans up CURL (invokes `curl_easy_cleanup()`). - */ - ~CurlConnection() override {} + /** + * @brief Destructor. + * @details Cleans up CURL (invokes `curl_easy_cleanup()`). + */ + ~CurlConnection() override {} - std::string const& GetConnectionKey() const override { return this->m_connectionKey; } + std::string const& GetConnectionKey() const override { return this->m_connectionKey; } - /** - * @brief Update last usage time for the connection. - * - */ - void UpdateLastUsageTime() override - { - this->m_lastUseTime = std::chrono::steady_clock::now(); - } + /** + * @brief Update last usage time for the connection. + * + */ + void UpdateLastUsageTime() override { this->m_lastUseTime = std::chrono::steady_clock::now(); } - /** - * @brief Checks whether this CURL connection is expired. - * @return `true` if this connection is considered expired; otherwise, `false`. - */ - bool IsExpired() override - { - auto connectionOnWaitingTimeMs = std::chrono::duration_cast( - std::chrono::steady_clock::now() - this->m_lastUseTime); - return connectionOnWaitingTimeMs.count() >= _detail::DefaultConnectionExpiredMilliseconds; - } + /** + * @brief Checks whether this CURL connection is expired. + * @return `true` if this connection is considered expired; otherwise, `false`. + */ + bool IsExpired() override + { + auto connectionOnWaitingTimeMs = std::chrono::duration_cast( + std::chrono::steady_clock::now() - this->m_lastUseTime); + return connectionOnWaitingTimeMs.count() >= _detail::DefaultConnectionExpiredMilliseconds; + } - /** - * @brief This function is used when working with streams to pull more data from the wire. - * Function will try to keep pulling data from socket until the buffer is all written or until - * there is no more data to get from the socket. - * - * @param context A context to control the request lifetime. - * @param buffer ptr to buffer where to copy bytes from socket. - * @param bufferSize size of the buffer and the requested bytes to be pulled from wire. - * @return return the numbers of bytes pulled from socket. It can be less than what it was - * requested. - */ - size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context) override; - - /** - * @brief This method will use libcurl socket to write all the bytes from buffer. - * - * @remarks Hardcoded timeout is used in case a socket stop responding. - * - * @param context A context to control the request lifetime. - * @param buffer ptr to the data to be sent to wire. - * @param bufferSize size of the buffer to send. - * @return CURL_OK when response is sent successfully. - */ - CURLcode SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context) - override; - - void Shutdown() override; - }; + /** + * @brief This function is used when working with streams to pull more data from the wire. + * Function will try to keep pulling data from socket until the buffer is all written or until + * there is no more data to get from the socket. + * + * @param context A context to control the request lifetime. + * @param buffer ptr to buffer where to copy bytes from socket. + * @param bufferSize size of the buffer and the requested bytes to be pulled from wire. + * @return return the numbers of bytes pulled from socket. It can be less than what it was + * requested. + */ + size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context) override; + + /** + * @brief This method will use libcurl socket to write all the bytes from buffer. + * + * @remarks Hardcoded timeout is used in case a socket stop responding. + * + * @param context A context to control the request lifetime. + * @param buffer ptr to the data to be sent to wire. + * @param bufferSize size of the buffer to send. + * @return CURL_OK when response is sent successfully. + */ + CURLcode SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context) override; + + void Shutdown() override; + }; }}} // namespace Azure::Core::Http diff --git a/sdk/core/azure-core/src/http/transport_policy.cpp b/sdk/core/azure-core/src/http/transport_policy.cpp index 6f73124f49..f2d9e4eec8 100644 --- a/sdk/core/azure-core/src/http/transport_policy.cpp +++ b/sdk/core/azure-core/src/http/transport_policy.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include "azure/core/http/policies/policy.hpp" +#include "azure/core/platform.hpp" #if defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER) #include "azure/core/http/curl_transport.hpp" @@ -11,6 +12,9 @@ #include "azure/core/http/win_http_transport.hpp" #endif +#include +#include + using Azure::Core::Context; using namespace Azure::Core::IO; using namespace Azure::Core::Http; @@ -21,12 +25,31 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { namespa namespace { bool AnyTransportOptionsSpecified(TransportOptions const& transportOptions) { - return !( - !transportOptions.HttpProxy.HasValue() && transportOptions.ProxyPassword.empty() - && transportOptions.ProxyUserName.empty()); + return ( + transportOptions.HttpProxy.HasValue() || !transportOptions.ProxyPassword.empty() + || !transportOptions.ProxyUserName.empty() + || transportOptions.EnableCertificateRevocationListCheck + || !transportOptions.ExpectedTlsRootCertificate.empty()); + } + + std::string PemEncodeFromBase64(std::string const& base64, std::string const& pemType) + { + std::stringstream rv; + rv << "-----BEGIN " << pemType << "-----" << std::endl; + std::string encodedValue(base64); + + // Insert crlf characters every 80 characters into the base64 encoded key to make it + // prettier. + size_t insertPos = 80; + while (insertPos < encodedValue.length()) + { + encodedValue.insert(insertPos, "\r\n"); + insertPos += 82; /* 80 characters plus the \r\n we just inserted */ + } + + rv << encodedValue << std::endl << "-----END " << pemType << "-----" << std::endl; + return rv.str(); } - // std::once_flag createTransportOnce; - // std::shared_ptr defaultTransport; } // namespace std::shared_ptr GetTransportAdapter(TransportOptions const& transportOptions) @@ -56,6 +79,24 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { namespa } httpOptions.ProxyUserName = transportOptions.ProxyUserName; httpOptions.ProxyPassword = transportOptions.ProxyPassword; + // Note that WinHTTP accepts a set of root certificates, even though transportOptions only + // specifies a single one. + if (!transportOptions.ExpectedTlsRootCertificate.empty()) + { + httpOptions.ExpectedTlsRootCertificates.push_back( + transportOptions.ExpectedTlsRootCertificate); + } + if (transportOptions.EnableCertificateRevocationListCheck) + { + httpOptions.EnableCertificateRevocationListCheck; + } + // If you specify an expected TLS root certificate, you also need to enable ignoring unknown + // CAs. + if (!transportOptions.ExpectedTlsRootCertificate.empty()) + { + httpOptions.IgnoreUnknownCertificateAuthority; + } + return std::make_shared(httpOptions); } else @@ -83,6 +124,16 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { namespa { curlOptions.ProxyPassword = transportOptions.ProxyPassword; } + + curlOptions.SslOptions.EnableCertificateRevocationListCheck + = transportOptions.EnableCertificateRevocationListCheck; + + if (!transportOptions.ExpectedTlsRootCertificate.empty()) + { + curlOptions.SslOptions.PemEncodedExpectedRootCertificates + = PemEncodeFromBase64(transportOptions.ExpectedTlsRootCertificate, "CERTIFICATE"); + } + return std::make_shared(curlOptions); } return defaultTransport; diff --git a/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp b/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp index d5ec6f559b..feadd948b2 100644 --- a/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp +++ b/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp @@ -1,8 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-License-Identifier: MIT +// cspell:words HCERTIFICATECHAIN PCCERT CCERT HCERTCHAINENGINE HCERTSTORE #include "azure/core/http/http.hpp" +#include "azure/core/base64.hpp" +#include "azure/core/diagnostics/logger.hpp" +#include "azure/core/internal/diagnostics/log.hpp" #include "azure/core/internal/strings.hpp" #if defined(BUILD_TRANSPORT_WINHTTP_ADAPTER) @@ -11,11 +15,15 @@ #include #include +#include #include +#include #include using Azure::Core::Context; using namespace Azure::Core::Http; +using namespace Azure::Core::Diagnostics; +using namespace Azure::Core::Diagnostics::_internal; namespace { @@ -196,9 +204,217 @@ std::string GetHeadersAsString(Azure::Core::Http::Request const& request) return requestHeaderString; } - } // namespace +// For each certificate specified in trustedCertificate, add to certificateStore. +bool WinHttpTransport::AddCertificatesToStore( + std::vector const& trustedCertificates, + _detail::unique_HCERTSTORE const& certificateStore) +{ + for (auto const& trustedCertificate : trustedCertificates) + { + auto derCertificate = Azure::Core::Convert::Base64Decode(trustedCertificate); + + if (!CertAddEncodedCertificateToStore( + certificateStore.get(), + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + derCertificate.data(), + static_cast(derCertificate.size()), + CERT_STORE_ADD_NEW, + NULL)) + { + GetErrorAndThrow("CertAddEncodedCertificateToStore failed"); + } + } + return true; +} + +// VerifyCertificateInChain determines whether the certificate in serverCertificate +// chains up to the PEM represented by trustedCertificate or not. +bool WinHttpTransport::VerifyCertificatesInChain( + std::vector const& trustedCertificates, + _detail::unique_PCCERT_CONTEXT const& serverCertificate) +{ + if ((trustedCertificates.empty()) || !serverCertificate) + { + return false; + } + + // Creates an in-memory certificate store that is destroyed at end of this function. + _detail::unique_HCERTSTORE certificateStore(CertOpenStore( + CERT_STORE_PROV_MEMORY, + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + 0, + CERT_STORE_CREATE_NEW_FLAG, + nullptr)); + if (!certificateStore) + { + GetErrorAndThrow("CertOpenStore failed"); + } + + // Add the trusted certificates to that store. + if (!AddCertificatesToStore(trustedCertificates, certificateStore)) + { + Log::Write(Logger::Level::Error, "Cannot add certificates to store"); + return false; + } + + _detail::unique_HCERTCHAINENGINE certificateChainEngine; + { + CERT_CHAIN_ENGINE_CONFIG EngineConfig{}; + EngineConfig.cbSize = sizeof(EngineConfig); + EngineConfig.dwFlags = CERT_CHAIN_ENABLE_CACHE_AUTO_UPDATE | CERT_CHAIN_ENABLE_SHARE_STORE; + EngineConfig.hExclusiveRoot = certificateStore.get(); + + HCERTCHAINENGINE engineHandle; + if (!CertCreateCertificateChainEngine(&EngineConfig, &engineHandle)) + { + GetErrorAndThrow("CertCreateCertificateChainEngine failed"); + } + certificateChainEngine.reset(engineHandle); + } + + // Generate a certificate chain using the local chain engine and the certificate store containing + // the trusted certificates. + _detail::unique_CCERT_CHAIN_CONTEXT chainContextToVerify; + { + CERT_CHAIN_PARA ChainPara{}; + ChainPara.cbSize = sizeof(ChainPara); + PCCERT_CHAIN_CONTEXT chainContext; + if (!CertGetCertificateChain( + certificateChainEngine.get(), + serverCertificate.get(), + nullptr, + certificateStore.get(), + &ChainPara, + 0, + nullptr, + &chainContext)) + { + GetErrorAndThrow("CertGetCertificateChain failed"); + } + chainContextToVerify.reset(chainContext); + } + + // And make sure that the certificate chain which was created matches the SSL chain. + { + CERT_CHAIN_POLICY_PARA PolicyPara{}; + PolicyPara.cbSize = sizeof(PolicyPara); + + CERT_CHAIN_POLICY_STATUS PolicyStatus{}; + PolicyStatus.cbSize = sizeof(PolicyStatus); + + if (!CertVerifyCertificateChainPolicy( + CERT_CHAIN_POLICY_SSL, chainContextToVerify.get(), &PolicyPara, &PolicyStatus)) + { + GetErrorAndThrow("CertVerifyCertificateChainPolicy"); + } + if (PolicyStatus.dwError != 0) + { + Log::Write( + Logger::Level::Error, + "CertVerifyCertificateChainPolicy sets certificateStatus " + + std::to_string(PolicyStatus.dwError)); + return false; + } + } + return true; +} + +/** + * Called by WinHTTP when sending a request to the server. This callback allows us to inspect the + * TLS certificate before sending it to the server. + */ +void WinHttpTransport::StatusCallback( + HINTERNET hInternet, + DWORD_PTR dwContext, + DWORD dwInternetStatus, + LPVOID, + DWORD) noexcept +{ + // If we're called before our context has been set (on Open and Close callbacks), ignore the + // status callback. + if (dwContext == 0) + { + return; + } + + try + { + WinHttpTransport* httpTransport = reinterpret_cast(dwContext); + httpTransport->OnHttpStatusOperation(hInternet, dwInternetStatus); + } + catch (Azure::Core::RequestFailedException& rfe) + { + // If an exception is thrown in the handler, log the error and terminate the connection. + Log::Write( + Logger::Level::Error, + "Request Failed Exception Thrown: " + std::string(rfe.what()) + rfe.Message); + WinHttpCloseHandle(hInternet); + } + catch (std::exception& ex) + { + // If an exception is thrown in the handler, log the error and terminate the connection. + Log::Write(Logger::Level::Error, "Exception Thrown: " + std::string(ex.what())); + } +} + +/** + * @brief HTTP Callback to enable private certificate checks. + * + * This method is called by WinHTTP when a certificate is received. This method is called multiple + * times based on the state of the TLS connection. We are only interested in + * WINHTTP_CALLBACK_STATUS_SENDING_REQUEST, which is called during the TLS handshake. + * + * When called, we verify that the certificate chain sent from the server contains the certificate + * the HTTP client was configured with. If it is, we accept the connection, if it is not, + * we abort the connection, closing the incoming request handle. + */ +void WinHttpTransport::OnHttpStatusOperation(HINTERNET hInternet, DWORD dwInternetStatus) +{ + if (dwInternetStatus != WINHTTP_CALLBACK_STATUS_SENDING_REQUEST) + { + if (dwInternetStatus == WINHTTP_CALLBACK_STATUS_SECURE_FAILURE) + { + Log::Write(Logger::Level::Error, "Security failure. :("); + } + // Silently ignore if there's any statuses we get that we can't handle + return; + } + + // We will only set the Status callback if a root certificate has been set. + AZURE_ASSERT(!m_options.ExpectedTlsRootCertificates.empty()); + + // Ask WinHTTP for the server certificate - this won't be valid outside a status callback. + _detail::unique_PCCERT_CONTEXT serverCertificate; + { + PCCERT_CONTEXT certContext; + DWORD bufferLength = sizeof(certContext); + if (!WinHttpQueryOption( + hInternet, + WINHTTP_OPTION_SERVER_CERT_CONTEXT, + reinterpret_cast(&certContext), + &bufferLength)) + { + GetErrorAndThrow("Could not retrieve TLS server certificate."); + } + serverCertificate.reset(certContext); + } + + if (!VerifyCertificatesInChain(m_options.ExpectedTlsRootCertificates, serverCertificate)) + { + Log::Write(Logger::Level::Error, "Server certificate is not trusted. Aborting HTTP request"); + + // To signal to caller that the request is to be terminated, the callback closes the handle. + // This ensures that no message is sent to the server. + WinHttpCloseHandle(hInternet); + + // To avoid a double free of this handle record that we've + // already closed the handle. + m_requestHandleClosed = true; + } +} + void WinHttpTransport::GetErrorAndThrow(const std::string& exceptionMessage, DWORD error) { std::string errorMessage = exceptionMessage + " Error Code: " + std::to_string(error); @@ -277,6 +493,21 @@ _detail::unique_HINTERNET WinHttpTransport::CreateSessionHandle() GetErrorAndThrow("Error while enforcing TLS 1.2 for connection request."); } + if (!m_options.ExpectedTlsRootCertificates.empty()) + { + + // Set the callback function to be called when a server certificate is received. + if (WinHttpSetStatusCallback( + sessionHandle.get(), + &WinHttpTransport::StatusCallback, + WINHTTP_CALLBACK_FLAG_SEND_REQUEST /* WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS*/, + 0) + == WINHTTP_INVALID_STATUS_CALLBACK) + { + GetErrorAndThrow("Error while setting up the status callback."); + } + } + return sessionHandle; } @@ -396,7 +627,7 @@ _detail::unique_HINTERNET WinHttpTransport::CreateRequestHandle( } } - if (m_options.IgnoreUnknownCertificateAuthority) + if (m_options.IgnoreUnknownCertificateAuthority || !m_options.ExpectedTlsRootCertificates.empty()) { auto option = SECURITY_FLAG_IGNORE_UNKNOWN_CA; if (!WinHttpSetOption(request.get(), WINHTTP_OPTION_SECURITY_FLAGS, &option, sizeof(option))) @@ -405,6 +636,15 @@ _detail::unique_HINTERNET WinHttpTransport::CreateRequestHandle( } } + if (m_options.EnableCertificateRevocationListCheck) + { + DWORD value = WINHTTP_ENABLE_SSL_REVOCATION; + if (!WinHttpSetOption(request.get(), WINHTTP_OPTION_ENABLE_FEATURE, &value, sizeof(value))) + { + GetErrorAndThrow("Error while enabling CRL validation."); + } + } + // If we are supporting WebSockets, then let WinHTTP know that it should // prepare to upgrade the HttpRequest to a WebSocket. #pragma warning(push) @@ -491,7 +731,7 @@ void WinHttpTransport::SendRequest( WINHTTP_NO_REQUEST_DATA, 0, streamLength > 0 ? static_cast(streamLength) : 0, - 0)) + reinterpret_cast(this))) { // Errors include: // ERROR_WINHTTP_CANNOT_CONNECT @@ -729,8 +969,22 @@ std::unique_ptr WinHttpTransport::Send(Request& request, Context co _detail::unique_HINTERNET connectionHandle = CreateConnectionHandle(request.GetUrl(), context); _detail::unique_HINTERNET requestHandle = CreateRequestHandle(connectionHandle, request.GetUrl(), request.GetMethod()); + try + { + SendRequest(requestHandle, request, context); + } + catch (TransportException&) + { + // If there was a TLS validation error, then we will have closed the request handle + // during the TLS validation callback. So if an exception was thrown, if we force closed the + // request handle, clear the handle in the requestHandle to prevent a double free. + if (m_requestHandleClosed) + { + requestHandle.release(); + } - SendRequest(requestHandle, request, context); + throw; + } ReceiveResponse(requestHandle, context); diff --git a/sdk/core/azure-core/test/ut/CMakeLists.txt b/sdk/core/azure-core/test/ut/CMakeLists.txt index db39d8f34b..610962a11c 100644 --- a/sdk/core/azure-core/test/ut/CMakeLists.txt +++ b/sdk/core/azure-core/test/ut/CMakeLists.txt @@ -33,6 +33,11 @@ endif() include(GoogleTest) +if (NOT WIN32) +find_package(OpenSSL REQUIRED) +SET(OPENSSLCRYPTO OpenSSL::Crypto) +endif() + add_executable ( azure-core-test azure_core_test.cpp @@ -119,7 +124,7 @@ add_custom_command(TARGET azure-core-test POST_BUILD # Adding private headers from CORE to the tests so we can test the private APIs with no relative paths include. target_include_directories (azure-core-test PRIVATE $) -target_link_libraries(azure-core-test PRIVATE azure-core gtest gmock) +target_link_libraries(azure-core-test PRIVATE azure-core gtest gmock ${OPENSSLCRYPTO}) create_map_file(azure-core-test azure-core-test.map) diff --git a/sdk/core/azure-core/test/ut/curl_connection_pool_test.cpp b/sdk/core/azure-core/test/ut/curl_connection_pool_test.cpp index 5c5d8cb126..94dc7fa1ad 100644 --- a/sdk/core/azure-core/test/ut/curl_connection_pool_test.cpp +++ b/sdk/core/azure-core/test/ut/curl_connection_pool_test.cpp @@ -55,7 +55,7 @@ namespace Azure { namespace Core { namespace Test { Azure::Core::Http::Request req( Azure::Core::Http::HttpMethod::Get, Azure::Core::Url(AzureSdkHttpbinServer::Get())); std::string const expectedConnectionKey(CreateConnectionKey( - AzureSdkHttpbinServer::Schema(), AzureSdkHttpbinServer::Host(), ",0,0,0,0,1,1,0,0")); + AzureSdkHttpbinServer::Schema(), AzureSdkHttpbinServer::Host(), ",0,0,0,0,1,1,0,0,0,0")); { // Creating a new connection with default options @@ -124,7 +124,7 @@ namespace Azure { namespace Core { namespace Test { // Now test that using a different connection config won't re-use the same connection std::string const secondExpectedKey = AzureSdkHttpbinServer::Schema() + "://" - + AzureSdkHttpbinServer::Host() + ",0,0,0,0,1,0,0,200000"; + + AzureSdkHttpbinServer::Host() + ",0,0,0,0,1,0,0,0,0,200000"; { // Creating a new connection with options Azure::Core::Http::CurlTransportOptions options; @@ -433,7 +433,9 @@ namespace Azure { namespace Core { namespace Test { Azure::Core::Http::Request req( Azure::Core::Http::HttpMethod::Get, Azure::Core::Url(authority)); std::string const expectedConnectionKey(CreateConnectionKey( - AzureSdkHttpbinServer::Schema(), AzureSdkHttpbinServer::Host(), ",0,0,0,0,1,1,0,0")); + AzureSdkHttpbinServer::Schema(), + AzureSdkHttpbinServer::Host(), + ",0,0,0,0,1,1,0,0,0,0")); // Creating a new connection with default options auto connection = Azure::Core::Http::_detail::CurlConnectionPool::g_curlConnectionPool @@ -471,7 +473,7 @@ namespace Azure { namespace Core { namespace Test { std::string const expectedConnectionKey(CreateConnectionKey( AzureSdkHttpbinServer::Schema(), AzureSdkHttpbinServer::Host(), - ":443,0,0,0,0,1,1,0,0")); + ":443,0,0,0,0,1,1,0,0,0,0")); // Creating a new connection with default options auto connection = Azure::Core::Http::_detail::CurlConnectionPool::g_curlConnectionPool @@ -508,7 +510,9 @@ namespace Azure { namespace Core { namespace Test { Azure::Core::Http::Request req( Azure::Core::Http::HttpMethod::Get, Azure::Core::Url(authority)); std::string const expectedConnectionKey(CreateConnectionKey( - AzureSdkHttpbinServer::Schema(), AzureSdkHttpbinServer::Host(), ",0,0,0,0,1,1,0,0")); + AzureSdkHttpbinServer::Schema(), + AzureSdkHttpbinServer::Host(), + ",0,0,0,0,1,1,0,0,0,0")); // Creating a new connection with default options auto connection = Azure::Core::Http::_detail::CurlConnectionPool::g_curlConnectionPool @@ -545,7 +549,7 @@ namespace Azure { namespace Core { namespace Test { std::string const expectedConnectionKey(CreateConnectionKey( AzureSdkHttpbinServer::Schema(), AzureSdkHttpbinServer::Host(), - ":443,0,0,0,0,1,1,0,0")); + ":443,0,0,0,0,1,1,0,0,0,0")); // Creating a new connection with default options auto connection = Azure::Core::Http::_detail::CurlConnectionPool::g_curlConnectionPool diff --git a/sdk/core/azure-core/test/ut/curl_options_test.cpp b/sdk/core/azure-core/test/ut/curl_options_test.cpp index 75fb899292..c8dc62238c 100644 --- a/sdk/core/azure-core/test/ut/curl_options_test.cpp +++ b/sdk/core/azure-core/test/ut/curl_options_test.cpp @@ -78,13 +78,16 @@ namespace Azure { namespace Core { namespace Test { std::unique_ptr response; EXPECT_NO_THROW(response = pipeline.Send(request, Azure::Core::Context::ApplicationContext)); - auto responseCode = response->GetStatusCode(); - int expectedCode = 200; - EXPECT_PRED2( - [](int expected, int code) { return expected == code; }, - expectedCode, - static_cast::type>( - responseCode)); + if (response) + { + auto responseCode = response->GetStatusCode(); + int expectedCode = 200; + EXPECT_PRED2( + [](int expected, int code) { return expected == code; }, + expectedCode, + static_cast::type>( + responseCode)); + } // Clean the connection from the pool *Windows fails to clean if we leave to be clean upon // app-destruction diff --git a/sdk/core/azure-core/test/ut/transport_policy_options.cpp b/sdk/core/azure-core/test/ut/transport_policy_options.cpp index eda51ba327..398f40f3c2 100644 --- a/sdk/core/azure-core/test/ut/transport_policy_options.cpp +++ b/sdk/core/azure-core/test/ut/transport_policy_options.cpp @@ -4,6 +4,7 @@ #include "azure/core/context.hpp" #include "azure/core/http/curl_transport.hpp" #include "azure/core/http/policies/policy.hpp" +#include "azure/core/internal/client_options.hpp" #include "azure/core/internal/environment.hpp" #include "azure/core/internal/http/pipeline.hpp" #include "azure/core/internal/json/json.hpp" @@ -14,7 +15,8 @@ #include #include -#if !defined(DISABLE_PROXY_TESTS) +using namespace std::chrono_literals; + namespace Azure { namespace Core { namespace Test { namespace { constexpr static const char AzureSdkHttpbinServerSchema[] = "https"; @@ -90,33 +92,8 @@ namespace Azure { namespace Core { namespace Test { Azure::Core::Http::HttpStatusCode code, Azure::Core::Http::HttpStatusCode expectedCode = Azure::Core::Http::HttpStatusCode::Ok); - std::string HttpProxyServer() - { - std::string anonymousServer{ - Azure::Core::_internal::Environment::GetVariable("ANONYMOUSCONTAINERIPV4ADDRESS")}; - GTEST_LOG_(INFO) << "Anonymous server: " << anonymousServer; - if (anonymousServer.empty()) - { - GTEST_LOG_(ERROR) - << "Could not find value for ANONYMOUSCONTAINERIPV4ADDRESS, Assuming local."; - anonymousServer = "127.0.0.1"; - } - - return "http://" + anonymousServer + ":3128"; - } - std::string HttpProxyServerWithPassword() - { - std::string authenticatedServer{ - Azure::Core::_internal::Environment::GetVariable("AUTHENTICATEDCONTAINERIPV4ADDRESS")}; - GTEST_LOG_(INFO) << "Authenticated server: " << authenticatedServer; - if (authenticatedServer.empty()) - { - GTEST_LOG_(ERROR) - << "Could not find value for AUTHENTICATEDCONTAINERIPV4ADDRESS, Assuming local."; - authenticatedServer = "127.0.0.1"; - } - return "http://" + authenticatedServer + ":3129"; - } + std::string HttpProxyServer() { return "http://127.0.0.1:3128"; } + std::string HttpProxyServerWithPassword() { return "http://127.0.0.1:3129"; } protected: // Create @@ -247,6 +224,7 @@ namespace Azure { namespace Core { namespace Test { using namespace Azure::Core::Http::_internal; using namespace Azure::Core::Http::Policies::_internal; +#if !defined(DISABLE_PROXY_TESTS) // constexpr char SocksProxyServer[] = "socks://98.162.96.41:4145"; TEST_F(TransportAdapterOptions, SimpleProxyTests) { @@ -377,5 +355,484 @@ namespace Azure { namespace Core { namespace Test { CheckBodyFromBuffer(*response, expectedResponseBodySize); } } -}}} // namespace Azure::Core::Test + #endif // defined(DISABLE_PROXY_TESTS) + + TEST_F(TransportAdapterOptions, DisableCrlValidation) + { + Azure::Core::Url testUrl(AzureSdkHttpbinServer::Get()); + // Azure::Core::Url testUrl("https://www.microsoft.com/"); + // HTTP Connections. + { + Azure::Core::Http::Policies::TransportOptions transportOptions; + + // Note that the default is to *disable* CRL checks, because they are disabled + // by default. So we test *enabling* CRL validation checks. + transportOptions.EnableCertificateRevocationListCheck = true; + HttpPipeline pipeline(CreateHttpPipeline(transportOptions)); + + auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, testUrl); + auto response = pipeline.Send(request, Azure::Core::Context::ApplicationContext); + EXPECT_EQ(response->GetStatusCode(), Azure::Core::Http::HttpStatusCode::Ok); + } +#if !defined(DISABLE_PROXY_TESTS) + { + Azure::Core::Http::Policies::TransportOptions transportOptions; + + transportOptions.HttpProxy = HttpProxyServerWithPassword(); + transportOptions.ProxyUserName = "user"; + transportOptions.ProxyPassword = "password"; + // Disable CA checks on proxy pipelines too. + transportOptions.EnableCertificateRevocationListCheck = true; + + HttpPipeline pipeline(CreateHttpPipeline(transportOptions)); + + auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, testUrl); + auto response = pipeline.Send(request, Azure::Core::Context::ApplicationContext); + checkResponseCode(response->GetStatusCode()); + auto expectedResponseBodySize = std::stoull(response->GetHeaders().at("content-length")); + CheckBodyFromBuffer(*response, expectedResponseBodySize); + } +#endif + } + + TEST_F(TransportAdapterOptions, CheckFailedCrlValidation) + { + // By default, for the Windows and Mac platforms, Curl uses + // SCHANNEL/SECTRANSP for CRL validation. Those SSL protocols + // don't have the same behaviors as OpenSSL does. +#if !defined(AZ_PLATFORM_WINDOWS) && !defined(AZ_PLATFORM_MAC) + // Azure::Core::Url + // testUrl("https://github.com/Azure/azure-sdk-for-cpp/blob/main/README.md"); + Azure::Core::Url testUrl("https://www.wikipedia.org"); + // For , github URLs work just fine if CRL validation is off, but if enabled, + // they fail. Let's use that fact to verify that CRL validation causes github + // URLs to fail. + { + Azure::Core::Http::Policies::TransportOptions transportOptions; + + // Note that the default is to *disable* CRL checks, because they are disabled + // by default. So we test *enabling* CRL validation checks. + transportOptions.EnableCertificateRevocationListCheck = false; + HttpPipeline pipeline(CreateHttpPipeline(transportOptions)); + + { + auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, testUrl); + auto response = pipeline.Send(request, Azure::Core::Context::ApplicationContext); + EXPECT_EQ(response->GetStatusCode(), Azure::Core::Http::HttpStatusCode::Ok); + } + } + { + Azure::Core::Http::Policies::TransportOptions transportOptions; + + // Note that the default is to *disable* CRL checks, because they are disabled + // by default. So we test *enabling* CRL validation checks. + transportOptions.EnableCertificateRevocationListCheck = true; + HttpPipeline pipeline(CreateHttpPipeline(transportOptions)); + + { + auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, testUrl); + EXPECT_THROW( + pipeline.Send(request, Azure::Core::Context::ApplicationContext), + Azure::Core::Http::TransportException); + } + } + { + Azure::Core::Http::Policies::TransportOptions transportOptions; + + // Note that the default is to *disable* CRL checks, because they are disabled + // by default. So we test *enabling* CRL validation checks. + // + // Retrieving the test URL should succeed if we allow failed CRL retrieval because + // the certificate for the test URL doesn't contain a CRL distribution points extension, + // and by default there is no platform CRL present. + Azure::Core::Http::CurlTransportOptions curlOptions; + curlOptions.SslOptions.AllowFailedCrlRetrieval = true; + curlOptions.SslOptions.EnableCertificateRevocationListCheck = true; + transportOptions.Transport = std::make_shared(curlOptions); + + HttpPipeline pipeline(CreateHttpPipeline(transportOptions)); + + { + auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, testUrl); + auto response = pipeline.Send(request, Azure::Core::Context::ApplicationContext); + EXPECT_EQ(response->GetStatusCode(), Azure::Core::Http::HttpStatusCode::Ok); + } + } +#endif + } + + TEST_F(TransportAdapterOptions, MultipleCrlOperations) + { + std::vector testUrls{ + AzureSdkHttpbinServer::Get(), + "https://www.microsoft.com/", + "https://www.example.com/", + "https://www.google.com/", + }; + + GTEST_LOG_(INFO) << "Basic test calls."; + { + Azure::Core::Http::Policies::TransportOptions transportOptions; + + // FIrst verify connectivity to the test servers. + transportOptions.EnableCertificateRevocationListCheck = false; + HttpPipeline pipeline(CreateHttpPipeline(transportOptions)); + + for (auto const& target : testUrls) + { + GTEST_LOG_(INFO) << "Test " << target; + Azure::Core::Url url(target); + auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, url); + std::unique_ptr response; + EXPECT_NO_THROW( + response = pipeline.Send(request, Azure::Core::Context::ApplicationContext)); + if (response && response->GetStatusCode() != Azure::Core::Http::HttpStatusCode::Found) + { + EXPECT_EQ(response->GetStatusCode(), Azure::Core::Http::HttpStatusCode::Ok); + } + } + } + + // Now verify that once we enable CRL checks, we can still access the URLs. + GTEST_LOG_(INFO) << "Test with CRL checks enabled"; + { + Azure::Core::Http::Policies::TransportOptions transportOptions; + + // Note that the default is to *disable* CRL checks, because they are disabled + // by default. So we test *enabling* CRL validation checks. + transportOptions.EnableCertificateRevocationListCheck = true; + HttpPipeline pipeline(CreateHttpPipeline(transportOptions)); + + for (auto const& target : testUrls) + { + GTEST_LOG_(INFO) << "Test " << target; + Azure::Core::Url url(target); + auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, url); + std::unique_ptr response; + EXPECT_NO_THROW( + response = pipeline.Send(request, Azure::Core::Context::ApplicationContext)); + if (response && response->GetStatusCode() != Azure::Core::Http::HttpStatusCode::Found) + { + EXPECT_EQ(response->GetStatusCode(), Azure::Core::Http::HttpStatusCode::Ok); + } + } + } + + // Now verify that once we enable CRL checks, we can still access the URLs. + GTEST_LOG_(INFO) << "Test with CRL checks enabled. Iteration 2."; + { + Azure::Core::Http::Policies::TransportOptions transportOptions; + + // Note that the default is to *disable* CRL checks, because they are disabled + // by default. So we test *enabling* CRL validation checks. + transportOptions.EnableCertificateRevocationListCheck = true; + HttpPipeline pipeline(CreateHttpPipeline(transportOptions)); + + for (auto const& target : testUrls) + { + GTEST_LOG_(INFO) << "Test " << target; + Azure::Core::Url url(target); + auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, url); + std::unique_ptr response; + EXPECT_NO_THROW( + response = pipeline.Send(request, Azure::Core::Context::ApplicationContext)); + if (response && response->GetStatusCode() != Azure::Core::Http::HttpStatusCode::Found) + { + EXPECT_EQ(response->GetStatusCode(), Azure::Core::Http::HttpStatusCode::Ok); + } + } + } + } + + TEST_F(TransportAdapterOptions, TestRootCertificate) + { + // On Windows and OSX, setting a root certificate disables the default system certificate + // store. That means that if we set the expected certificate, we won't be able to connect to + // the server because the certificates root CA is not in the store. +#if defined(AZ_PLATFORM_LINUX) + // cspell:disable + std::string azurewebsitesCertificate + = "MIIF8zCCBNugAwIBAgIQCq+mxcpjxFFB6jvh98dTFzANBgkqhkiG9w0BAQwFADBh" + "MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3" + "d3cuZGlnaWNlcnQuY29tMSAwHgYDVQQDExdEaWdpQ2VydCBHbG9iYWwgUm9vdCBH" + "MjAeFw0yMDA3MjkxMjMwMDBaFw0yNDA2MjcyMzU5NTlaMFkxCzAJBgNVBAYTAlVT" + "MR4wHAYDVQQKExVNaWNyb3NvZnQgQ29ycG9yYXRpb24xKjAoBgNVBAMTIU1pY3Jv" + "c29mdCBBenVyZSBUTFMgSXNzdWluZyBDQSAwMTCCAiIwDQYJKoZIhvcNAQEBBQAD" + "ggIPADCCAgoCggIBAMedcDrkXufP7pxVm1FHLDNA9IjwHaMoaY8arqqZ4Gff4xyr" + "RygnavXL7g12MPAx8Q6Dd9hfBzrfWxkF0Br2wIvlvkzW01naNVSkHp+OS3hL3W6n" + "l/jYvZnVeJXjtsKYcXIf/6WtspcF5awlQ9LZJcjwaH7KoZuK+THpXCMtzD8XNVdm" + "GW/JI0C/7U/E7evXn9XDio8SYkGSM63aLO5BtLCv092+1d4GGBSQYolRq+7Pd1kR" + "EkWBPm0ywZ2Vb8GIS5DLrjelEkBnKCyy3B0yQud9dpVsiUeE7F5sY8Me96WVxQcb" + "OyYdEY/j/9UpDlOG+vA+YgOvBhkKEjiqygVpP8EZoMMijephzg43b5Qi9r5UrvYo" + "o19oR/8pf4HJNDPF0/FJwFVMW8PmCBLGstin3NE1+NeWTkGt0TzpHjgKyfaDP2tO" + "4bCk1G7pP2kDFT7SYfc8xbgCkFQ2UCEXsaH/f5YmpLn4YPiNFCeeIida7xnfTvc4" + "7IxyVccHHq1FzGygOqemrxEETKh8hvDR6eBdrBwmCHVgZrnAqnn93JtGyPLi6+cj" + "WGVGtMZHwzVvX1HvSFG771sskcEjJxiQNQDQRWHEh3NxvNb7kFlAXnVdRkkvhjpR" + "GchFhTAzqmwltdWhWDEyCMKC2x/mSZvZtlZGY+g37Y72qHzidwtyW7rBetZJAgMB" + "AAGjggGtMIIBqTAdBgNVHQ4EFgQUDyBd16FXlduSzyvQx8J3BM5ygHYwHwYDVR0j" + "BBgwFoAUTiJUIBiV5uNu5g/6+rkS7QYXjzkwDgYDVR0PAQH/BAQDAgGGMB0GA1Ud" + "JQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjASBgNVHRMBAf8ECDAGAQH/AgEAMHYG" + "CCsGAQUFBwEBBGowaDAkBggrBgEFBQcwAYYYaHR0cDovL29jc3AuZGlnaWNlcnQu" + "Y29tMEAGCCsGAQUFBzAChjRodHRwOi8vY2FjZXJ0cy5kaWdpY2VydC5jb20vRGln" + "aUNlcnRHbG9iYWxSb290RzIuY3J0MHsGA1UdHwR0MHIwN6A1oDOGMWh0dHA6Ly9j" + "cmwzLmRpZ2ljZXJ0LmNvbS9EaWdpQ2VydEdsb2JhbFJvb3RHMi5jcmwwN6A1oDOG" + "MWh0dHA6Ly9jcmw0LmRpZ2ljZXJ0LmNvbS9EaWdpQ2VydEdsb2JhbFJvb3RHMi5j" + "cmwwHQYDVR0gBBYwFDAIBgZngQwBAgEwCAYGZ4EMAQICMBAGCSsGAQQBgjcVAQQD" + "AgEAMA0GCSqGSIb3DQEBDAUAA4IBAQAlFvNh7QgXVLAZSsNR2XRmIn9iS8OHFCBA" + "WxKJoi8YYQafpMTkMqeuzoL3HWb1pYEipsDkhiMnrpfeYZEA7Lz7yqEEtfgHcEBs" + "K9KcStQGGZRfmWU07hPXHnFz+5gTXqzCE2PBMlRgVUYJiA25mJPXfB00gDvGhtYa" + "+mENwM9Bq1B9YYLyLjRtUz8cyGsdyTIG/bBM/Q9jcV8JGqMU/UjAdh1pFyTnnHEl" + "Y59Npi7F87ZqYYJEHJM2LGD+le8VsHjgeWX2CJQko7klXvcizuZvUEDTjHaQcs2J" + "+kPgfyMIOY1DMJ21NxOJ2xPRC/wAh/hzSBRVtoAnyuxtkZ4VjIOh"; + // cspell:enable + + { + Azure::Core::Http::Policies::TransportOptions transportOptions; + + // Note that the default is to *disable* CRL checks, because they are disabled + // by default. So we test *enabling* CRL validation checks. + transportOptions.ExpectedTlsRootCertificate = azurewebsitesCertificate; + HttpPipeline pipeline(CreateHttpPipeline(transportOptions)); + + Azure::Core::Url url(AzureSdkHttpbinServer::Get()); + auto request = Azure::Core::Http::Request(Azure::Core::Http::HttpMethod::Get, url); + auto response = pipeline.Send(request, Azure::Core::Context::ApplicationContext); + EXPECT_EQ(response->GetStatusCode(), Azure::Core::Http::HttpStatusCode::Ok); + } +#endif + } + + const std::string TestProxyHttpsCertificate = + // cspell:disable + "MIIDSDCCAjCgAwIBAgIUIoKu8Oao7j10TLNxaUG2Bs0FrRwwDQYJKoZIhvcNAQEL" + "BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgwNTIxMTcyM1oXDTIzMDgw" + "NTIxMTcyM1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF" + "AAOCAQ8AMIIBCgKCAQEA0UPG7ER++5/9D/qa4SCtt7QvdHwcpidbwktPNU8iRW7V" + "pIDPWS4goLp/+7+maT0Z/mqwSO3JDtm/dtdlr3F/5EMgyUExnYcvUixZAiyFyEwj" + "j6wnAtNvqsg4rDqBlD17fuqTVsZm9Yo7QYub6p5PeznWYucOxRrczqFCiW4uj0Yk" + "GgUHPPmCvhSDKowV8CYRHfkD6R8R4SFkoP3/uejXHxeXoYJNMWq5K0GqGaOZtNFB" + "F7QWZHoLrRpZcY4h+DxwP3c+/FdlVcs9nstkF+EnTnwx5IRyKsaWb/pUEmYKvNDz" + "wi6qnRUdu+DghZuvyZZDgwoYrSZokcbKumk0MsLC3QIDAQABo4GRMIGOMA8GA1Ud" + "EwEB/wQFMAMBAf8wDgYDVR0PAQH/BAQDAgGmMBYGA1UdJQEB/wQMMAoGCCsGAQUF" + "BwMBMBcGA1UdEQEB/wQNMAuCCWxvY2FsaG9zdDA6BgorBgEEAYI3VAEBBCwMKkFT" + "UC5ORVQgQ29yZSBIVFRQUyBkZXZlbG9wbWVudCBjZXJ0aWZpY2F0ZTANBgkqhkiG" + "9w0BAQsFAAOCAQEARX4NxGbycdPVuqvu/CO+/LpWrEm1OcOl7N57/mD5npTIJT78" + "TYtXk1J61akumKdf5CaBgCDRcl35LhioFZIMEsiOidffAp6t493xocncFBhIYYrZ" + "HS6aKsZKPu8h3wOLpYu+zh7f0Hx6pkHPAfw4+knmQjDYomz/hTwuo/MuT8k6Ee7B" + "NGWqxUamLI8bucuf2ZfT1XOq83uWaFF5KwAuVLhpzo39/TmPyYGnaoKRYf9QjabS" + "LUjecMNLJFWHUSD4cKHvXJjDYZEiCiy+MdUDytWIsfw0fzAUjz9Qaz8YpZ+fXufM" + "MNMNfyJHSMEMFIT2D1UaQiwryXWQWJ93OiSdjA=="; + + const std::string InvalidTestProxyHttpsCertificate + = "MIIIujCCBqKgAwIBAgITMwAxS6DhmVCLBf6MWwAAADFLoDANBgkqhkiG9w0BAQwF" + "ADBZMQswCQYDVQQGEwJVUzEeMBwGA1UEChMVTWljcm9zb2Z0IENvcnBvcmF0aW9u" + "MSowKAYDVQQDEyFNaWNyb3NvZnQgQXp1cmUgVExTIElzc3VpbmcgQ0EgMDEwHhcN" + "MjIwMzE0MTgzOTU1WhcNMjMwMzA5MTgzOTU1WjBqMQswCQYDVQQGEwJVUzELMAkG" + "A1UECBMCV0ExEDAOBgNVBAcTB1JlZG1vbmQxHjAcBgNVBAoTFU1pY3Jvc29mdCBD" + "b3Jwb3JhdGlvbjEcMBoGA1UEAwwTKi5henVyZXdlYnNpdGVzLm5ldDCCASIwDQYJ" + "KoZIhvcNAQEBBQADggEPADCCAQoCggEBAM3heDMqn7v8cmh4A9vECuEfuiUnKBIw" + "7y0Sf499Z7WW92HDkIvV3eJ6jcyq41f2UJcG8ivCu30eMnYyyI+aRHIedkvOBA2i" + "PqG78e99qGTuKCj9lrJGVfeTBJ1VIlPvfuHFv/3JaKIBpRtuqxCdlgsGAJQmvHEn" + "vIHUV2jgj4iWNBDoC83ShtWg6qV2ol7yiaClB20Af5byo36jVdMN6vS+/othn3jG" + "pn+NP00DWYbP5y4qhs5XLH9wQZaTUPKIaUxmHewErcM0rMAaWl8wMqQTeNYf3l5D" + "ax50yuEg9VVjtbDdSmvOkslGpVqsOl1NrmyN7gCvcvcRUQcxIiXJQc0CAwEAAaOC" + "BGgwggRkMIIBfwYKKwYBBAHWeQIEAgSCAW8EggFrAWkAdgCt9776fP8QyIudPZwe" + "PhhqtGcpXc+xDCTKhYY069yCigAAAX+Jw/reAAAEAwBHMEUCIE8AAjvwO4AffPn7" + "un67WykJ2hGB4n8qJE7pk4QYjWW+AiEA/pio1E9ALt30Kh/Ga4gRefH1ILbQ8n4h" + "bHFatezIcvYAdwB6MoxU2LcttiDqOOBSHumEFnAyE4VNO9IrwTpXo1LrUgAAAX+J" + "w/qlAAAEAwBIMEYCIQCdbj6FOX6wK+dLoqjWKuCgkKSsZsJKpVik6HjlRgomzQIh" + "AM7mYp5dBFmNLas3fFcP0rMMK+17n8u0GhFH2KpkPr1SAHYA6D7Q2j71BjUy51co" + "vIlryQPTy9ERa+zraeF3fW0GvW4AAAF/icP6jgAABAMARzBFAiAhjTz3PBjqRrpY" + "eH7us44lESC7c0dzdTcehTeAwmEyrgIhAOCaqmqA+ercv+39jzFWkctG36bazRFX" + "4gGNiKU0bctcMCcGCSsGAQQBgjcVCgQaMBgwCgYIKwYBBQUHAwIwCgYIKwYBBQUH" + "AwEwPAYJKwYBBAGCNxUHBC8wLQYlKwYBBAGCNxUIh73XG4Hn60aCgZ0ujtAMh/Da" + "HV2ChOVpgvOnPgIBZAIBJTCBrgYIKwYBBQUHAQEEgaEwgZ4wbQYIKwYBBQUHMAKG" + "YWh0dHA6Ly93d3cubWljcm9zb2Z0LmNvbS9wa2lvcHMvY2VydHMvTWljcm9zb2Z0" + "JTIwQXp1cmUlMjBUTFMlMjBJc3N1aW5nJTIwQ0ElMjAwMSUyMC0lMjB4c2lnbi5j" + "cnQwLQYIKwYBBQUHMAGGIWh0dHA6Ly9vbmVvY3NwLm1pY3Jvc29mdC5jb20vb2Nz" + "cDAdBgNVHQ4EFgQUiiks5RXI6IIQccflfDtgAHndN7owDgYDVR0PAQH/BAQDAgSw" + "MHwGA1UdEQR1MHOCEyouYXp1cmV3ZWJzaXRlcy5uZXSCFyouc2NtLmF6dXJld2Vi" + "c2l0ZXMubmV0ghIqLmF6dXJlLW1vYmlsZS5uZXSCFiouc2NtLmF6dXJlLW1vYmls" + "ZS5uZXSCFyouc3NvLmF6dXJld2Vic2l0ZXMubmV0MAwGA1UdEwEB/wQCMAAwZAYD" + "VR0fBF0wWzBZoFegVYZTaHR0cDovL3d3dy5taWNyb3NvZnQuY29tL3BraW9wcy9j" + "cmwvTWljcm9zb2Z0JTIwQXp1cmUlMjBUTFMlMjBJc3N1aW5nJTIwQ0ElMjAwMS5j" + "cmwwZgYDVR0gBF8wXTBRBgwrBgEEAYI3TIN9AQEwQTA/BggrBgEFBQcCARYzaHR0" + "cDovL3d3dy5taWNyb3NvZnQuY29tL3BraW9wcy9Eb2NzL1JlcG9zaXRvcnkuaHRt" + "MAgGBmeBDAECAjAfBgNVHSMEGDAWgBQPIF3XoVeV25LPK9DHwncEznKAdjAdBgNV" + "HSUEFjAUBggrBgEFBQcDAgYIKwYBBQUHAwEwDQYJKoZIhvcNAQEMBQADggIBAKtk" + "4nEDfqxbP80uaoBoPaeeX4G/tBNcfpR2sf6soW8atAqGOohdLPcE0n5/KJn+H4u7" + "CsZdTJyUVxBxAlpqAc9JABl4urWNbhv4pueGBZXOn5K5Lpup/gp1HhCx4XKFno/7" + "T22NVDol4LRLUTeTkrpNyYLU5QYBQpqlFMAcvem/2seiPPYghFtLr5VWVEikUvnf" + "wSlECNk84PT7mOdbrX7T3CbG9WEZVmSYxMCS4pwcW3caXoSzUzZ0H1sJndCJW8La" + "9tekRKkMVkN558S+FFwaY1yARNqCFeK+yiwvkkkojqHbgwFJgCFWYy37kFR9uPiv" + "3sTHvs8IZ5K8TY7rHk3pSMYqoBTODCs7wKGiByWSDMcfAgGBzjt95SKfq0p6sj0C" + "+HWFiyKR+PTi2esFP9Vr9sC9jfRM6zwa7KnONqLefHauJPdNMt5l1FQGWvyco4IN" + "lwK3Z9FfEOFZA4YcjsqnkNacKZqLjgis3FvD8VPXETgRuffVc75lJxH6WmkwqdXj" + "BlU8wOcJyXTmM1ehYpziCpWvGBSEIsFuK6BC/iBnQEuWKdctAdbHIDlLctGgDWjx" + "xYDPZ/TtORGL8YaDnj6QHeOURIAHCtt6NCWKV6OR2HtMx+tCEvfi5ION1dyJ9hAX" + "+4K9FXc71ab7tdV/GLPkWc8Q0x1nk7ogDYcqKbiF"; + + class TestProxy { + // cspell:enable + + std::unique_ptr m_pipeline; + + public: + struct TestProxyOptions : Azure::Core::_internal::ClientOptions + { + TestProxyOptions() : Azure::Core::_internal::ClientOptions() {} + }; + TestProxy(TestProxyOptions options = TestProxyOptions()) + { + if (options.Transport.ExpectedTlsRootCertificate.empty()) + { + options.Transport.ExpectedTlsRootCertificate = TestProxyHttpsCertificate; + } + std::vector> perRetryPolicies; + std::vector> perCallPolicies; + m_pipeline = std::make_unique( + options, + "Test Proxy", + "2021-11", + std::move(perRetryPolicies), + std::move(perCallPolicies)); + } + + Azure::Response PostStartRecording(std::string const& recordingFile) + { + std::string proxyServerRequest; + proxyServerRequest = "{ \"x-recording-file\": \""; + proxyServerRequest += Azure::Core::Url::Encode(recordingFile); + proxyServerRequest += "\"}"; + std::vector bodyVector{proxyServerRequest.begin(), proxyServerRequest.end()}; + Azure::Core::IO::MemoryBodyStream postBody(bodyVector); + auto request = Azure::Core::Http::Request( + Azure::Core::Http::HttpMethod::Post, + Azure::Core::Url("https://localhost:5001/record/start"), + &postBody); + + auto response = m_pipeline->Send(request, Azure::Core::Context::ApplicationContext); + auto& responseHeaders = response->GetHeaders(); + auto responseId = responseHeaders.find("x-recording-id"); + return Azure::Response(responseId->second, std::move(response)); + } + Azure::Response PostStopRecording( + std::string const& recordingId) + { + auto request = Azure::Core::Http::Request( + Azure::Core::Http::HttpMethod::Post, + Azure::Core::Url("https://localhost:5001/record/stop")); + request.SetHeader("x-recording-id", recordingId); + + auto response = m_pipeline->Send(request, Azure::Core::Context::ApplicationContext); + auto responseCode = response->GetStatusCode(); + return Azure::Response(responseCode, std::move(response)); + } + + Azure::Response PostStartPlayback(std::string const& recordingFile) + { + std::string proxyServerRequest; + proxyServerRequest = "{ \"x-recording-file\": \""; + proxyServerRequest += Azure::Core::Url::Encode(recordingFile); + proxyServerRequest += "\"}"; + std::vector bodyVector{proxyServerRequest.begin(), proxyServerRequest.end()}; + Azure::Core::IO::MemoryBodyStream postBody(bodyVector); + auto request = Azure::Core::Http::Request( + Azure::Core::Http::HttpMethod::Post, + Azure::Core::Url("https://localhost:5001/playback/start"), + &postBody); + + auto response = m_pipeline->Send(request, Azure::Core::Context::ApplicationContext); + auto& responseHeaders = response->GetHeaders(); + auto responseId = responseHeaders.find("x-recording-id"); + return Azure::Response(responseId->second, std::move(response)); + } + + Azure::Response PostStopPlayback( + std::string const& recordingId) + { + auto request = Azure::Core::Http::Request( + Azure::Core::Http::HttpMethod::Post, + Azure::Core::Url("https://localhost:5001/playback/stop")); + request.SetHeader("x-recording-id", recordingId); + + auto response = m_pipeline->Send(request, Azure::Core::Context::ApplicationContext); + auto responseCode = response->GetStatusCode(); + return Azure::Response(responseCode, std::move(response)); + } + + Azure::Response ProxyServerGetUrl( + std::string const& recordingId, + bool isRecording, + std::string const& urlToRecord) + { + Azure::Core::Url targetUrl{urlToRecord}; + auto request = Azure::Core::Http::Request( + Azure::Core::Http::HttpMethod::Get, + Azure::Core::Url("https://localhost:5001/" + targetUrl.GetRelativeUrl())); + request.SetHeader( + "x-recording-upstream-base-uri", targetUrl.GetScheme() + "://" + targetUrl.GetHost()); + request.SetHeader("x-recording-id", recordingId); + request.SetHeader("x-recording-mode", (isRecording ? "record" : "playback")); + + auto response = m_pipeline->Send(request, Azure::Core::Context::ApplicationContext); + std::string responseBody(response->GetBody().begin(), response->GetBody().end()); + return Azure::Response(responseBody, std::move(response)); + } + + Azure::Response IsAlive() + { + auto request = Azure::Core::Http::Request( + Azure::Core::Http::HttpMethod::Get, + Azure::Core::Url("https://localhost:5001/Admin/IsAlive")); + auto response = m_pipeline->Send(request, Azure::Core::Context::ApplicationContext); + auto statusCode = response->GetStatusCode(); + return Azure::Response(statusCode, std::move(response)); + } + + ~TestProxy() {} + }; + + TEST_F(TransportAdapterOptions, AccessTestProxyServer) + { + TestProxy proxyServer; + + EXPECT_EQ(Azure::Core::Http::HttpStatusCode::Ok, proxyServer.IsAlive().Value); + + std::string recordingId; + EXPECT_NO_THROW(recordingId = proxyServer.PostStartRecording("testRecording.json").Value); + + GTEST_LOG_(INFO) << "Started recording with ID " << recordingId; + + std::string response; + EXPECT_NO_THROW( + response + = proxyServer.ProxyServerGetUrl(recordingId, true, AzureSdkHttpbinServer::Get()).Value); + + GTEST_LOG_(INFO) << "Response for recording " << recordingId << "is: " << response; + + EXPECT_NO_THROW(proxyServer.PostStopRecording(recordingId)); + + EXPECT_NO_THROW(recordingId = proxyServer.PostStartPlayback("testRecording.json").Value); + GTEST_LOG_(INFO) << "Started playback with ID " << recordingId; + + EXPECT_NO_THROW( + response + = proxyServer.ProxyServerGetUrl(recordingId, false, AzureSdkHttpbinServer::Get()).Value); + + GTEST_LOG_(INFO) << "Recorded Response for " << recordingId << "is: " << response; + + EXPECT_NO_THROW(proxyServer.PostStopPlayback(recordingId)); + } + + TEST_F(TransportAdapterOptions, TestProxyServerWithInvalidCertificate) + { + TestProxy::TestProxyOptions options; + options.Transport.ExpectedTlsRootCertificate = InvalidTestProxyHttpsCertificate; + TestProxy proxyServer(options); + + EXPECT_THROW(proxyServer.IsAlive(), Azure::Core::Http::TransportException); + } + +}}} // namespace Azure::Core::Test diff --git a/sdk/core/ci.yml b/sdk/core/ci.yml index 45e6e3e0e2..71fea0a3d9 100644 --- a/sdk/core/ci.yml +++ b/sdk/core/ci.yml @@ -94,6 +94,12 @@ stages: displayName: Verify Proxy Server Working Correctly. condition: and(succeeded(), variables.RunProxyTests, contains(variables.CmakeArgs, 'BUILD_TESTING=ON')) + # Start the test proxy tool for testing. Only start it when previous steps had succeeded and we're enabling testing. + - template: /eng/common/testproxy/test-proxy-tool.yml + parameters: + runProxy: true + targetVersion: 1.0.0-dev.20220810.2 + PostTestSteps: # Shut down the test server. This uses curl to send a request to the "terminateserver" websocket endpoint. # When the test server receives a request on terminateserver, it shuts down gracefully. @@ -115,7 +121,7 @@ stages: docker stop $_ ` } displayName: Shutdown Squid Proxy. - condition: and(contains(variables.CmakeArgs, 'BUILD_TESTING=ON'), variables.RunProxyTests, contains(variables['OSVmImage'], 'linux')) + condition: and(variables.RunProxyTests, contains(variables.CmakeArgs, 'BUILD_TESTING=ON'), contains(variables['OSVmImage'], 'linux')) - template: /eng/common/pipelines/templates/steps/publish-artifact.yml parameters: ArtifactPath: '$(Build.SourcesDirectory)/WebSocketServer.log'