diff --git a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/ChallengeCacheHandler.cs b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/ChallengeCacheHandler.cs
new file mode 100644
index 0000000000000..86c013cf7e02b
--- /dev/null
+++ b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/ChallengeCacheHandler.cs
@@ -0,0 +1,50 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Azure.KeyVault.Customized.Authentication
+{
+ ///
+ /// A which will update the when a 401 response is returned with a WWW-Authenticate bearer challenge header.
+ ///
+ public class ChallengeCacheHandler : MessageProcessingHandler
+ {
+ ///
+ /// Returns the specified request without performing any processing
+ ///
+ ///
+ ///
+ ///
+ protected override HttpRequestMessage ProcessRequest(HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ return request;
+ }
+ ///
+ /// Updates the when the specified response has a return code of 401
+ ///
+ /// The response to evaluate
+ /// The cancellation token
+ ///
+ protected override HttpResponseMessage ProcessResponse(HttpResponseMessage response, CancellationToken cancellationToken)
+ {
+ // if the response came back as 401 and the response contains a bearer challenge update the challenge cache
+ if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized)
+ {
+ HttpBearerChallenge challenge = HttpBearerChallenge.GetBearerChallengeFromResponse(response);
+
+ if (challenge != null)
+ {
+ // Update challenge cache
+ HttpBearerChallengeCache.GetInstance().SetChallengeForURL(response.RequestMessage.RequestUri, challenge);
+ }
+ }
+
+ return response;
+ }
+
+ }
+}
diff --git a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/HttpBearerChallenge.cs b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/HttpBearerChallenge.cs
index a454c07db1c4b..1ebe236602e27 100644
--- a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/HttpBearerChallenge.cs
+++ b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/HttpBearerChallenge.cs
@@ -4,6 +4,8 @@
using System;
using System.Collections.Generic;
+using System.Net.Http;
+using System.Linq;
namespace Microsoft.Azure.KeyVault
{
@@ -36,6 +38,20 @@ public static bool IsBearerChallenge(string challenge)
return true;
}
+ internal static HttpBearerChallenge GetBearerChallengeFromResponse(HttpResponseMessage response)
+ {
+ if (response == null) return null;
+
+ string challenge = response?.Headers.WwwAuthenticate.FirstOrDefault()?.ToString();
+
+ if (!string.IsNullOrEmpty(challenge) && IsBearerChallenge(challenge))
+ {
+ return new HttpBearerChallenge(response.RequestMessage.RequestUri, challenge);
+ }
+
+ return null;
+ }
+
private Dictionary _parameters = null;
private string _sourceAuthority = null;
private Uri _sourceUri = null;
diff --git a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/KeyVaultClient.cs b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/KeyVaultClient.cs
index b3dc4d1150352..9a2a8cc78f3ee 100644
--- a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/KeyVaultClient.cs
+++ b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/KeyVaultClient.cs
@@ -16,6 +16,7 @@ namespace Microsoft.Azure.KeyVault
using Rest.Serialization;
using Newtonsoft.Json;
using System.Linq;
+ using Microsoft.Azure.KeyVault.Customized.Authentication;
///
@@ -226,5 +227,20 @@ public KeyVaultClient(KeyVaultCredential credential, HttpClient httpClient)
}
return _result;
}
+
+ ///
+ /// Overrides the base to add a as the outermost .
+ ///
+ /// The base handler for the client
+ /// The handler pipeline for the client
+ ///
+ protected override DelegatingHandler CreateHttpHandlerPipeline(HttpClientHandler httpClientHandler, params DelegatingHandler[] handlers)
+ {
+ var challengeCacheHandler = new ChallengeCacheHandler();
+
+ challengeCacheHandler.InnerHandler = base.CreateHttpHandlerPipeline(httpClientHandler, handlers);
+
+ return challengeCacheHandler;
+ }
}
}
\ No newline at end of file
diff --git a/sdk/keyvault/Microsoft.Azure.KeyVault/tests/ChallengeCacheHandlerTests.cs b/sdk/keyvault/Microsoft.Azure.KeyVault/tests/ChallengeCacheHandlerTests.cs
new file mode 100644
index 0000000000000..c4dc3f47b7a68
--- /dev/null
+++ b/sdk/keyvault/Microsoft.Azure.KeyVault/tests/ChallengeCacheHandlerTests.cs
@@ -0,0 +1,254 @@
+using Microsoft.Azure.KeyVault.Customized.Authentication;
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+using Kvp = System.Collections.Generic.KeyValuePair;
+
+namespace Microsoft.Azure.KeyVault.Tests
+{
+ public class ChallengeCacheHandlerTests
+ {
+ [Fact]
+ public async Task CacheAddOn401Async()
+ {
+ var handler = new ChallengeCacheHandler();
+
+ var expChallenge = MockChallenge.Create();
+
+ handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString());
+
+ var client = new HttpClient(handler);
+
+ var requestUrl = CreateMockUrl(2);
+
+ var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));
+
+ AssertChallengeCacheEntry(requestUrl, expChallenge);
+ }
+
+ [Fact]
+ public async Task CacheUpdateOn401Async()
+ {
+ string requestUrlBase = CreateMockUrl();
+
+ string requestUrl1 = CreateMockUrl(requestUrlBase, 2);
+
+ HttpBearerChallengeCache.GetInstance().SetChallengeForURL(new Uri(requestUrl1), MockChallenge.Create().ToHttpBearerChallenge(requestUrl1));
+
+ string requestUrl2 = CreateMockUrl(requestUrlBase, 2);
+
+ var handler = new ChallengeCacheHandler();
+
+ var expChallenge = MockChallenge.Create();
+
+ handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString());
+
+ var client = new HttpClient(handler);
+
+ var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl2));
+
+ AssertChallengeCacheEntry(requestUrl1, expChallenge);
+
+ AssertChallengeCacheEntry(requestUrl2, expChallenge);
+ }
+
+ [Fact]
+ public async Task CacheNotUpdatedNoChallengeAsync()
+ {
+ var handler = new ChallengeCacheHandler();
+
+ handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, null);
+
+ var client = new HttpClient(handler);
+
+ var requestUrl = CreateMockUrl(2);
+
+ var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));
+
+ AssertChallengeCacheEntry(requestUrl, null);
+ }
+
+ [Fact]
+ public async Task CacheNotUpdatedNon401Aysnc()
+ {
+ var handler = new ChallengeCacheHandler();
+
+ var expChallenge = MockChallenge.Create();
+
+ handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Forbidden, expChallenge.ToString());
+
+ var client = new HttpClient(handler);
+
+ var requestUrl = CreateMockUrl(2);
+
+ var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));
+
+ AssertChallengeCacheEntry(requestUrl, null);
+ }
+
+ [Fact]
+ public async Task CacheNotUpdatedNonBearerChallengeAsync()
+ {
+ var handler = new ChallengeCacheHandler();
+
+ handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, MockChallenge.Create("PoP").ToString());
+
+ var client = new HttpClient(handler);
+
+ var requestUrl = CreateMockUrl(2);
+
+ var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));
+
+ AssertChallengeCacheEntry(requestUrl, null);
+ }
+
+ private static void AssertChallengeCacheEntry(string requestUrl, MockChallenge expChallenge)
+ {
+ var actChallenge = HttpBearerChallengeCache.GetInstance().GetChallengeForURL(new Uri(requestUrl));
+
+ if (expChallenge == null)
+ {
+ Assert.Null(actChallenge);
+ }
+ else
+ {
+ Assert.NotNull(actChallenge);
+
+ Assert.Equal(expChallenge.AuthorizationServer, actChallenge.AuthorizationServer);
+
+ Assert.Equal(expChallenge.Scope, actChallenge.Scope);
+
+ Assert.Equal(expChallenge.Resource, actChallenge.Resource);
+ }
+ }
+
+ private static string BuildChallengeString(params Kvp[] parameters)
+ {
+ // remove the trailing ',' and return
+ return BuildChallengeString("Bearer", parameters);
+ }
+
+ private static string BuildChallengeString(string challengeType, params Kvp[] parameters)
+ {
+ StringBuilder buff = new StringBuilder(challengeType).Append(" ");
+
+ foreach (var kvp in parameters)
+ {
+ buff.Append(kvp.Key).Append("=\"").Append(kvp.Value).Append("\",");
+ }
+
+ // remove the trailing ',' and return
+ return buff.Remove(buff.Length - 1, 1).ToString();
+ }
+
+ public static string CreateMockUrl(int pathCount = 0)
+ {
+ return CreateMockUrl("https://" + Guid.NewGuid().ToString("N"), pathCount);
+ }
+
+ public static string CreateMockUrl(string baseUrl, int pathCount = 0)
+ {
+ var buff = new StringBuilder(baseUrl);
+
+ if(baseUrl.EndsWith("/"))
+ {
+ buff.Remove(buff.Length - 1, 1);
+ }
+
+ for (int i = 0; i < pathCount; i++)
+ {
+ buff.Append("/").Append(Guid.NewGuid().ToString("N"));
+ }
+
+ return buff.ToString();
+ }
+
+ private class MockChallenge
+ {
+ public string ChallengeType { get; set; }
+
+ public string AuthorizationServer { get; set; }
+
+ public string Resource { get; set; }
+
+ public string Scope { get; set; }
+
+ public static MockChallenge Create(string challengeType = null, string authority = null, string resource = null, string scope = null)
+ {
+ var mock = new MockChallenge();
+ mock.ChallengeType = challengeType ?? "Bearer";
+ mock.AuthorizationServer = authority ?? CreateMockUrl(1);
+ mock.Resource = resource ?? CreateMockUrl(0);
+ mock.Scope = scope ?? mock.Resource + "/.default";
+ return mock;
+ }
+
+ public HttpBearerChallenge ToHttpBearerChallenge(string requestUrl)
+ {
+ return new HttpBearerChallenge(new Uri(requestUrl), ToString());
+ }
+
+ public override string ToString()
+ {
+ var parameters = new List();
+
+ StringBuilder buff = new StringBuilder();
+
+ if(AuthorizationServer != null)
+ {
+ parameters.Add(new Kvp("authorization", AuthorizationServer));
+ }
+
+ if (Resource != null)
+ {
+ parameters.Add(new Kvp("resource", Resource));
+ }
+
+ if (Scope != null)
+ {
+ parameters.Add(new Kvp("scope", Scope));
+ }
+
+ return BuildChallengeString(ChallengeType, parameters.ToArray());
+ }
+
+ }
+
+
+ private class StaticChallengeResponseHandler : HttpMessageHandler
+ {
+ private HttpStatusCode _statusCode;
+ private string _challengeHeader;
+
+ public StaticChallengeResponseHandler(HttpStatusCode statusCode, string challengeHeader)
+ {
+ _statusCode = statusCode;
+ _challengeHeader = challengeHeader;
+ }
+
+ protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ return Task.FromResult(CreateResponse(request));
+ }
+
+ private HttpResponseMessage CreateResponse(HttpRequestMessage request)
+ {
+ var response = new HttpResponseMessage(_statusCode);
+
+ response.RequestMessage = request;
+
+ if(_challengeHeader != null)
+ {
+ response.Headers.Add("WWW-Authenticate", _challengeHeader);
+ }
+
+ return response;
+ }
+ }
+ }
+}