Skip to content

Commit

Permalink
According to https://www.rfc-editor.org/rfc/rfc7519 section 10.1, cla…
Browse files Browse the repository at this point in the history
…im names (a.k.a. "types") are case-sensitive.

Claims manipulations have been factored out in a separate utility class.
  • Loading branch information
vvdb-architecture committed Oct 9, 2024
1 parent 8332c5c commit f0c3fc6
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 62 deletions.
43 changes: 16 additions & 27 deletions src/Arc4u.Standard.OAuth2.AspNetCore/AppPrincipalTransform.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Security.Claims;
using Arc4u.Configuration;
using Arc4u.Dependency.Attribute;
Expand All @@ -9,6 +8,7 @@
using Arc4u.OAuth2.Options;
using Arc4u.OAuth2.Token;
using Arc4u.Security.Principal;
using Arc4u.Standard.OAuth2.Security;
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -98,8 +98,6 @@ public async Task<ClaimsPrincipal> TransformAsync(ClaimsPrincipal principal)
}

#region Handling extra claims
private const string tokenExpirationClaimType = "exp";
private static readonly HashSet<string> ClaimsToExclude = new HashSet<string>(StringComparer.OrdinalIgnoreCase) { "aud", "iss", "iat", "nbf", "acr", "aio", "appidacr", "ipaddr", "scp", "sub", "tid", "uti", "unique_name", "apptype", "appid", "ver", "http://schemas.microsoft.com/ws/2008/06/identity/claims/authenticationinstant", "http://schemas.microsoft.com/identity/claims/scope" };

/// <summary>
/// This code is similar to the code in AppPrincipalFactory where the claims are stored in a secureCache.
Expand Down Expand Up @@ -127,9 +125,8 @@ private async Task LoadExtraClaimsAsync(ClaimsIdentity? identity)
// if cachedClaims is not null, it means that the information was in the cache and didn't expire yet.
if (cachedClaims is not null)
{
// Add the new claims to the identity
identity.AddClaims(cachedClaims.Where(c => !identity.Claims.Any(c1 => StringComparer.OrdinalIgnoreCase.Equals(c1.Type, c.ClaimType)))
.Select(c => new Claim(c.ClaimType, c.Value)));
// Add the new claims to the identity. Note that we need to merge the same way as in the case they are retrieved (see below)
identity.MergeClaims(cachedClaims);
return;
}

Expand All @@ -142,37 +139,28 @@ private async Task LoadExtraClaimsAsync(ClaimsIdentity? identity)
// Should receive specific extra claims. This is the responsibility of the caller to provide the right claims.
// We expect the exp claim to be present.
// if not the persistence time will be the default one.
var claims = (await _claimsFiller.GetAsync(identity, settings, null).ConfigureAwait(false)).Where(c => !ClaimsToExclude.Contains(c.ClaimType)).ToList();
var claims = (await _claimsFiller.GetAsync(identity, settings, null).ConfigureAwait(false)).Sanitize().ToList();

// Load the claims into the identity but exclude the exp claim and the one already present.
identity.AddClaims(claims.Where(c => !StringComparer.OrdinalIgnoreCase.Equals(c.ClaimType, tokenExpirationClaimType))
.Where(c => !identity.Claims.Any(c1 => StringComparer.OrdinalIgnoreCase.Equals(c1.Type, c.ClaimType)))
.Select(c => new Claim(c.ClaimType, c.Value)));
identity.MergeClaims(claims);

TimeSpan timeout;

// Check expiry claim explicitly returned in the call to _claimsFiller.GetAsync
// If no expiry claim found in the call to _claimsFiller.GetAsync, try to locate one in the existing identity claims (which are normally obtained from the bearer token)
DateTimeOffset? expDate;

if (!TryParseExpirationDate(claims.FirstOrDefault(c => StringComparer.OrdinalIgnoreCase.Equals(c.ClaimType, tokenExpirationClaimType))?.Value, out expDate) && !TryParseExpirationDate(identity.Claims.FirstOrDefault(c => StringComparer.OrdinalIgnoreCase.Equals(c.Type, tokenExpirationClaimType))?.Value, out expDate))
if (claims.TryGetExpiration(out var ticks) || identity.Claims.TryGetExpiration(out ticks))
{
// there's an expiration claim in the claims returned by the claims filler or in the existing identity claims, so we can compute an expiration date
timeout = DateTimeOffset.FromUnixTimeSeconds(ticks) - DateTimeOffset.UtcNow;
}
else
{
// fall back to the configured max time.
expDate = DateTimeOffset.UtcNow.Add(_options.MaxTime);
timeout = _options.MaxTime;
}

