Skip to content

Commit

Permalink
Add validation callback and tests (#4818)
Browse files Browse the repository at this point in the history
* initial

* Add validation callback and tests

* Address comments

* Missed to use Lazy HttpClient

* Update src/client/Microsoft.Identity.Client/Http/IHttpManager.cs

Co-authored-by: Bogdan Gavril <[email protected]>

* Update src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs

Co-authored-by: Bogdan Gavril <[email protected]>

* Rename

* Add retry policy

* Address comments

* Update tests to test managed identity retry policy as well

* Address comments

* Update the logic to check for custom http client first.

* Fix tests failing due to reflection

* Update the service fabric endpoint in tests

---------

Co-authored-by: Gladwin Johnson <[email protected]>
Co-authored-by: Gladwin Johnson <[email protected]>
Co-authored-by: Bogdan Gavril <[email protected]>
  • Loading branch information
4 people authored Jul 16, 2024
1 parent df47013 commit 9049fe9
Show file tree
Hide file tree
Showing 18 changed files with 187 additions and 41 deletions.
19 changes: 16 additions & 3 deletions src/client/Microsoft.Identity.Client/Http/HttpManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public async Task<HttpResponse> SendRequestAsync(
ILoggerAdapter logger,
bool doNotThrow,
X509Certificate2 bindingCertificate,
HttpClient customHttpClient,
CancellationToken cancellationToken,
int retryCount = 0)
{
Expand All @@ -75,6 +76,7 @@ public async Task<HttpResponse> SendRequestAsync(
clonedBody,
method,
bindingCertificate,
customHttpClient,
logger,
cancellationToken).ConfigureAwait(false);
}
Expand Down Expand Up @@ -111,6 +113,7 @@ public async Task<HttpResponse> SendRequestAsync(
logger,
doNotThrow,
bindingCertificate,
customHttpClient,
cancellationToken: cancellationToken,
retryCount) // Pass the updated retry count
.ConfigureAwait(false);
Expand Down Expand Up @@ -142,8 +145,17 @@ public async Task<HttpResponse> SendRequestAsync(
return response;
}

private HttpClient GetHttpClient(X509Certificate2 x509Certificate2)
{
private HttpClient GetHttpClient(X509Certificate2 x509Certificate2, HttpClient customHttpClient) {
if (x509Certificate2 != null && customHttpClient != null)
{
throw new NotImplementedException("Mtls certificate cannot be used with service fabric. A custom http client is used for service fabric managed identity to validate the server certificate.");
}

if (customHttpClient != null)
{
return customHttpClient;
}

if (_httpClientFactory is IMsalMtlsHttpClientFactory msalMtlsHttpClientFactory)
{
// If the factory is an IMsalMtlsHttpClientFactory, use it to get an HttpClient with the certificate
Expand Down Expand Up @@ -175,6 +187,7 @@ private async Task<HttpResponse> ExecuteAsync(
HttpContent body,
HttpMethod method,
X509Certificate2 bindingCertificate,
HttpClient customHttpClient,
ILoggerAdapter logger,
CancellationToken cancellationToken = default)
{
Expand All @@ -189,7 +202,7 @@ private async Task<HttpResponse> ExecuteAsync(

Stopwatch sw = Stopwatch.StartNew();

HttpClient client = GetHttpClient(bindingCertificate);
HttpClient client = GetHttpClient(bindingCertificate, customHttpClient);

using (HttpResponseMessage responseMessage =
await client.SendAsync(requestMessage, cancellationToken).ConfigureAwait(false))
Expand Down
17 changes: 17 additions & 0 deletions src/client/Microsoft.Identity.Client/Http/IHttpManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,29 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;

namespace Microsoft.Identity.Client.Http
{
internal interface IHttpManager
{
long LastRequestDurationInMs { get; }

/// <summary>
/// Method to send a request to the server using the HttpClient configured in the implementation.
/// </summary>
/// <param name="endpoint">The endpoint to send the request to.</param>
/// <param name="headers">Headers to send in the http request.</param>
/// <param name="body">Body of the request.</param>
/// <param name="method">Http method.</param>
/// <param name="logger">Logger from the request context.</param>
/// <param name="doNotThrow">Flag to decide if MsalServiceException is thrown or the response is returned in case of 5xx errors.</param>
/// <param name="retry">Flag to indicate whether the retries are performed in for specific failures.</param>

Check warning on line 28 in src/client/Microsoft.Identity.Client/Http/IHttpManager.cs

View workflow job for this annotation

GitHub Actions / Run performance benchmarks

XML comment has a param tag for 'retry', but there is no parameter by that name

Check warning on line 28 in src/client/Microsoft.Identity.Client/Http/IHttpManager.cs

View workflow job for this annotation

GitHub Actions / Run performance benchmarks

XML comment has a param tag for 'retry', but there is no parameter by that name
/// <param name="mtlsCertificate">Certificate used for MTLS authentication.</param>
/// <param name="customHttpClient">Custom http client which bypasses the HttpClientFactory.
/// This is needed for service fabric managed identity where a cert validation callback is added to the handler.</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
Task<HttpResponse> SendRequestAsync(
Uri endpoint,
IDictionary<string, string> headers,
Expand All @@ -23,6 +39,7 @@ Task<HttpResponse> SendRequestAsync(
ILoggerAdapter logger,
bool doNotThrow,
X509Certificate2 mtlsCertificate,
HttpClient customHttpClient,
CancellationToken cancellationToken,
int retryCount = 0);

Check warning on line 44 in src/client/Microsoft.Identity.Client/Http/IHttpManager.cs

View workflow job for this annotation

GitHub Actions / Run performance benchmarks

Parameter 'retryCount' has no matching param tag in the XML comment for 'IHttpManager.SendRequestAsync(Uri, IDictionary<string, string>, HttpContent, HttpMethod, ILoggerAdapter, bool, X509Certificate2, HttpClient, CancellationToken, int)' (but other parameters do)

Check warning on line 44 in src/client/Microsoft.Identity.Client/Http/IHttpManager.cs

View workflow job for this annotation

GitHub Actions / Run performance benchmarks

Parameter 'retryCount' has no matching param tag in the XML comment for 'IHttpManager.SendRequestAsync(Uri, IDictionary<string, string>, HttpContent, HttpMethod, ILoggerAdapter, bool, X509Certificate2, HttpClient, CancellationToken, int)' (but other parameters do)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,31 +199,33 @@ private async Task<RegionInfo> DiscoverAsync(ILoggerAdapter logger, Cancellation
Uri imdsUri = BuildImdsUri(DefaultApiVersion);

HttpResponse response = await _httpManager.SendRequestAsync(
imdsUri,
headers,
body: null,
HttpMethod.Get,
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
GetCancellationToken(requestCancellationToken))
.ConfigureAwait(false);
imdsUri,
headers,
body: null,
HttpMethod.Get,
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
GetCancellationToken(requestCancellationToken))
.ConfigureAwait(false);

// A bad request occurs when the version in the IMDS call is no longer supported.
if (response.StatusCode == HttpStatusCode.BadRequest)
{
string apiVersion = await GetImdsUriApiVersionAsync(logger, headers, requestCancellationToken).ConfigureAwait(false); // Get the latest version
imdsUri = BuildImdsUri(apiVersion);
response = await _httpManager.SendRequestAsync(
imdsUri,
headers,
body: null,
HttpMethod.Get,
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
GetCancellationToken(requestCancellationToken))
.ConfigureAwait(false); // Call again with updated version
imdsUri,
headers,
body: null,
HttpMethod.Get,
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
GetCancellationToken(requestCancellationToken))
.ConfigureAwait(false); // Call again with updated version
}

if (response.StatusCode == HttpStatusCode.OK && !response.Body.IsNullOrEmpty())
Expand Down Expand Up @@ -323,6 +325,7 @@ private async Task<string> GetImdsUriApiVersionAsync(ILoggerAdapter logger, Dict
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
GetCancellationToken(userCancellationToken))
.ConfigureAwait(false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ public async Task ValidateAuthorityAsync(
string webFingerUrl = Constants.FormatAdfsWebFingerUrl(authorityInfo.Host, resource);

Http.HttpResponse httpResponse = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync(
new Uri(webFingerUrl),
null,
body: null,
System.Net.Http.HttpMethod.Get,
logger: _requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
_requestContext.UserCancellationToken)
.ConfigureAwait(false);
new Uri(webFingerUrl),
null,
body: null,
System.Net.Http.HttpMethod.Get,
logger: _requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
_requestContext.UserCancellationToken)
.ConfigureAwait(false);

if (httpResponse.StatusCode != HttpStatusCode.OK)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
logger: _requestContext.Logger,
doNotThrow: true,
mtlsCertificate: null,
GetHttpClientWithSslValidation(_requestContext),
cancellationToken).ConfigureAwait(false);
}
else
Expand All @@ -76,6 +77,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
logger: _requestContext.Logger,
doNotThrow: true,
mtlsCertificate: null,
GetHttpClientWithSslValidation(_requestContext),
cancellationToken)
.ConfigureAwait(false);

Expand All @@ -90,6 +92,12 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
}
}

// This method is internal for testing purposes.
internal virtual HttpClient GetHttpClientWithSslValidation(RequestContext requestContext)
{
return null;
}

protected virtual Task<ManagedIdentityResponse> HandleResponseAsync(
AcquireTokenForManagedIdentityParameters parameters,
HttpResponse response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ protected override async Task<ManagedIdentityResponse> HandleResponseAsync(
logger: _requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
cancellationToken)
.ConfigureAwait(false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace Microsoft.Identity.Client.ManagedIdentity
internal class ServiceFabricManagedIdentitySource : AbstractManagedIdentity
{
private const string ServiceFabricMsiApiVersion = "2019-07-01-preview";

private readonly Uri _endpoint;
private readonly string _identityHeaderValue;
internal static Lazy<HttpClient> _httpClientLazy;

public static AbstractManagedIdentity Create(RequestContext requestContext)
{
Expand Down Expand Up @@ -45,8 +45,44 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)
return new ServiceFabricManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.IdentityHeader);
}

internal override HttpClient GetHttpClientWithSslValidation(RequestContext requestContext)
{
if (_httpClientLazy == null)
{
_httpClientLazy = new Lazy<HttpClient>(() =>
{
HttpClientHandler handler = CreateHandlerWithSslValidation(requestContext.Logger);
return new HttpClient(handler);
});
}

return _httpClientLazy.Value;
}

internal HttpClientHandler CreateHandlerWithSslValidation(ILoggerAdapter logger)
{
#if NET471_OR_GREATER || NETSTANDARD || NET
logger.Info(() => "[Managed Identity] Setting up server certificate validation callback.");
return new HttpClientHandler
{
ServerCertificateCustomValidationCallback = (message, certificate, chain, sslPolicyErrors) =>
{
if (sslPolicyErrors != System.Net.Security.SslPolicyErrors.None)
{
return 0 == string.Compare(certificate.Thumbprint, EnvironmentVariables.IdentityServerThumbprint, StringComparison.OrdinalIgnoreCase);
}
return true;
}
};
#else
logger.Warning("[Managed Identity] Server certificate validation callback is not supported on .NET Framework.");
return new HttpClientHandler();
#endif
}


private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri endpoint, string identityHeaderValue) :
base(requestContext, ManagedIdentitySource.ServiceFabric)
base(requestContext, ManagedIdentitySource.ServiceFabric)
{
_endpoint = endpoint;
_identityHeaderValue = identityHeaderValue;
Expand Down
2 changes: 2 additions & 0 deletions src/client/Microsoft.Identity.Client/OAuth2/OAuth2Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ internal async Task<T> ExecuteRequestAsync<T>(
logger: requestContext.Logger,
doNotThrow: false,
mtlsCertificate: _mtlsCertificate,
customHttpClient: null,
requestContext.UserCancellationToken)
.ConfigureAwait(false);
}
Expand All @@ -152,6 +153,7 @@ internal async Task<T> ExecuteRequestAsync<T>(
logger: requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
requestContext.UserCancellationToken)
.ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public async Task<MexDocument> GetMexDocumentAsync(string federationMetadataUrl,
logger: requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
requestContext.UserCancellationToken)
.ConfigureAwait(false);

Expand Down Expand Up @@ -106,6 +107,7 @@ public async Task<WsTrustResponse> GetWsTrustResponseAsync(
logger: requestContext.Logger,
doNotThrow: true,
mtlsCertificate: null,
customHttpClient: null,
requestContext.UserCancellationToken).ConfigureAwait(false);

if (resp.StatusCode != System.Net.HttpStatusCode.OK)
Expand Down Expand Up @@ -179,6 +181,7 @@ public async Task<UserRealmDiscoveryResponse> GetUserRealmAsync(
logger: requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
requestContext.UserCancellationToken)
.ConfigureAwait(false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public enum MsiAzureResource
ServiceFabric
}

public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentitySource, string endpoint, string secret = "secret")
public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentitySource, string endpoint, string secret = "secret", string thumbprint = "thumbprint")
{
switch (managedIdentitySource)
{
Expand All @@ -57,7 +57,7 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity
case ManagedIdentitySource.ServiceFabric:
Environment.SetEnvironmentVariable("IDENTITY_ENDPOINT", endpoint);
Environment.SetEnvironmentVariable("IDENTITY_HEADER", secret);
Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", "thumbprint");
Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", thumbprint);
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Internal;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace Microsoft.Identity.Test.Common.Core.Mocks
Expand All @@ -29,9 +30,11 @@ internal sealed class MockHttpManager : IHttpManager,
public MockHttpManager(string testName = null,
bool isManagedIdentity = false,
Func<MockHttpMessageHandler> messageHandlerFunc = null,
Func<HttpClient> validateServerCertificateCallback = null,
bool invokeNonMtlsHttpManagerFactory = false) :
this(true, testName, isManagedIdentity, messageHandlerFunc, invokeNonMtlsHttpManagerFactory)
{ }
{
}

public MockHttpManager(
bool retry,
Expand Down Expand Up @@ -110,6 +113,7 @@ public Task<HttpResponse> SendRequestAsync(
ILoggerAdapter logger,
bool doNotThrow,
X509Certificate2 mtlsCertificate,
HttpClient customHttpClient,
CancellationToken cancellationToken,
int retryCount = 0)
{
Expand All @@ -121,6 +125,7 @@ public Task<HttpResponse> SendRequestAsync(
logger,
doNotThrow,
mtlsCertificate,
customHttpClient: null,
cancellationToken,
retryCount);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public async Task<HttpResponse> SendRequestAsync(
ILoggerAdapter logger,
bool doNotThrow,
X509Certificate2 mtlsCertificate,
HttpClient customHttpClient,
CancellationToken cancellationToken,
int retryCount = 0)
{
Expand Down
Loading

0 comments on commit 9049fe9

Please sign in to comment.