diff --git a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/ChallengeCacheHandler.cs b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/ChallengeCacheHandler.cs new file mode 100644 index 0000000000000..86c013cf7e02b --- /dev/null +++ b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/ChallengeCacheHandler.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Azure.KeyVault.Customized.Authentication +{ + /// + /// A which will update the when a 401 response is returned with a WWW-Authenticate bearer challenge header. + /// + public class ChallengeCacheHandler : MessageProcessingHandler + { + /// + /// Returns the specified request without performing any processing + /// + /// + /// + /// + protected override HttpRequestMessage ProcessRequest(HttpRequestMessage request, CancellationToken cancellationToken) + { + return request; + } + /// + /// Updates the when the specified response has a return code of 401 + /// + /// The response to evaluate + /// The cancellation token + /// + protected override HttpResponseMessage ProcessResponse(HttpResponseMessage response, CancellationToken cancellationToken) + { + // if the response came back as 401 and the response contains a bearer challenge update the challenge cache + if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized) + { + HttpBearerChallenge challenge = HttpBearerChallenge.GetBearerChallengeFromResponse(response); + + if (challenge != null) + { + // Update challenge cache + HttpBearerChallengeCache.GetInstance().SetChallengeForURL(response.RequestMessage.RequestUri, challenge); + } + } + + return response; + } + + } +} diff --git a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/HttpBearerChallenge.cs b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/HttpBearerChallenge.cs index a454c07db1c4b..1ebe236602e27 100644 --- a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/HttpBearerChallenge.cs +++ b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/Authentication/HttpBearerChallenge.cs @@ -4,6 +4,8 @@ using System; using System.Collections.Generic; +using System.Net.Http; +using System.Linq; namespace Microsoft.Azure.KeyVault { @@ -36,6 +38,20 @@ public static bool IsBearerChallenge(string challenge) return true; } + internal static HttpBearerChallenge GetBearerChallengeFromResponse(HttpResponseMessage response) + { + if (response == null) return null; + + string challenge = response?.Headers.WwwAuthenticate.FirstOrDefault()?.ToString(); + + if (!string.IsNullOrEmpty(challenge) && IsBearerChallenge(challenge)) + { + return new HttpBearerChallenge(response.RequestMessage.RequestUri, challenge); + } + + return null; + } + private Dictionary _parameters = null; private string _sourceAuthority = null; private Uri _sourceUri = null; diff --git a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/KeyVaultClient.cs b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/KeyVaultClient.cs index b3dc4d1150352..9a2a8cc78f3ee 100644 --- a/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/KeyVaultClient.cs +++ b/sdk/keyvault/Microsoft.Azure.KeyVault/src/Customized/KeyVaultClient.cs @@ -16,6 +16,7 @@ namespace Microsoft.Azure.KeyVault using Rest.Serialization; using Newtonsoft.Json; using System.Linq; + using Microsoft.Azure.KeyVault.Customized.Authentication; /// @@ -226,5 +227,20 @@ public KeyVaultClient(KeyVaultCredential credential, HttpClient httpClient) } return _result; } + + /// + /// Overrides the base to add a as the outermost . + /// + /// The base handler for the client + /// The handler pipeline for the client + /// + protected override DelegatingHandler CreateHttpHandlerPipeline(HttpClientHandler httpClientHandler, params DelegatingHandler[] handlers) + { + var challengeCacheHandler = new ChallengeCacheHandler(); + + challengeCacheHandler.InnerHandler = base.CreateHttpHandlerPipeline(httpClientHandler, handlers); + + return challengeCacheHandler; + } } } \ No newline at end of file diff --git a/sdk/keyvault/Microsoft.Azure.KeyVault/tests/ChallengeCacheHandlerTests.cs b/sdk/keyvault/Microsoft.Azure.KeyVault/tests/ChallengeCacheHandlerTests.cs new file mode 100644 index 0000000000000..c4dc3f47b7a68 --- /dev/null +++ b/sdk/keyvault/Microsoft.Azure.KeyVault/tests/ChallengeCacheHandlerTests.cs @@ -0,0 +1,254 @@ +using Microsoft.Azure.KeyVault.Customized.Authentication; +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Kvp = System.Collections.Generic.KeyValuePair; + +namespace Microsoft.Azure.KeyVault.Tests +{ + public class ChallengeCacheHandlerTests + { + [Fact] + public async Task CacheAddOn401Async() + { + var handler = new ChallengeCacheHandler(); + + var expChallenge = MockChallenge.Create(); + + handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString()); + + var client = new HttpClient(handler); + + var requestUrl = CreateMockUrl(2); + + var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); + + AssertChallengeCacheEntry(requestUrl, expChallenge); + } + + [Fact] + public async Task CacheUpdateOn401Async() + { + string requestUrlBase = CreateMockUrl(); + + string requestUrl1 = CreateMockUrl(requestUrlBase, 2); + + HttpBearerChallengeCache.GetInstance().SetChallengeForURL(new Uri(requestUrl1), MockChallenge.Create().ToHttpBearerChallenge(requestUrl1)); + + string requestUrl2 = CreateMockUrl(requestUrlBase, 2); + + var handler = new ChallengeCacheHandler(); + + var expChallenge = MockChallenge.Create(); + + handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString()); + + var client = new HttpClient(handler); + + var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl2)); + + AssertChallengeCacheEntry(requestUrl1, expChallenge); + + AssertChallengeCacheEntry(requestUrl2, expChallenge); + } + + [Fact] + public async Task CacheNotUpdatedNoChallengeAsync() + { + var handler = new ChallengeCacheHandler(); + + handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, null); + + var client = new HttpClient(handler); + + var requestUrl = CreateMockUrl(2); + + var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); + + AssertChallengeCacheEntry(requestUrl, null); + } + + [Fact] + public async Task CacheNotUpdatedNon401Aysnc() + { + var handler = new ChallengeCacheHandler(); + + var expChallenge = MockChallenge.Create(); + + handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Forbidden, expChallenge.ToString()); + + var client = new HttpClient(handler); + + var requestUrl = CreateMockUrl(2); + + var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); + + AssertChallengeCacheEntry(requestUrl, null); + } + + [Fact] + public async Task CacheNotUpdatedNonBearerChallengeAsync() + { + var handler = new ChallengeCacheHandler(); + + handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, MockChallenge.Create("PoP").ToString()); + + var client = new HttpClient(handler); + + var requestUrl = CreateMockUrl(2); + + var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); + + AssertChallengeCacheEntry(requestUrl, null); + } + + private static void AssertChallengeCacheEntry(string requestUrl, MockChallenge expChallenge) + { + var actChallenge = HttpBearerChallengeCache.GetInstance().GetChallengeForURL(new Uri(requestUrl)); + + if (expChallenge == null) + { + Assert.Null(actChallenge); + } + else + { + Assert.NotNull(actChallenge); + + Assert.Equal(expChallenge.AuthorizationServer, actChallenge.AuthorizationServer); + + Assert.Equal(expChallenge.Scope, actChallenge.Scope); + + Assert.Equal(expChallenge.Resource, actChallenge.Resource); + } + } + + private static string BuildChallengeString(params Kvp[] parameters) + { + // remove the trailing ',' and return + return BuildChallengeString("Bearer", parameters); + } + + private static string BuildChallengeString(string challengeType, params Kvp[] parameters) + { + StringBuilder buff = new StringBuilder(challengeType).Append(" "); + + foreach (var kvp in parameters) + { + buff.Append(kvp.Key).Append("=\"").Append(kvp.Value).Append("\","); + } + + // remove the trailing ',' and return + return buff.Remove(buff.Length - 1, 1).ToString(); + } + + public static string CreateMockUrl(int pathCount = 0) + { + return CreateMockUrl("https://" + Guid.NewGuid().ToString("N"), pathCount); + } + + public static string CreateMockUrl(string baseUrl, int pathCount = 0) + { + var buff = new StringBuilder(baseUrl); + + if(baseUrl.EndsWith("/")) + { + buff.Remove(buff.Length - 1, 1); + } + + for (int i = 0; i < pathCount; i++) + { + buff.Append("/").Append(Guid.NewGuid().ToString("N")); + } + + return buff.ToString(); + } + + private class MockChallenge + { + public string ChallengeType { get; set; } + + public string AuthorizationServer { get; set; } + + public string Resource { get; set; } + + public string Scope { get; set; } + + public static MockChallenge Create(string challengeType = null, string authority = null, string resource = null, string scope = null) + { + var mock = new MockChallenge(); + mock.ChallengeType = challengeType ?? "Bearer"; + mock.AuthorizationServer = authority ?? CreateMockUrl(1); + mock.Resource = resource ?? CreateMockUrl(0); + mock.Scope = scope ?? mock.Resource + "/.default"; + return mock; + } + + public HttpBearerChallenge ToHttpBearerChallenge(string requestUrl) + { + return new HttpBearerChallenge(new Uri(requestUrl), ToString()); + } + + public override string ToString() + { + var parameters = new List(); + + StringBuilder buff = new StringBuilder(); + + if(AuthorizationServer != null) + { + parameters.Add(new Kvp("authorization", AuthorizationServer)); + } + + if (Resource != null) + { + parameters.Add(new Kvp("resource", Resource)); + } + + if (Scope != null) + { + parameters.Add(new Kvp("scope", Scope)); + } + + return BuildChallengeString(ChallengeType, parameters.ToArray()); + } + + } + + + private class StaticChallengeResponseHandler : HttpMessageHandler + { + private HttpStatusCode _statusCode; + private string _challengeHeader; + + public StaticChallengeResponseHandler(HttpStatusCode statusCode, string challengeHeader) + { + _statusCode = statusCode; + _challengeHeader = challengeHeader; + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + return Task.FromResult(CreateResponse(request)); + } + + private HttpResponseMessage CreateResponse(HttpRequestMessage request) + { + var response = new HttpResponseMessage(_statusCode); + + response.RequestMessage = request; + + if(_challengeHeader != null) + { + response.Headers.Add("WWW-Authenticate", _challengeHeader); + } + + return response; + } + } + } +}