Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate token cache timeout setting in most cases #126

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/Arc4u.Standard.OAuth2.AspNetCore.Adal/Token/Cache.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
using Arc4u.Dependency;
using Arc4u.Dependency;
using Arc4u.Diagnostics;
using Microsoft.Extensions.Logging;
using Microsoft.IdentityModel.Clients.ActiveDirectory;
using System;
using System.Collections.Generic;

namespace Arc4u.OAuth2.Token.Adal
{
Expand Down Expand Up @@ -75,8 +73,19 @@ private void AfterAccessNotification(TokenCacheNotificationArgs args)
if (HasStateChanged)
{
Logger.Technical().From<Cache>().System($"Adding token information to the cache for the identifier: {_identifier}.").Log();
TokenCache.Put(_identifier, SerializeAdalV3());
Logger.Technical().From<Cache>().System($"Added token information to the cache for the identifier: {_identifier}.").Log();

var now = DateTimeOffset.UtcNow;
var maxExpires = ReadItems().Select(item => item.ExpiresOn).DefaultIfEmpty(now).Max();
var timeout = maxExpires - now;
if (timeout > TimeSpan.Zero)
{
TokenCache.Put(_identifier, timeout, SerializeAdalV3());
Logger.Technical().From<Cache>().System($"Added token information to the cache for the identifier: {_identifier}.").Log();
}
else
{
Logger.Technical().From<Cache>().System($"No valid tokens to cache for identifier: {_identifier} because of negative duration {timeout}.").Log();
}

HasStateChanged = false;
}
Expand Down
104 changes: 47 additions & 57 deletions src/Arc4u.Standard.OAuth2.AspNetCore/AppPrincipalTransform.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
using System.Collections.Generic;
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Security.Claims;
using System.Threading.Tasks;
using Arc4u.Configuration;
using Arc4u.Dependency.Attribute;
using Arc4u.Diagnostics;
using Arc4u.IdentityModel.Claims;
using Arc4u.OAuth2.Options;
using Arc4u.OAuth2.Token;
using Arc4u.Security.Principal;
using Arc4u.Standard.OAuth2.Security;
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.DependencyInjection;

namespace Arc4u.OAuth2;

Expand Down Expand Up @@ -102,8 +98,6 @@ public async Task<ClaimsPrincipal> TransformAsync(ClaimsPrincipal principal)
}

#region Handling extra claims
private const string tokenExpirationClaimType = "exp";
private static readonly string[] ClaimsToExclude = { "aud", "iss", "iat", "nbf", "acr", "aio", "appidacr", "ipaddr", "scp", "sub", "tid", "uti", "unique_name", "apptype", "appid", "ver", "http://schemas.microsoft.com/ws/2008/06/identity/claims/authenticationinstant", "http://schemas.microsoft.com/identity/claims/scope" };

/// <summary>
/// This code is similar to the code in AppPrincipalFactory where the claims are stored in a secureCache.
Expand All @@ -128,21 +122,12 @@ private async Task LoadExtraClaimsAsync(ClaimsIdentity? identity)
// Something already in the cache? Avoid a expensive backend call.
var cachedClaims = GetClaimsFromCache(cacheKey);

// check expirity.
var cachedExpiredClaim = cachedClaims.FirstOrDefault(c => c.ClaimType.Equals(tokenExpirationClaimType, StringComparison.OrdinalIgnoreCase));

if (cachedExpiredClaim is not null && long.TryParse(cachedExpiredClaim.Value, out var cachedExpiredTicks))
// if cachedClaims is not null, it means that the information was in the cache and didn't expire yet.
if (cachedClaims is not null)
{
var expDate = DateTimeOffset.FromUnixTimeSeconds(cachedExpiredTicks).UtcDateTime;
if (expDate > DateTime.UtcNow)
{
identity.AddClaims(cachedClaims.Where(c => c.ClaimType != tokenExpirationClaimType)
.Where(c => !ClaimsToExclude.Any(arg => arg.Equals(c.ClaimType)))
.Where(c => !identity.Claims.Any(c1 => c1.Type == c.ClaimType))
.Select(c => new Claim(c.ClaimType, c.Value)));

return;
}
// Add the new claims to the identity. Note that we need to merge the same way as in the case they are retrieved (see below)
identity.MergeClaims(cachedClaims);
return;
}

// Add Telemetry.
Expand All @@ -154,65 +139,70 @@ private async Task LoadExtraClaimsAsync(ClaimsIdentity? identity)
// Should receive specific extra claims. This is the responsibility of the caller to provide the right claims.
// We expect the exp claim to be present.
// if not the persistence time will be the default one.
var claims = (await _claimsFiller.GetAsync(identity, settings, null).ConfigureAwait(false)).Where(c => !ClaimsToExclude.Any(arg => arg.Equals(c.ClaimType))).ToList();
var claims = (await _claimsFiller.GetAsync(identity, settings, null).ConfigureAwait(false)).Sanitize().ToList();