SaveClaimsToCache(claims, cacheKey, expDate!.Value - DateTimeOffset.UtcNow);
}
}

private static bool TryParseExpirationDate(string? value, out DateTimeOffset? expirationDate)
{
if (value is not null && long.TryParse(value, NumberStyles.Integer, CultureInfo.InvariantCulture, out var ticks))
{
expirationDate = DateTimeOffset.FromUnixTimeSeconds(ticks);
return true;
SaveClaimsToCache(claims, cacheKey, timeout);
}

expirationDate = null;
return false;
}

private List<ClaimDto>? GetClaimsFromCache(string? cacheKey)
Expand All @@ -198,6 +186,7 @@ private void SaveClaimsToCache([DisallowNull] IEnumerable<ClaimDto> claims, stri
return;
}

// Only if the cache key is valid does it make sense to test the claims validity
ArgumentNullException.ThrowIfNull(claims);

if (timeout > TimeSpan.Zero)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
using System;
using System.Diagnostics;
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using Arc4u.Configuration;
using Arc4u.Dependency;
using Arc4u.Dependency.Attribute;
using Arc4u.Diagnostics;
using Arc4u.OAuth2.Token;
using Arc4u.Security.Principal;
using Arc4u.ServiceModel;
using Arc4u.Standard.OAuth2.Security;
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
Expand All @@ -23,8 +20,6 @@ public class AppServicePrincipalFactory : IAppPrincipalFactory
{
public const string ProviderKey = "ProviderId";

public static readonly string tokenExpirationClaimType = "exp";
public static readonly string[] ClaimsToExclude = { "exp", "aud", "iss", "iat", "nbf", "acr", "aio", "appidacr", "ipaddr", "scp", "sub", "tid", "uti", "unique_name", "apptype", "appid", "ver", "http://schemas.microsoft.com/ws/2008/06/identity/claims/authenticationinstant", "http://schemas.microsoft.com/identity/claims/scope" };

private readonly IContainerResolve _container;
private readonly ILogger<AppServicePrincipalFactory> _logger;
Expand Down Expand Up @@ -102,13 +97,12 @@ private async Task BuildTheIdentity(ClaimsIdentity identity, IKeyValueSettings s
}

// Check the settings contains the service url.
TokenInfo? token = null;
try
{
token = await provider.GetTokenAsync(settings, parameter).ConfigureAwait(true);
var token = await provider.GetTokenAsync(settings, parameter).ConfigureAwait(true);
identity.BootstrapContext = token.Token;
var jwtToken = new JwtSecurityToken(token.Token);
identity.AddClaims(jwtToken.Claims.Where(c => !ClaimsToExclude.Any(arg => arg.Equals(c.Type))).Select(c => new Claim(c.Type, c.Value)));
identity.MergeClaims(jwtToken.Claims.Sanitize());
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using Arc4u.Caching;
using Arc4u.Configuration;
using Arc4u.Dependency;
Expand All @@ -16,6 +10,7 @@
using Arc4u.OAuth2.Token;
using Arc4u.Security.Principal;
using Arc4u.ServiceModel;
using Arc4u.Standard.OAuth2.Security;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

Expand All @@ -28,8 +23,6 @@ public class AppPrincipalFactory : IAppPrincipalFactory
public const string DefaultSettingsResolveName = "OAuth2";
public const string PlatformParameters = "platformParameters";

public static readonly string tokenExpirationClaimType = "exp";
public static readonly string[] ClaimsToExclude = { "exp", "aud", "iss", "iat", "nbf", "acr", "aio", "appidacr", "ipaddr", "scp", "sub", "tid", "uti", "unique_name", "apptype", "appid", "ver", "http://schemas.microsoft.com/ws/2008/06/identity/claims/authenticationinstant", "http://schemas.microsoft.com/identity/claims/scope" };

private readonly IContainerResolve _container;
private readonly INetworkInformation _networkInformation;
Expand Down Expand Up @@ -118,10 +111,6 @@ private async Task BuildTheIdentity(ClaimsIdentity identity, IKeyValueSettings s
// The token has claims filled by the STS.
// We can fill the new federated identity with the claims from the token.
var jwtToken = new JwtSecurityToken(token.Token);
var expTokenClaim = jwtToken.Claims.FirstOrDefault(c => c.Type.Equals(tokenExpirationClaimType, StringComparison.InvariantCultureIgnoreCase));
long expTokenTicks = 0;
if (null != expTokenClaim)
long.TryParse(expTokenClaim.Value, NumberStyles.Integer, CultureInfo.InvariantCulture, out expTokenTicks);

// The key for the cache is based on the claims from a ClaimsIdentity => build a dummy identity with the claim from the token.
var dummyIdentity = new ClaimsIdentity(jwtToken.Claims);
Expand All @@ -132,38 +121,37 @@ private async Task BuildTheIdentity(ClaimsIdentity identity, IKeyValueSettings s
// if we have a token "cached" from the system, we can take the authorization claims from the cache (if exists)...
// so we avoid too many backend calls for nothing.
// But every time we have a token that has been refreshed, we will call the backend (if available and reload the claims).
var cachedExpiredClaim = cachedClaims.FirstOrDefault(c => c.ClaimType.Equals(tokenExpirationClaimType, StringComparison.InvariantCultureIgnoreCase));
long cachedExpiredTicks = 0;

if (null != cachedExpiredClaim)
long.TryParse(cachedExpiredClaim.Value, NumberStyles.Integer, CultureInfo.InvariantCulture, out cachedExpiredTicks);

// we only call the backend if the ticks are not the same.
bool copyClaimsFromCache = cachedExpiredTicks > 0 && expTokenTicks > 0 && cachedClaims.Count > 0 && cachedExpiredTicks == expTokenTicks;

if (copyClaimsFromCache)
var copyClaimsFromCache = jwtToken.Claims.TryGetExpiration(out Claim? expTokenClaim) && expTokenClaim.Value.TryParseTicks(out var expTokenTicks)
&& cachedClaims.TryGetExpiration(out var cachedExpiredTicks)
&& cachedExpiredTicks == expTokenTicks;
if (copyClaimsFromCache && cachedClaims.Count > 0)
{
identity.AddClaims(cachedClaims.Select(p => new Claim(p.ClaimType, p.Value)));
messages.Add(new Message(ServiceModel.MessageCategory.Technical, MessageType.Information, "Create the principal from the cache, token has not been refreshed."));
}
else
{
// Fill the claims based on the token and the backend call
identity.AddClaims(jwtToken.Claims.Where(c => !ClaimsToExclude.Any(arg => arg.Equals(c.Type))).Select(c => new Claim(c.Type, c.Value)));
identity.AddClaim(expTokenClaim);
identity.AddClaims(jwtToken.Claims.Sanitize().Select(c => new Claim(c.Type, c.Value)));
// there is no guarantee that expTokenClaim is not null, so we need to check it.
if (expTokenClaim is not null)
{
identity.AddClaim(expTokenClaim);
}

if (_container.TryResolve(out IClaimsFiller claimFiller)) // Fill the claims with more information.
{
try
{
// Get the claims and clean any technical claims in case of.
var claims = (await claimFiller.GetAsync(identity, new List<IKeyValueSettings> { settings }, parameter))
.Where(c => !ClaimsToExclude.Any(arg => arg.Equals(c.ClaimType))).ToList();
var claims = (await claimFiller.GetAsync(identity, new List<IKeyValueSettings> { settings }, parameter)).Sanitize().ToList();

// We copy the claims from the backend but the exp claim will be the value of the token (front end definition) and not the backend one. Otherwhise there will be always a difference.
identity.AddClaims(claims.Where(c => !identity.Claims.Any(c1 => c1.Type == c.ClaimType)).Select(c => new Claim(c.ClaimType, c.Value)));
identity.MergeClaims(claims);

messages.Add(new Message(ServiceModel.MessageCategory.Technical, MessageType.Information, $"Add {claims.Count()} claims to the principal."));
messages.Add(new Message(ServiceModel.MessageCategory.Technical, MessageType.Information, $"Add {claims.Count} claims to the principal."));
messages.Add(new Message(ServiceModel.MessageCategory.Technical, MessageType.Information, $"Save claims to the cache."));
}
catch (Exception e)
Expand Down
117 changes: 117 additions & 0 deletions src/Arc4u.Standard.OAuth2/Security/ClaimsUtility.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#if NET6_0_OR_GREATER
using System.Diagnostics.CodeAnalysis;
#endif
using System.Globalization;
using System.Security.Claims;
using Arc4u.IdentityModel.Claims;

namespace Arc4u.Standard.OAuth2.Security;

/// <summary>
/// Claims utilities and extension methods, used in AppPrincipalFactory and AppPrincipalTransform.
/// </summary>
/// <remarks>
/// Claim type comparison is always case-sensitive. See https://www.rfc-editor.org/rfc/rfc7519 section 10.1.
/// </remarks>
public static class ClaimsUtility
{
private const string ExpirationClaimType = "exp";
private static readonly HashSet<string> ClaimsToExclude = ["aud", "iss", "iat", "nbf", "acr", "aio", "appidacr", "ipaddr", "scp", "sub", "tid", "uti", "unique_name", "apptype", "appid", "ver", "http://schemas.microsoft.com/ws/2008/06/identity/claims/authenticationinstant", "http://schemas.microsoft.com/identity/claims/scope"];

/// <summary>
/// Sanitize the list of external claims by removing the claims we don't need.
/// The <paramref name="claims"/> may come from an external system. If there is an "exp" claim in the list, it will be included.
/// </summary>
/// <param name="claims"></param>
/// <returns></returns>
public static IEnumerable<ClaimDto> Sanitize(this IEnumerable<ClaimDto> claims)
{
foreach (var claim in claims)
{
if (ClaimsToExclude.Contains(claim.ClaimType))
{
continue;
}
yield return claim;
}
}

/// <summary>
/// Sanitize the list of internal claims by removing the claims we don't need.
/// The <paramref name="claims"/> come from the token claims . If there is an "exp" claim in the list, it will be *excluded* as well
/// since it may be replaced by the equivalent external claim.
/// </summary>
/// <param name="claims"></param>
/// <returns></returns>
public static IEnumerable<Claim> Sanitize(this IEnumerable<Claim> claims)
{
foreach (var claim in claims)
{
if (ClaimsToExclude.Contains(claim.Type) || claim.Type == ExpirationClaimType)
{
continue;
}
yield return claim;
}
}

/// <summary>
/// Merge the <paramref name="claimsToMerge"/> into the <paramref name="claimsIdentity"/>, but only the claims that are not already present.
/// The "exp" claim is always excluded from the merge.
/// </summary>
/// <param name="claimsIdentity"></param>
/// <param name="claimsToMerge"></param>
public static void MergeClaims(this ClaimsIdentity claimsIdentity, IEnumerable<ClaimDto> claimsToMerge)
{
claimsIdentity.AddClaims(claimsToMerge.Where(c => c.ClaimType != ExpirationClaimType && !claimsIdentity.HasClaim(c1 => c1.Type == c.ClaimType)).Select(c => new Claim(c.ClaimType, c.Value)));
}

/// <summary>
/// Merge the <paramref name="claimsToMerge"/> into the <paramref name="claimsIdentity"/>, but only the claims that are not already present.
/// The "exp" claim is always excluded from the merge.
/// </summary>
/// <param name="claimsIdentity"></param>
/// <param name="claimsToMerge"></param>
public static void MergeClaims(this ClaimsIdentity claimsIdentity, IEnumerable<Claim> claimsToMerge)
{
claimsIdentity.AddClaims(claimsToMerge.Where(c => c.Type != ExpirationClaimType && !claimsIdentity.HasClaim(c1 => c1.Type == c.Type)).Select(c => new Claim(c.Type, c.Value)));
}

#if NET6_0_OR_GREATER
public static bool TryGetExpiration(this IEnumerable<Claim> claims, [NotNullWhen(true)] out Claim? expClaim)
#else
public static bool TryGetExpiration(this IEnumerable<Claim> claims, out Claim expClaim)
#endif
{
expClaim = claims.FirstOrDefault(c => c.Type == ExpirationClaimType);
return expClaim is not null;
}

public static bool TryGetExpiration(this IEnumerable<Claim> claims, out long ticks)
{
var claim = claims.FirstOrDefault(c => c.Type == ExpirationClaimType);

if (claim is null)
{
ticks = 0;
return false;
}

return claim.Value.TryParseTicks(out ticks);
}

public static bool TryGetExpiration(this IEnumerable<ClaimDto> claims, out long ticks)
{
var claim = claims.FirstOrDefault(c => c.ClaimType == ExpirationClaimType);

if (claim is null)
{
ticks = 0;
return false;
}

return claim.Value.TryParseTicks(out ticks);
}

public static bool TryParseTicks(this string s, out long ticks) => long.TryParse(s, NumberStyles.Integer, CultureInfo.InvariantCulture, out ticks);
}

0 comments on commit f0c3fc6

Please sign in to comment.