diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 2baa79229b..abc9f97fcc 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -10,6 +10,8 @@ ### Other Changes +- Added support for Identity token caching, and for configuring token refresh offset in `BearerTokenAuthenticationPolicy`. + ## 1.8.0-beta.1 (2022-10-06) ### Features Added diff --git a/sdk/core/azure-core/inc/azure/core/credentials/credentials.hpp b/sdk/core/azure-core/inc/azure/core/credentials/credentials.hpp index 94a033eebb..27a9e62d33 100644 --- a/sdk/core/azure-core/inc/azure/core/credentials/credentials.hpp +++ b/sdk/core/azure-core/inc/azure/core/credentials/credentials.hpp @@ -48,6 +48,12 @@ namespace Azure { namespace Core { namespace Credentials { * */ std::vector Scopes; + + /** + * @brief Minimum token expiration suggestion. + * + */ + DateTime::duration MinimumExpiration = std::chrono::minutes(2); }; /** @@ -61,6 +67,8 @@ namespace Azure { namespace Core { namespace Credentials { * @param tokenRequestContext A context to get the token in. * @param context A context to control the request lifetime. * + * @return Authentication token. + * * @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred. */ virtual AccessToken GetToken( diff --git a/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp b/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp index 46a71a58b7..7099584f4d 100644 --- a/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp +++ b/sdk/core/azure-core/src/http/bearer_token_authentication_policy.cpp @@ -18,8 +18,8 @@ std::unique_ptr BearerTokenAuthenticationPolicy::Send( { std::lock_guard lock(m_accessTokenMutex); - // Refresh the token in 2 or less minutes before the actual expiration. - if (std::chrono::system_clock::now() > (m_accessToken.ExpiresOn - std::chrono::minutes(2))) + if (std::chrono::system_clock::now() + > (m_accessToken.ExpiresOn - m_tokenRequestContext.MinimumExpiration)) { m_accessToken = m_credential->GetToken(m_tokenRequestContext, context); } diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 14e620b1d6..552e371cc4 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added token caching. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/identity/azure-identity/CMakeLists.txt b/sdk/identity/azure-identity/CMakeLists.txt index 8cc8f7414c..6b686d6337 100644 --- a/sdk/identity/azure-identity/CMakeLists.txt +++ b/sdk/identity/azure-identity/CMakeLists.txt @@ -60,6 +60,8 @@ set( AZURE_IDENTITY_SOURCE src/private/managed_identity_source.hpp src/private/package_version.hpp + src/private/token_cache.hpp + src/private/token_cache_internals.hpp src/private/token_credential_impl.hpp src/chained_token_credential.cpp src/client_certificate_credential.cpp @@ -67,6 +69,7 @@ set( src/environment_credential.cpp src/managed_identity_credential.cpp src/managed_identity_source.cpp + src/token_cache.cpp src/token_credential_impl.cpp ) @@ -106,6 +109,9 @@ az_rtti_setup( ) if(BUILD_TESTING) + # define a symbol that enables some test hooks in code + add_compile_definitions(TESTING_BUILD) + # tests if (NOT AZ_ALL_LIBRARIES OR FETCH_SOURCE_DEPS) include(AddGoogleTest) diff --git a/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp index 8e1d43b02d..206c77108e 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/client_secret_credential.hpp @@ -51,13 +51,18 @@ namespace Azure { namespace Identity { std::unique_ptr<_detail::TokenCredentialImpl> m_tokenCredentialImpl; Core::Url m_requestUrl; std::string m_requestBody; + + std::string m_tenantId; + std::string m_clientId; + std::string m_authorityHost; + bool m_isAdfs; ClientSecretCredential( - std::string const& tenantId, - std::string const& clientId, + std::string tenantId, + std::string clientId, std::string const& clientSecret, - std::string const& authorityHost, + std::string authorityHost, Core::Credentials::TokenCredentialOptions const& options); public: @@ -70,8 +75,8 @@ namespace Azure { namespace Identity { * @param options Options for token retrieval. */ explicit ClientSecretCredential( - std::string const& tenantId, - std::string const& clientId, + std::string tenantId, + std::string clientId, std::string const& clientSecret, ClientSecretCredentialOptions const& options); @@ -86,7 +91,7 @@ namespace Azure { namespace Identity { explicit ClientSecretCredential( std::string tenantId, std::string clientId, - std::string clientSecret, + std::string const& clientSecret, Core::Credentials::TokenCredentialOptions const& options = Core::Credentials::TokenCredentialOptions()); @@ -102,6 +107,8 @@ namespace Azure { namespace Identity { * @param tokenRequestContext A context to get the token in. * @param context A context to control the request lifetime. * + * @return Authentication token. + * * @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred. */ Core::Credentials::AccessToken GetToken( diff --git a/sdk/identity/azure-identity/inc/azure/identity/managed_identity_credential.hpp b/sdk/identity/azure-identity/inc/azure/identity/managed_identity_credential.hpp index 3e4863ec5c..0017bd8817 100644 --- a/sdk/identity/azure-identity/inc/azure/identity/managed_identity_credential.hpp +++ b/sdk/identity/azure-identity/inc/azure/identity/managed_identity_credential.hpp @@ -59,6 +59,8 @@ namespace Azure { namespace Identity { * @param tokenRequestContext A context to get the token in. * @param context A context to control the request lifetime. * + * @return Authentication token. + * * @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred. */ Core::Credentials::AccessToken GetToken( diff --git a/sdk/identity/azure-identity/src/client_secret_credential.cpp b/sdk/identity/azure-identity/src/client_secret_credential.cpp index 5b04416c59..78cdac88a0 100644 --- a/sdk/identity/azure-identity/src/client_secret_credential.cpp +++ b/sdk/identity/azure-identity/src/client_secret_credential.cpp @@ -3,9 +3,11 @@ #include "azure/identity/client_secret_credential.hpp" +#include "private/token_cache.hpp" #include "private/token_credential_impl.hpp" #include +#include using namespace Azure::Identity; @@ -13,29 +15,33 @@ std::string const Azure::Identity::_detail::g_aadGlobalAuthority = "https://login.microsoftonline.com/"; ClientSecretCredential::ClientSecretCredential( - std::string const& tenantId, - std::string const& clientId, + std::string tenantId, + std::string clientId, std::string const& clientSecret, - std::string const& authorityHost, + std::string authorityHost, Azure::Core::Credentials::TokenCredentialOptions const& options) : m_tokenCredentialImpl(std::make_unique<_detail::TokenCredentialImpl>(options)), - m_isAdfs(tenantId == "adfs") + m_tenantId(std::move(tenantId)), m_clientId(std::move(clientId)), + m_authorityHost(std::move(authorityHost)) { using Azure::Core::Url; - m_requestUrl = Url(authorityHost); - m_requestUrl.AppendPath(tenantId); + + m_isAdfs = (m_tenantId == "adfs"); + + m_requestUrl = Url(m_authorityHost); + m_requestUrl.AppendPath(m_tenantId); m_requestUrl.AppendPath(m_isAdfs ? "oauth2/token" : "oauth2/v2.0/token"); std::ostringstream body; - body << "grant_type=client_credentials&client_id=" << Url::Encode(clientId) + body << "grant_type=client_credentials&client_id=" << Url::Encode(m_clientId) << "&client_secret=" << Url::Encode(clientSecret); m_requestBody = body.str(); } ClientSecretCredential::ClientSecretCredential( - std::string const& tenantId, - std::string const& clientId, + std::string tenantId, + std::string clientId, std::string const& clientSecret, ClientSecretCredentialOptions const& options) : ClientSecretCredential(tenantId, clientId, clientSecret, options.AuthorityHost, options) @@ -45,7 +51,7 @@ ClientSecretCredential::ClientSecretCredential( ClientSecretCredential::ClientSecretCredential( std::string tenantId, std::string clientId, - std::string clientSecret, + std::string const& clientSecret, Core::Credentials::TokenCredentialOptions const& options) : ClientSecretCredential( tenantId, @@ -62,28 +68,49 @@ Azure::Core::Credentials::AccessToken ClientSecretCredential::GetToken( Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext, Azure::Core::Context const& context) const { - return m_tokenCredentialImpl->GetToken(context, [&]() { - using _detail::TokenCredentialImpl; - using Azure::Core::Http::HttpMethod; - - std::ostringstream body; - body << m_requestBody; - { - auto const& scopes = tokenRequestContext.Scopes; - if (!scopes.empty()) - { - body << "&scope=" << TokenCredentialImpl::FormatScopes(scopes, m_isAdfs); - } - } - - auto request = std::make_unique( - HttpMethod::Post, m_requestUrl, body.str()); + using _detail::TokenCache; + using _detail::TokenCredentialImpl; - if (m_isAdfs) + std::string scopesStr; + { + auto const& scopes = tokenRequestContext.Scopes; + if (!scopes.empty()) { - request->HttpRequest.SetHeader("Host", m_requestUrl.GetHost()); + scopesStr = TokenCredentialImpl::FormatScopes(scopes, m_isAdfs); } - - return request; - }); + } + + // TokenCache::GetToken() and m_tokenCredentialImpl->GetToken() can only use the lambda argument + // when they are being executed. They are not supposed to keep a reference to lambda argument to + // call it later. Therefore, any capture made here will outlive the possible time frame when the + // lambda might get called. + return TokenCache::GetToken( + m_tenantId, + m_clientId, + m_authorityHost, + scopesStr, + tokenRequestContext.MinimumExpiration, + [&]() { + return m_tokenCredentialImpl->GetToken(context, [&]() { + using Azure::Core::Http::HttpMethod; + + std::ostringstream body; + body << m_requestBody; + + if (!scopesStr.empty()) + { + body << "&scope=" << scopesStr; + } + + auto request = std::make_unique( + HttpMethod::Post, m_requestUrl, body.str()); + + if (m_isAdfs) + { + request->HttpRequest.SetHeader("Host", m_requestUrl.GetHost()); + } + + return request; + }); + }); } diff --git a/sdk/identity/azure-identity/src/managed_identity_source.cpp b/sdk/identity/azure-identity/src/managed_identity_source.cpp index 56fbf51273..3dcb4d38ad 100644 --- a/sdk/identity/azure-identity/src/managed_identity_source.cpp +++ b/sdk/identity/azure-identity/src/managed_identity_source.cpp @@ -3,6 +3,8 @@ #include "private/managed_identity_source.hpp" +#include "private/token_cache.hpp" + #include #include @@ -59,7 +61,7 @@ AppServiceManagedIdentitySource::AppServiceManagedIdentitySource( std::string const& apiVersion, std::string const& secretHeaderName, std::string const& clientIdHeaderName) - : ManagedIdentitySource(options), + : ManagedIdentitySource(clientId, endpointUrl.GetHost(), options), m_request(Azure::Core::Http::HttpMethod::Get, std::move(endpointUrl)) { { @@ -81,18 +83,37 @@ Azure::Core::Credentials::AccessToken AppServiceManagedIdentitySource::GetToken( Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext, Azure::Core::Context const& context) const { - return TokenCredentialImpl::GetToken(context, [&]() { - auto request = std::make_unique(m_request); + std::string scopesStr; + { + auto const& scopes = tokenRequestContext.Scopes; + if (!scopes.empty()) { - auto const& scopes = tokenRequestContext.Scopes; - if (!scopes.empty()) - { - request->HttpRequest.GetUrl().AppendQueryParameter("resource", FormatScopes(scopes, true)); - } + scopesStr = TokenCredentialImpl::FormatScopes(scopes, true); } + } - return request; - }); + // TokenCache::GetToken() and TokenCredentialImpl::GetToken() can only use the lambda argument + // when they are being executed. They are not supposed to keep a reference to lambda argument to + // call it later. Therefore, any capture made here will outlive the possible time frame when the + // lambda might get called. + return TokenCache::GetToken( + std::string(), + GetClientId(), + GetAuthorityHost(), + scopesStr, + tokenRequestContext.MinimumExpiration, + [&]() { + return TokenCredentialImpl::GetToken(context, [&]() { + auto request = std::make_unique(m_request); + + if (!scopesStr.empty()) + { + request->HttpRequest.GetUrl().AppendQueryParameter("resource", scopesStr); + } + + return request; + }); + }); } std::unique_ptr AppServiceV2017ManagedIdentitySource::Create( @@ -128,7 +149,7 @@ CloudShellManagedIdentitySource::CloudShellManagedIdentitySource( std::string const& clientId, Azure::Core::Credentials::TokenCredentialOptions const& options, Azure::Core::Url endpointUrl) - : ManagedIdentitySource(options), m_url(std::move(endpointUrl)) + : ManagedIdentitySource(clientId, endpointUrl.GetHost(), options), m_url(std::move(endpointUrl)) { using Azure::Core::Url; if (!clientId.empty()) @@ -141,28 +162,47 @@ Azure::Core::Credentials::AccessToken CloudShellManagedIdentitySource::GetToken( Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext, Azure::Core::Context const& context) const { - return TokenCredentialImpl::GetToken(context, [&]() { - using Azure::Core::Url; - using Azure::Core::Http::HttpMethod; - - std::string resource; + std::string scopesStr; + { + auto const& scopes = tokenRequestContext.Scopes; + if (!scopes.empty()) { - auto const& scopes = tokenRequestContext.Scopes; - if (!scopes.empty()) - { - resource = "resource=" + FormatScopes(scopes, true); - if (!m_body.empty()) - { - resource += "&"; - } - } + scopesStr = TokenCredentialImpl::FormatScopes(scopes, true); } + } - auto request = std::make_unique(HttpMethod::Post, m_url, resource + m_body); - request->HttpRequest.SetHeader("Metadata", "true"); - - return request; - }); + // TokenCache::GetToken() and TokenCredentialImpl::GetToken() can only use the lambda argument + // when they are being executed. They are not supposed to keep a reference to lambda argument to + // call it later. Therefore, any capture made here will outlive the possible time frame when the + // lambda might get called. + return TokenCache::GetToken( + std::string(), + GetClientId(), + GetAuthorityHost(), + scopesStr, + tokenRequestContext.MinimumExpiration, + [&]() { + return TokenCredentialImpl::GetToken(context, [&]() { + using Azure::Core::Url; + using Azure::Core::Http::HttpMethod; + + std::string resource; + + if (!scopesStr.empty()) + { + resource = "resource=" + scopesStr; + if (!m_body.empty()) + { + resource += "&"; + } + } + + auto request = std::make_unique(HttpMethod::Post, m_url, resource + m_body); + request->HttpRequest.SetHeader("Metadata", "true"); + + return request; + }); + }); } std::unique_ptr AzureArcManagedIdentitySource::Create( @@ -194,7 +234,8 @@ std::unique_ptr AzureArcManagedIdentitySource::Create( AzureArcManagedIdentitySource::AzureArcManagedIdentitySource( Azure::Core::Credentials::TokenCredentialOptions const& options, Azure::Core::Url endpointUrl) - : ManagedIdentitySource(options), m_url(std::move(endpointUrl)) + : ManagedIdentitySource(std::string(), endpointUrl.GetHost(), options), + m_url(std::move(endpointUrl)) { m_url.AppendQueryParameter("api-version", "2019-11-01"); @@ -204,6 +245,15 @@ Azure::Core::Credentials::AccessToken AzureArcManagedIdentitySource::GetToken( Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext, Azure::Core::Context const& context) const { + std::string scopesStr; + { + auto const& scopes = tokenRequestContext.Scopes; + if (!scopes.empty()) + { + scopesStr = TokenCredentialImpl::FormatScopes(scopes, true); + } + } + auto const createRequest = [&]() { using Azure::Core::Http::HttpMethod; using Azure::Core::Http::Request; @@ -212,59 +262,70 @@ Azure::Core::Credentials::AccessToken AzureArcManagedIdentitySource::GetToken( { auto& httpRequest = request->HttpRequest; httpRequest.SetHeader("Metadata", "true"); + + if (!scopesStr.empty()) { - auto const& scopes = tokenRequestContext.Scopes; - if (!scopes.empty()) - { - httpRequest.GetUrl().AppendQueryParameter("resource", FormatScopes(scopes, true)); - } + httpRequest.GetUrl().AppendQueryParameter("resource", scopesStr); } } return request; }; - return TokenCredentialImpl::GetToken( - context, - createRequest, - [&](auto const statusCode, auto const& response) -> std::unique_ptr { - using Core::Credentials::AuthenticationException; - using Core::Http::HttpStatusCode; - - if (statusCode != HttpStatusCode::Unauthorized) - { - return nullptr; - } - - auto const& headers = response.GetHeaders(); - auto authHeader = headers.find("WWW-Authenticate"); - if (authHeader == headers.end()) - { - throw AuthenticationException( - "Did not receive expected WWW-Authenticate header " - "in the response from Azure Arc Managed Identity Endpoint."); - } - - constexpr auto ChallengeValueSeparator = '='; - auto const& challenge = authHeader->second; - auto eq = challenge.find(ChallengeValueSeparator); - if (eq == std::string::npos - || challenge.find(ChallengeValueSeparator, eq + 1) != std::string::npos) - { - throw AuthenticationException("The WWW-Authenticate header in the response from Azure " - "Arc Managed Identity Endpoint " - "did not match the expected format."); - } - - auto request = createRequest(); - std::ifstream secretFile(challenge.substr(eq + 1)); - request->HttpRequest.SetHeader( - "Authorization", - "Basic " - + std::string( - std::istreambuf_iterator(secretFile), std::istreambuf_iterator())); - - return request; + // TokenCache::GetToken() and TokenCredentialImpl::GetToken() can only use the lambda argument + // when they are being executed. They are not supposed to keep a reference to lambda argument to + // call it later. Therefore, any capture made here will outlive the possible time frame when the + // lambda might get called. + return TokenCache::GetToken( + std::string(), + GetClientId(), + GetAuthorityHost(), + scopesStr, + tokenRequestContext.MinimumExpiration, + [&]() { + return TokenCredentialImpl::GetToken( + context, + createRequest, + [&](auto const statusCode, auto const& response) -> std::unique_ptr { + using Core::Credentials::AuthenticationException; + using Core::Http::HttpStatusCode; + + if (statusCode != HttpStatusCode::Unauthorized) + { + return nullptr; + } + + auto const& headers = response.GetHeaders(); + auto authHeader = headers.find("WWW-Authenticate"); + if (authHeader == headers.end()) + { + throw AuthenticationException( + "Did not receive expected WWW-Authenticate header " + "in the response from Azure Arc Managed Identity Endpoint."); + } + + constexpr auto ChallengeValueSeparator = '='; + auto const& challenge = authHeader->second; + auto eq = challenge.find(ChallengeValueSeparator); + if (eq == std::string::npos + || challenge.find(ChallengeValueSeparator, eq + 1) != std::string::npos) + { + throw AuthenticationException( + "The WWW-Authenticate header in the response from Azure Arc " + "Managed Identity Endpoint did not match the expected format."); + } + + auto request = createRequest(); + std::ifstream secretFile(challenge.substr(eq + 1)); + request->HttpRequest.SetHeader( + "Authorization", + "Basic " + + std::string( + std::istreambuf_iterator(secretFile), + std::istreambuf_iterator())); + + return request; + }); }); } @@ -278,7 +339,7 @@ std::unique_ptr ImdsManagedIdentitySource::Create( ImdsManagedIdentitySource::ImdsManagedIdentitySource( std::string const& clientId, Azure::Core::Credentials::TokenCredentialOptions const& options) - : ManagedIdentitySource(options), + : ManagedIdentitySource(clientId, std::string(), options), m_request( Azure::Core::Http::HttpMethod::Get, Azure::Core::Url("http://169.254.169.254/metadata/identity/oauth2/token")) @@ -302,16 +363,35 @@ Azure::Core::Credentials::AccessToken ImdsManagedIdentitySource::GetToken( Azure::Core::Credentials::TokenRequestContext const& tokenRequestContext, Azure::Core::Context const& context) const { - return TokenCredentialImpl::GetToken(context, [&]() { - auto request = std::make_unique(m_request); + std::string scopesStr; + { + auto const& scopes = tokenRequestContext.Scopes; + if (!scopes.empty()) { - auto const& scopes = tokenRequestContext.Scopes; - if (!scopes.empty()) - { - request->HttpRequest.GetUrl().AppendQueryParameter("resource", FormatScopes(scopes, true)); - } + scopesStr = TokenCredentialImpl::FormatScopes(scopes, true); } + } - return request; - }); + // TokenCache::GetToken() and TokenCredentialImpl::GetToken() can only use the lambda argument + // when they are being executed. They are not supposed to keep a reference to lambda argument to + // call it later. Therefore, any capture made here will outlive the possible time frame when the + // lambda might get called. + return TokenCache::GetToken( + std::string(), + GetClientId(), + GetAuthorityHost(), + scopesStr, + tokenRequestContext.MinimumExpiration, + [&]() { + return TokenCredentialImpl::GetToken(context, [&]() { + auto request = std::make_unique(m_request); + + if (!scopesStr.empty()) + { + request->HttpRequest.GetUrl().AppendQueryParameter("resource", scopesStr); + } + + return request; + }); + }); } diff --git a/sdk/identity/azure-identity/src/private/managed_identity_source.hpp b/sdk/identity/azure-identity/src/private/managed_identity_source.hpp index 40d66fc589..04ec7b6a70 100644 --- a/sdk/identity/azure-identity/src/private/managed_identity_source.hpp +++ b/sdk/identity/azure-identity/src/private/managed_identity_source.hpp @@ -11,9 +11,14 @@ #include #include +#include namespace Azure { namespace Identity { namespace _detail { class ManagedIdentitySource : protected TokenCredentialImpl { + private: + std::string m_clientId; + std::string m_authorityHost; + public: virtual Core::Credentials::AccessToken GetToken( Core::Credentials::TokenRequestContext const& tokenRequestContext, @@ -22,10 +27,17 @@ namespace Azure { namespace Identity { namespace _detail { protected: static Core::Url ParseEndpointUrl(std::string const& url, char const* envVarName); - explicit ManagedIdentitySource(Core::Credentials::TokenCredentialOptions const& options) - : TokenCredentialImpl(options) + explicit ManagedIdentitySource( + std::string clientId, + std::string authorityHost, + Core::Credentials::TokenCredentialOptions const& options) + : TokenCredentialImpl(options), m_clientId(std::move(clientId)), + m_authorityHost(std::move(authorityHost)) { } + + std::string const& GetClientId() const { return m_clientId; } + std::string const& GetAuthorityHost() const { return m_authorityHost; } }; class AppServiceManagedIdentitySource : public ManagedIdentitySource { diff --git a/sdk/identity/azure-identity/src/private/token_cache.hpp b/sdk/identity/azure-identity/src/private/token_cache.hpp new file mode 100644 index 0000000000..fa88a2b125 --- /dev/null +++ b/sdk/identity/azure-identity/src/private/token_cache.hpp @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Token cache. + */ + +#pragma once + +#include +#include + +#include +#include + +namespace Azure { namespace Identity { namespace _detail { + /** + * @brief Implements an access token cache. + * + */ + class TokenCache final { + TokenCache() = delete; + ~TokenCache() = delete; + + public: + /** + * @brief Attempts to get token from cache, and if not found, gets the token using the function + * provided, caches it, and returns its value. + * + * @param tenantId Azure Tenant ID. + * @param clientId Azure Client ID. + * @param authorityHost Authentication authority URL. + * @param scopes Authentication scopes. + * + * @return Authentication token. + * + * @throw Azure::Core::Credentials::AuthenticationException Authentication error occurred. + */ + static Core::Credentials::AccessToken GetToken( + std::string const& tenantId, + std::string const& clientId, + std::string const& authorityHost, + std::string const& scopes, + DateTime::duration minimumExpiration, + std::function const& getNewToken); + + /** + * @brief Provides access to internal aspects of the cache as a test hook. + * + */ + class Internals; + +#if defined(TESTING_BUILD) + /** + * @brief Clears token cache. Intended to only be used in tests. + * + */ + static void Clear(); +#endif + }; +}}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/src/private/token_cache_internals.hpp b/sdk/identity/azure-identity/src/private/token_cache_internals.hpp new file mode 100644 index 0000000000..d3d67b30f6 --- /dev/null +++ b/sdk/identity/azure-identity/src/private/token_cache_internals.hpp @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Token cache internals and test hooks. + */ + +#pragma once + +#include "token_cache.hpp" + +#include + +#include +#include +#include +#include +#include +#include + +namespace Azure { namespace Identity { namespace _detail { + /** + * @brief Implements internal aspects of token cache and provides test hooks. + * + */ + class TokenCache::Internals final { + Internals() = delete; + ~Internals() = delete; + + public: + /** + * @brief Represents a unique set of characteristics that are used to distinguish between cache + * entries. + * + */ + struct CacheKey final + { + std::string TenantId; ///< Tenant ID. + std::string ClientId; ///< Client ID. + std::string AuthorityHost; ///< Authority Host. + std::string Scopes; ///< Authentication Scopes as a single string. + + bool operator<(TokenCache::Internals::CacheKey const& other) const + { + return std::tie(TenantId, ClientId, AuthorityHost, Scopes) + < std::tie(other.TenantId, other.ClientId, other.AuthorityHost, other.Scopes); + } + }; + + /** + * @brief Represents immediate cache value (token) and a synchronization primitive to handle its + * updates. + * + */ + struct CacheValue final + { + std::shared_timed_mutex ElementMutex; + Core::Credentials::AccessToken AccessToken; + }; + + /** + * @brief The cache itself. + * + */ + static std::map> Cache; + + /** + * @brief Mutex to access the cache container. + * + */ + static std::shared_timed_mutex CacheMutex; + +#if defined(TESTING_BUILD) + /** + * A test hook that gets invoked before cache write lock gets acquired. + * + */ + static std::function OnBeforeCacheWriteLock; + + /** + * A test hook that gets invoked before item write lock gets acquired. + * + */ + static std::function OnBeforeItemWriteLock; +#endif + }; +}}} // namespace Azure::Identity::_detail diff --git a/sdk/identity/azure-identity/src/token_cache.cpp b/sdk/identity/azure-identity/src/token_cache.cpp new file mode 100644 index 0000000000..4f88476e8e --- /dev/null +++ b/sdk/identity/azure-identity/src/token_cache.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "private/token_cache.hpp" +#include "private/token_cache_internals.hpp" + +#include +#include + +using Azure::Identity::_detail::TokenCache; + +using Azure::DateTime; +using Azure::Core::Credentials::AccessToken; + +decltype(TokenCache::Internals::Cache) TokenCache::Internals::Cache; +decltype(TokenCache::Internals::CacheMutex) TokenCache::Internals::CacheMutex; + +#if defined(TESTING_BUILD) +std::function TokenCache::Internals::OnBeforeCacheWriteLock; +std::function TokenCache::Internals::OnBeforeItemWriteLock; +#endif + +namespace { +bool IsFresh( + std::shared_ptr const& item, + DateTime::duration minimumExpiration, + std::chrono::system_clock::time_point now) +{ + return item->AccessToken.ExpiresOn > (DateTime(now) + minimumExpiration); +} + +std::shared_ptr GetOrCreateValue( + TokenCache::Internals::CacheKey const& key, + DateTime::duration minimumExpiration) +{ + { + std::shared_lock cacheReadLock(TokenCache::Internals::CacheMutex); + + auto const found = TokenCache::Internals::Cache.find(key); + if (found != TokenCache::Internals::Cache.end()) + { + return found->second; + } + } + +#if defined(TESTING_BUILD) + if (TokenCache::Internals::OnBeforeCacheWriteLock != nullptr) + { + TokenCache::Internals::OnBeforeCacheWriteLock(); + } +#endif + + std::unique_lock cacheWriteLock(TokenCache::Internals::CacheMutex); + + // Search cache for the second time, in case the item was inserted between releasing the read lock + // and acquiring the write lock. + auto const found = TokenCache::Internals::Cache.find(key); + if (found != TokenCache::Internals::Cache.end()) + { + return found->second; + } + + // Clean up cache from expired items (once every N insertions). + { + auto const cacheSize = TokenCache::Internals::Cache.size(); + + // N: cacheSize (before insertion) is >= 32 and is a power of two. + // 32 as a starting point does not have any special meaning. + // + // Power of 2 trick: + // https://www.exploringbinary.com/ten-ways-to-check-if-an-integer-is-a-power-of-two-in-c/ + + if (cacheSize >= 32 && (cacheSize & (cacheSize - 1)) == 0) + { + auto now = std::chrono::system_clock::now(); + + auto iter = TokenCache::Internals::Cache.begin(); + while (iter != TokenCache::Internals::Cache.end()) + { + // Should we end up erasing the element, iterator to current will become invalid, after + // which we can't increment it. So we copy current, and safely advance the loop iterator. + auto const curr = iter; + ++iter; + + // We will try to obtain a write lock, but in a non-blocking way. We only lock it if no one + // was holding it for read and write at a time. If it's busy in any way, we don't wait, but + // move on. + auto const item = curr->second; + { + std::unique_lock lock(item->ElementMutex, std::defer_lock); + if (lock.try_lock() && !IsFresh(item, minimumExpiration, now)) + { + TokenCache::Internals::Cache.erase(curr); + } + } + } + } + } + + // Insert the blank value value and return it. + return TokenCache::Internals::Cache[key] = std::make_shared(); +} +} // namespace + +AccessToken TokenCache::GetToken( + std::string const& tenantId, + std::string const& clientId, + std::string const& authorityHost, + std::string const& scopes, + DateTime::duration minimumExpiration, + std::function const& getNewToken) +{ + auto const item + = GetOrCreateValue({tenantId, clientId, authorityHost, scopes}, minimumExpiration); + + { + std::shared_lock itemReadLock(item->ElementMutex); + + if (IsFresh(item, minimumExpiration, std::chrono::system_clock::now())) + { + return item->AccessToken; + } + } + + { +#if defined(TESTING_BUILD) + if (TokenCache::Internals::OnBeforeItemWriteLock != nullptr) + { + TokenCache::Internals::OnBeforeItemWriteLock(); + } +#endif + + std::unique_lock itemWriteLock(item->ElementMutex); + + // Check the expiration for the second time, in case it just got updated, after releasing the + // itemReadLock, and before acquiring itemWriteLock. + if (IsFresh(item, minimumExpiration, std::chrono::system_clock::now())) + { + return item->AccessToken; + } + + auto const newToken = getNewToken(); + item->AccessToken = newToken; + return newToken; + } +} + +#if defined(TESTING_BUILD) +void TokenCache::Clear() +{ + std::unique_lock cacheWriteLock(TokenCache::Internals::CacheMutex); + Internals::Cache.clear(); +} +#endif diff --git a/sdk/identity/azure-identity/test/ut/CMakeLists.txt b/sdk/identity/azure-identity/test/ut/CMakeLists.txt index 784e32018b..243f4d567c 100644 --- a/sdk/identity/azure-identity/test/ut/CMakeLists.txt +++ b/sdk/identity/azure-identity/test/ut/CMakeLists.txt @@ -25,6 +25,7 @@ add_executable ( macro_guard_test.cpp managed_identity_credential_test.cpp simplified_header_test.cpp + token_cache_test.cpp token_credential_impl_test.cpp token_credential_test.cpp ) diff --git a/sdk/identity/azure-identity/test/ut/credential_test_helper.cpp b/sdk/identity/azure-identity/test/ut/credential_test_helper.cpp index 8c0868535e..6285b4e782 100644 --- a/sdk/identity/azure-identity/test/ut/credential_test_helper.cpp +++ b/sdk/identity/azure-identity/test/ut/credential_test_helper.cpp @@ -3,6 +3,7 @@ #include "credential_test_helper.hpp" +#include "private/token_cache_internals.hpp" #include #include @@ -70,6 +71,8 @@ CredentialTestHelper::TokenRequestSimulationResult CredentialTestHelper::Simulat std::vector const& responses, GetTokenCallback getToken) { + Azure::Identity::_detail::TokenCache::Clear(); + using Azure::Core::Context; using Azure::Core::Http::HttpStatusCode; using Azure::Core::Http::RawResponse; diff --git a/sdk/identity/azure-identity/test/ut/token_cache_test.cpp b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp new file mode 100644 index 0000000000..78c4368227 --- /dev/null +++ b/sdk/identity/azure-identity/test/ut/token_cache_test.cpp @@ -0,0 +1,450 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "private/token_cache_internals.hpp" + +#include + +#include + +using Azure::DateTime; +using Azure::Core::Credentials::AccessToken; +using Azure::Identity::_detail::TokenCache; + +using namespace std::chrono_literals; + +TEST(TokenCache, KeyComparison) +{ + using Key = TokenCache::Internals::CacheKey; + Key const key1{"a", "b", "c", "d"}; + EXPECT_FALSE(key1 < key1); + + { + Key const key1dup{"a", "b", "c", "d"}; + + EXPECT_FALSE(key1 < key1dup); + EXPECT_FALSE(key1dup < key1); + } + + Key const key2{"a", "b", "c", "~"}; + Key const key3{"a", "b", "~", "d"}; + Key const key4{"a", "~", "c", "d"}; + Key const key5{"~", "b", "c", "d"}; + + EXPECT_TRUE(key1 < key2); + EXPECT_TRUE(key1 < key3); + EXPECT_TRUE(key1 < key4); + EXPECT_TRUE(key1 < key5); + EXPECT_FALSE(key2 < key1); + EXPECT_FALSE(key3 < key1); + EXPECT_FALSE(key4 < key1); + EXPECT_FALSE(key5 < key1); + + EXPECT_TRUE(key2 < key3); + EXPECT_TRUE(key2 < key4); + EXPECT_TRUE(key2 < key5); + EXPECT_FALSE(key3 < key2); + EXPECT_FALSE(key4 < key2); + EXPECT_FALSE(key5 < key2); + + EXPECT_TRUE(key3 < key4); + EXPECT_TRUE(key3 < key5); + EXPECT_FALSE(key4 < key3); + EXPECT_FALSE(key5 < key3); + + EXPECT_TRUE(key4 < key5); + EXPECT_FALSE(key5 < key4); +} + +TEST(TokenCache, GetReuseRefresh) +{ + TokenCache::Clear(); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + auto const Yesterday = Tomorrow - 48h; + + { + auto const token1 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, Tomorrow); + EXPECT_EQ(token1.Token, "T1"); + + auto const token2 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 24h; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, token2.ExpiresOn); + EXPECT_EQ(token1.Token, token2.Token); + } + + { + TokenCache::Internals::Cache[{"A", "B", "C", "D"}]->AccessToken.ExpiresOn = Yesterday; + + auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + AccessToken result; + result.Token = "T3"; + result.ExpiresOn = Tomorrow + 1min; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow + 1min); + EXPECT_EQ(token.Token, "T3"); + } +} + +TEST(TokenCache, TwoThreadsAttemptToInsertTheSameKey) +{ + TokenCache::Clear(); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + TokenCache::Internals::OnBeforeCacheWriteLock = [=]() { + TokenCache::Internals::OnBeforeCacheWriteLock = nullptr; + static_cast(TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); + }; + + auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " + "acquiring cache write lock"); + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 1min; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "T1"); +} + +TEST(TokenCache, TwoThreadsAttemptToUpdateTheSameToken) +{ + TokenCache::Clear(); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + auto const Yesterday = Tomorrow - 48h; + + { + TokenCache::Internals::OnBeforeItemWriteLock = [=]() { + TokenCache::Internals::OnBeforeItemWriteLock = nullptr; + auto const item = TokenCache::Internals::Cache[{"A", "B", "C", "D"}]; + item->AccessToken.Token = "T1"; + item->AccessToken.ExpiresOn = Tomorrow; + }; + + auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the fresh value was inserted just before " + "acquiring item write lock"); + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 1min; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow); + EXPECT_EQ(token.Token, "T1"); + } + + // Same as above, but the token that was inserted is already expired. + { + TokenCache::Clear(); + + TokenCache::Internals::OnBeforeItemWriteLock = [=]() { + TokenCache::Internals::OnBeforeItemWriteLock = nullptr; + auto const item = TokenCache::Internals::Cache[{"A", "B", "C", "D"}]; + item->AccessToken.Token = "T3"; + item->AccessToken.ExpiresOn = Yesterday; + }; + + auto const token = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + AccessToken result; + result.Token = "T4"; + result.ExpiresOn = Tomorrow + 3min; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token.ExpiresOn, Tomorrow + 3min); + EXPECT_EQ(token.Token, "T4"); + } +} + +TEST(TokenCache, ExpiredCleanup) +{ + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + auto const Yesterday = Tomorrow - 48h; + + TokenCache::Clear(); + EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + + for (auto i = 1; i <= 65; ++i) + { + auto const n = std::to_string(i); + static_cast(TokenCache::GetToken(n, n, n, n, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); + } + + // Simply: we added 64+1 token, none of them has expired. None are expected to be cleaned up. + EXPECT_EQ(TokenCache::Internals::Cache.size(), 65UL); + + // Let's expire 3 of them, with numbers from 1 to 3. + for (auto i = 1; i <= 3; ++i) + { + auto const n = std::to_string(i); + TokenCache::Internals::Cache[{n, n, n, n}]->AccessToken.ExpiresOn = Yesterday; + } + + // Add tokens up to 128 total. When 129th gets added, clean up should get triggered. + for (auto i = 66; i <= 128; ++i) + { + auto const n = std::to_string(i); + static_cast(TokenCache::GetToken(n, n, n, n, 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); + } + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 128UL); + + // Count is at 128. Tokens from 1 to 3 are still in cache even though they are expired. + for (auto i = 1; i <= 3; ++i) + { + auto const n = std::to_string(i); + EXPECT_NE(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end()); + } + + // One more addition to the cache and cleanup for the expired ones will get triggered. + static_cast(TokenCache::GetToken("129", "129", "129", "129", 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + })); + + // We were at 128 before we added 1 more, and now we're at 126. 3 were deleted, 1 was added. + EXPECT_EQ(TokenCache::Internals::Cache.size(), 126UL); + + // Items from 1 to 3 should no longer be in the cache. + for (auto i = 1; i <= 3; ++i) + { + auto const n = std::to_string(i); + EXPECT_EQ(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end()); + } + + // Let's expire items from 21 all the way up to 129. + for (auto i = 21; i <= 129; ++i) + { + auto const n = std::to_string(i); + TokenCache::Internals::Cache[{n, n, n, n}]->AccessToken.ExpiresOn = Yesterday; + } + + // Re-add items 2 and 3. Adding them should not trigger cleanup. After adding, cache should get to + // 128 items (with numbers from 2 to 129, and number 1 missing). + for (auto i = 2; i <= 3; ++i) + { + auto const n = std::to_string(i); + static_cast(TokenCache::GetToken(n, n, n, n, 2min, [=]() { + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow; + return result; + })); + } + + // Cache is now at 128 again (items from 2 to 129). Adding 1 more will trigger cleanup. + EXPECT_EQ(TokenCache::Internals::Cache.size(), 128UL); + + // Now let's lock some of the items for reading, and some for writing. Cleanup should not block on + // token release, but will simply move on, without doing anything to the ones that were locked. + // Out of 4 locked, two are expired, so they should get cleared under normla circumstances, but + // this time they will remain in the cache. + std::shared_lock readLockForUnexpired( + TokenCache::Internals::Cache[{"2", "2", "2", "2"}]->ElementMutex); + + std::shared_lock readLockForExpired( + TokenCache::Internals::Cache[{"127", "127", "127", "127"}]->ElementMutex); + + std::unique_lock writeLockForUnexpired( + TokenCache::Internals::Cache[{"3", "3", "3", "3"}]->ElementMutex); + + std::unique_lock writeLockForExpired( + TokenCache::Internals::Cache[{"128", "128", "128", "128"}]->ElementMutex); + + // Count is at 128. Inserting the 129th element, and it will trigger cleanup. + static_cast(TokenCache::GetToken("1", "1", "1", "1", 2min, [=]() { + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow; + return result; + })); + + // These should be 20 unexpired items + two that are expired but were locked, so 22 total. + EXPECT_EQ(TokenCache::Internals::Cache.size(), 22UL); + + for (auto i = 1; i <= 20; ++i) + { + auto const n = std::to_string(i); + EXPECT_NE(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end()); + } + + EXPECT_NE( + TokenCache::Internals::Cache.find({"127", "127", "127", "127"}), + TokenCache::Internals::Cache.end()); + + EXPECT_NE( + TokenCache::Internals::Cache.find({"128", "128", "128", "128"}), + TokenCache::Internals::Cache.end()); + + for (auto i = 21; i <= 126; ++i) + { + auto const n = std::to_string(i); + EXPECT_EQ(TokenCache::Internals::Cache.find({n, n, n, n}), TokenCache::Internals::Cache.end()); + } +} + +TEST(TokenCache, MinimumExpiration) +{ + TokenCache::Clear(); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + auto const token1 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, Tomorrow); + EXPECT_EQ(token1.Token, "T1"); + + auto const token2 = TokenCache::GetToken("A", "B", "C", "D", 24h, [=]() { + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 1h; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token2.ExpiresOn, Tomorrow + 1h); + EXPECT_EQ(token2.Token, "T2"); +} + +TEST(TokenCache, MultithreadedAccess) +{ + TokenCache::Clear(); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 0UL); + + DateTime const Tomorrow = std::chrono::system_clock::now() + 24h; + + auto const token1 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + AccessToken result; + result.Token = "T1"; + result.ExpiresOn = Tomorrow; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token1.ExpiresOn, Tomorrow); + EXPECT_EQ(token1.Token, "T1"); + + { + std::shared_lock itemReadLock( + TokenCache::Internals::Cache[{"A", "B", "C", "D"}]->ElementMutex); + + { + std::shared_lock cacheReadLock(TokenCache::Internals::CacheMutex); + + // Parallel threads read both the container and the item we're accessing, and we can access it + // in parallel as well. + auto const token2 = TokenCache::GetToken("A", "B", "C", "D", 2min, [=]() { + EXPECT_FALSE("getNewToken does not get invoked when the existing cache value is good"); + AccessToken result; + result.Token = "T2"; + result.ExpiresOn = Tomorrow + 1h; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 1UL); + + EXPECT_EQ(token2.ExpiresOn, token1.ExpiresOn); + EXPECT_EQ(token2.Token, token1.Token); + } + + // The cache is unlocked, but one item is being read in a parallel thread, which does not + // prevent new items (with different key) from being appended to cache. + auto const token3 = TokenCache::GetToken("E", "F", "G", "H", 2min, [=]() { + AccessToken result; + result.Token = "T3"; + result.ExpiresOn = Tomorrow + 2h; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 2UL); + + EXPECT_EQ(token3.ExpiresOn, Tomorrow + 2h); + EXPECT_EQ(token3.Token, "T3"); + } + + { + std::unique_lock itemWriteLock( + TokenCache::Internals::Cache[{"A", "B", "C", "D"}]->ElementMutex); + + // The cache is unlocked, but one item is being written in a parallel thread, which does not + // prevent new items (with different key) from being appended to cache. + auto const token3 = TokenCache::GetToken("I", "J", "K", "L", 2min, [=]() { + AccessToken result; + result.Token = "T4"; + result.ExpiresOn = Tomorrow + 3h; + return result; + }); + + EXPECT_EQ(TokenCache::Internals::Cache.size(), 3UL); + + EXPECT_EQ(token3.ExpiresOn, Tomorrow + 3h); + EXPECT_EQ(token3.Token, "T4"); + } +} diff --git a/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp b/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp index dc9b7e0752..d2693f7ab2 100644 --- a/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_credential_impl_test.cpp @@ -3,6 +3,8 @@ #include "private/token_credential_impl.hpp" +#include "private/token_cache.hpp" + #include "credential_test_helper.hpp" #include @@ -18,6 +20,7 @@ using Azure::Core::Credentials::TokenCredential; using Azure::Core::Credentials::TokenCredentialOptions; using Azure::Core::Credentials::TokenRequestContext; using Azure::Core::Http::HttpMethod; +using Azure::Identity::_detail::TokenCache; using Azure::Identity::_detail::TokenCredentialImpl; using Azure::Identity::Test::_detail::CredentialTestHelper; @@ -50,6 +53,8 @@ class TokenCredentialImplTester : public TokenCredential { AccessToken GetToken(TokenRequestContext const& tokenRequestContext, Context const& context) const override { + TokenCache::Clear(); + return m_tokenCredentialImpl->GetToken(context, [&]() { m_throwingFunction(); diff --git a/sdk/identity/azure-identity/test/ut/token_credential_test.cpp b/sdk/identity/azure-identity/test/ut/token_credential_test.cpp index 1ad894727d..5dca4c58ce 100644 --- a/sdk/identity/azure-identity/test/ut/token_credential_test.cpp +++ b/sdk/identity/azure-identity/test/ut/token_credential_test.cpp @@ -7,6 +7,8 @@ #include #include +#include "private/token_cache.hpp" + #include #include @@ -58,6 +60,8 @@ TEST_F(TokenCredentialTest, ClientSecret) std::string const testName(GetTestName()); auto const clientSecretCredential = GetClientSecretCredential(testName); + _detail::TokenCache::Clear(); + auto const token = clientSecretCredential->GetToken( {{"https://vault.azure.net/.default"}}, Azure::Core::Context::ApplicationContext); @@ -70,6 +74,8 @@ TEST_F(TokenCredentialTest, EnvironmentCredential) std::string const testName(GetTestName()); auto const clientSecretCredential = GetEnvironmentCredential(testName); + _detail::TokenCache::Clear(); + auto const token = clientSecretCredential->GetToken( {{"https://vault.azure.net/.default"}}, Azure::Core::Context::ApplicationContext);