// Load the claims into the identity but exclude the exp claim and the one already present.
identity.AddClaims(claims.Where(c => c.ClaimType != tokenExpirationClaimType)
.Where(c => !ClaimsToExclude.Any(arg => arg.Equals(c.ClaimType)))
.Where(c => !identity.Claims.Any(c1 => c1.Type == c.ClaimType))
.Select(c => new Claim(c.ClaimType, c.Value)));
identity.MergeClaims(claims);

SaveClaimsToCache(claims, cacheKey);
}
TimeSpan timeout;

}

private List<ClaimDto> GetClaimsFromCache(string? cacheKey)
{
var claims = new List<ClaimDto>();
// Check expiry claim explicitly returned in the call to _claimsFiller.GetAsync
// If no expiry claim found in the call to _claimsFiller.GetAsync, try to locate one in the existing identity claims (which are normally obtained from the bearer token)
if (claims.TryGetExpiration(out var ticks) || identity.Claims.TryGetExpiration(out ticks))
{
// there's an expiration claim in the claims returned by the claims filler or in the existing identity claims, so we can compute an expiration date
timeout = DateTimeOffset.FromUnixTimeSeconds(ticks) - DateTimeOffset.UtcNow;
}
else
{
// fall back to the configured max time.
timeout = _options.MaxTime;
}

if (string.IsNullOrWhiteSpace(cacheKey))
{
return claims;
SaveClaimsToCache(claims, cacheKey, timeout);
}
}

try
{
claims = _cacheHelper.GetCache().Get<List<ClaimDto>>(cacheKey) ?? new List<ClaimDto>();
}
catch (Exception ex)
private List<ClaimDto>? GetClaimsFromCache(string? cacheKey)
{
if (!string.IsNullOrWhiteSpace(cacheKey))
{
_logger.Technical().Exception(ex).Log();
try
{
return _cacheHelper.GetCache().Get<List<ClaimDto>>(cacheKey);
}
catch (Exception ex)
{
_logger.Technical().Exception(ex).Log();
}
Comment on lines +172 to +177
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid catching general exceptions

Catching the general Exception type may mask specific issues and make debugging more difficult. Consider catching more specific exceptions that are expected from the cache retrieval operation.

}

return claims;
return null;
}

private void SaveClaimsToCache([DisallowNull] IEnumerable<ClaimDto> claims, string? cacheKey)
private void SaveClaimsToCache([DisallowNull] IEnumerable<ClaimDto> claims, string? cacheKey, TimeSpan timeout)
{
if (string.IsNullOrWhiteSpace(cacheKey))
{
return;
}

// Only if the cache key is valid does it make sense to test the claims validity
ArgumentNullException.ThrowIfNull(claims);

try
if (timeout > TimeSpan.Zero)
{
var cachedExpiredClaim = claims.FirstOrDefault(c => c.ClaimType.Equals(tokenExpirationClaimType, StringComparison.OrdinalIgnoreCase));

// if no expiration claim exist we assume the lifetime of the extra claims is defined by the cache context for the principal.
if (cachedExpiredClaim is null)
try
{
cachedExpiredClaim = new ClaimDto(tokenExpirationClaimType, DateTimeOffset.UtcNow.Add(_cacheOptions.MaxTime).ToUnixTimeSeconds().ToString(CultureInfo.InvariantCulture));
claims = new List<ClaimDto>(claims) { cachedExpiredClaim };
_cacheHelper.GetCache().Put(cacheKey, timeout, claims);
}
catch (Exception ex)
{
_logger.Technical().Exception(ex).Log();
Comment on lines +196 to +200
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid catching general exceptions

In the cache save operation, catching the general Exception type can hide specific errors and complicate error handling. It's advisable to catch specific exceptions that you anticipate from the cache put operation.

}

_cacheHelper.GetCache().Put(cacheKey, _cacheOptions.MaxTime, claims);
}
catch (Exception ex)
else
{
_logger.Technical().Exception(ex).Log();
_logger.Technical().LogWarning("The timeout is not valid to save the claims in the cache.");
}
}
#endregion
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using System;
using System.Linq;
using Arc4u.Configuration;
using Arc4u.OAuth2.Options;
using Microsoft.Extensions.Configuration;
Expand All @@ -13,9 +11,21 @@ public static void AddClaimsFiller(this IServiceCollection services, Action<Clai
var validate = new ClaimsFillerOptions();
options(validate);

if (validate.LoadClaimsFromClaimsFillerProvider && (null == validate.SettingsKeys || !validate.SettingsKeys.Any()))
if (validate.LoadClaimsFromClaimsFillerProvider)
{
throw new ConfigurationException("Settings key must be provided.");
string? configErrors = null;
if (null == validate.SettingsKeys || !validate.SettingsKeys.Any())
{
configErrors += "Settings key must be provided." + System.Environment.NewLine;
}
if (validate.LoadClaimsFromClaimsFillerProvider && validate.MaxTime == TimeSpan.Zero)
{
configErrors += $"If {nameof(validate.LoadClaimsFromClaimsFillerProvider)} is true, then {nameof(validate.MaxTime)} must be provided." + System.Environment.NewLine;
}
if (configErrors is not null)
{
throw new ConfigurationException(configErrors);
}
}

services.Configure<ClaimsFillerOptions>(options);
Expand All @@ -39,10 +49,10 @@ public static void AddClaimsFiller(this IServiceCollection services, IConfigurat
}
}

AddClaimsFiller(services, o =>
{
o.LoadClaimsFromClaimsFillerProvider = options.LoadClaimsFromClaimsFillerProvider;
o.SettingsKeys = options.SettingsKeys;
});
AddClaimsFiller(services, o =>
{
o.LoadClaimsFromClaimsFillerProvider = options.LoadClaimsFromClaimsFillerProvider;
o.SettingsKeys = options.SettingsKeys;
});
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Net.Http.Headers;
using System.Text;
using System.Threading.Tasks;
using Arc4u.Configuration;
using Arc4u.Dependency;
using Arc4u.Diagnostics;
Expand Down Expand Up @@ -89,7 +86,7 @@ public async Task Invoke(HttpContext context)
{
tokenInfo = await GetTokenFromCredentialAsync(credential, context.RequestServices).ConfigureAwait(false);

tokenCache?.Put(cacheKey, tokenInfo);
tokenCache?.Put(cacheKey, tokenInfo.ExpiresOnUtc - DateTime.UtcNow, tokenInfo);
}

if (tokenInfo is not null)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using System.Collections.Generic;

namespace Arc4u.OAuth2.Options;

/// <summary>
Expand All @@ -10,6 +8,12 @@ public class ClaimsFillerOptions
// By default, a specific claims filler is needed to manage the right.
public bool LoadClaimsFromClaimsFillerProvider { get; set; } = true;

/// <summary>
/// If <see cref="LoadClaimsFromClaimsFillerProvider"/> is true, the extra claims are cached for <see cref="MaxTime"/>.
/// The default value of 20 minutes provides a balance between performance and data freshness for most use cases.
/// </summary>
public TimeSpan MaxTime { get; set; } = TimeSpan.FromMinutes(20);

/// <summary>
/// The settings key to load the claims from. => registered as a named <see cref="SimpleKeyValueSettings>"/>
/// By default no on behalf of scenario is defined => AOauth2 must be added to perform the on behalf of scenario.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
using System;
using System.Diagnostics;
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using Arc4u.Configuration;
using Arc4u.Dependency;
using Arc4u.Dependency.Attribute;
using Arc4u.Diagnostics;
using Arc4u.OAuth2.Token;
using Arc4u.Security.Principal;
using Arc4u.ServiceModel;
using Arc4u.Standard.OAuth2.Security;
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
Expand All @@ -23,8 +20,6 @@ public class AppServicePrincipalFactory : IAppPrincipalFactory
{
public const string ProviderKey = "ProviderId";

public static readonly string tokenExpirationClaimType = "exp";
public static readonly string[] ClaimsToExclude = { "exp", "aud", "iss", "iat", "nbf", "acr", "aio", "appidacr", "ipaddr", "scp", "sub", "tid", "uti", "unique_name", "apptype", "appid", "ver", "http://schemas.microsoft.com/ws/2008/06/identity/claims/authenticationinstant", "http://schemas.microsoft.com/identity/claims/scope" };

private readonly IContainerResolve _container;
private readonly ILogger<AppServicePrincipalFactory> _logger;
Expand Down Expand Up @@ -102,13 +97,12 @@ private async Task BuildTheIdentity(ClaimsIdentity identity, IKeyValueSettings s
}

// Check the settings contains the service url.
TokenInfo? token = null;
try
{
token = await provider.GetTokenAsync(settings, parameter).ConfigureAwait(true);
var token = await provider.GetTokenAsync(settings, parameter).ConfigureAwait(true);
identity.BootstrapContext = token.Token;
var jwtToken = new JwtSecurityToken(token.Token);
identity.AddClaims(jwtToken.Claims.Where(c => !ClaimsToExclude.Any(arg => arg.Equals(c.Type))).Select(c => new Claim(c.Type, c.Value)));
identity.MergeClaims(jwtToken.Claims.Sanitize());
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
using System;
using System.Threading.Tasks;
using Arc4u.Configuration;
using Arc4u.Dependency;
using Arc4u.Dependency.Attribute;
using Arc4u.Diagnostics;
Expand Down Expand Up @@ -84,7 +81,7 @@ public async Task<TokenInfo> GetTokenAsync(IKeyValueSettings settings, Credentia
try
{
_logger.Technical().System($"Save the token in the cache for {cacheKey}, will expire at {tokenInfo.ExpiresOnUtc} Utc.").Log();
TokenCache.Put(cacheKey, tokenInfo);
TokenCache.Put(cacheKey, tokenInfo.ExpiresOnUtc - DateTime.UtcNow, tokenInfo);
}
catch (Exception ex)
{
Expand Down
Loading