Skip to content

Commit

Permalink
Adding ChallengeCacheHandler to update challenge cache on 401 response (
Browse files Browse the repository at this point in the history
Azure#6950)

* Adding ChallengeCacheHandler to update challenge cache on 401 responses with auth challenge

* updates addressing PR feedback
  • Loading branch information
schaabs authored Jul 17, 2019
1 parent 9904eab commit c593eee
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Azure.KeyVault.Customized.Authentication
{
/// <summary>
/// A <see cref="DelegatingHandler"/> which will update the <see cref="HttpBearerChallengeCache"/> when a 401 response is returned with a WWW-Authenticate bearer challenge header.
/// </summary>
public class ChallengeCacheHandler : MessageProcessingHandler
{
/// <summary>
/// Returns the specified request without performing any processing
/// </summary>
/// <param name="request"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected override HttpRequestMessage ProcessRequest(HttpRequestMessage request, CancellationToken cancellationToken)
{
return request;
}
/// <summary>
/// Updates the <see cref="HttpBearerChallengeCache"/> when the specified response has a return code of 401
/// </summary>
/// <param name="response">The response to evaluate</param>
/// <param name="cancellationToken">The cancellation token</param>
/// <returns></returns>
protected override HttpResponseMessage ProcessResponse(HttpResponseMessage response, CancellationToken cancellationToken)
{
// if the response came back as 401 and the response contains a bearer challenge update the challenge cache
if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized)
{
HttpBearerChallenge challenge = HttpBearerChallenge.GetBearerChallengeFromResponse(response);

if (challenge != null)
{
// Update challenge cache
HttpBearerChallengeCache.GetInstance().SetChallengeForURL(response.RequestMessage.RequestUri, challenge);
}
}

return response;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Linq;

namespace Microsoft.Azure.KeyVault
{
Expand Down Expand Up @@ -36,6 +38,20 @@ public static bool IsBearerChallenge(string challenge)
return true;
}

internal static HttpBearerChallenge GetBearerChallengeFromResponse(HttpResponseMessage response)
{
if (response == null) return null;

string challenge = response?.Headers.WwwAuthenticate.FirstOrDefault()?.ToString();

if (!string.IsNullOrEmpty(challenge) && IsBearerChallenge(challenge))
{
return new HttpBearerChallenge(response.RequestMessage.RequestUri, challenge);
}

return null;
}

private Dictionary<string, string> _parameters = null;
private string _sourceAuthority = null;
private Uri _sourceUri = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.Azure.KeyVault
using Rest.Serialization;
using Newtonsoft.Json;
using System.Linq;
using Microsoft.Azure.KeyVault.Customized.Authentication;


/// <summary>
Expand Down Expand Up @@ -226,5 +227,20 @@ public KeyVaultClient(KeyVaultCredential credential, HttpClient httpClient)
}
return _result;
}

/// <summary>
/// Overrides the base <see cref="CreateHttpHandlerPipeline"/> to add a <see cref="ChallengeCacheHandler"/> as the outermost <see cref="DelegatingHandler"/>.
/// </summary>
/// <param name="httpClientHandler">The base handler for the client</param>
/// <param name="handlers">The handler pipeline for the client</param>
/// <returns></returns>
protected override DelegatingHandler CreateHttpHandlerPipeline(HttpClientHandler httpClientHandler, params DelegatingHandler[] handlers)
{
var challengeCacheHandler = new ChallengeCacheHandler();

challengeCacheHandler.InnerHandler = base.CreateHttpHandlerPipeline(httpClientHandler, handlers);

return challengeCacheHandler;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
using Microsoft.Azure.KeyVault.Customized.Authentication;
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Kvp = System.Collections.Generic.KeyValuePair<string, string>;

namespace Microsoft.Azure.KeyVault.Tests
{
public class ChallengeCacheHandlerTests
{
[Fact]
public async Task CacheAddOn401Async()
{
var handler = new ChallengeCacheHandler();

var expChallenge = MockChallenge.Create();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString());

var client = new HttpClient(handler);

var requestUrl = CreateMockUrl(2);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));

AssertChallengeCacheEntry(requestUrl, expChallenge);
}

[Fact]
public async Task CacheUpdateOn401Async()
{
string requestUrlBase = CreateMockUrl();

string requestUrl1 = CreateMockUrl(requestUrlBase, 2);

HttpBearerChallengeCache.GetInstance().SetChallengeForURL(new Uri(requestUrl1), MockChallenge.Create().ToHttpBearerChallenge(requestUrl1));

string requestUrl2 = CreateMockUrl(requestUrlBase, 2);

var handler = new ChallengeCacheHandler();

var expChallenge = MockChallenge.Create();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString());

var client = new HttpClient(handler);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl2));

AssertChallengeCacheEntry(requestUrl1, expChallenge);

AssertChallengeCacheEntry(requestUrl2, expChallenge);
}

[Fact]
public async Task CacheNotUpdatedNoChallengeAsync()
{
var handler = new ChallengeCacheHandler();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, null);

var client = new HttpClient(handler);

var requestUrl = CreateMockUrl(2);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));

AssertChallengeCacheEntry(requestUrl, null);
}

[Fact]
public async Task CacheNotUpdatedNon401Aysnc()
{
var handler = new ChallengeCacheHandler();

var expChallenge = MockChallenge.Create();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Forbidden, expChallenge.ToString());

var client = new HttpClient(handler);

var requestUrl = CreateMockUrl(2);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));

AssertChallengeCacheEntry(requestUrl, null);
}

[Fact]
public async Task CacheNotUpdatedNonBearerChallengeAsync()
{
var handler = new ChallengeCacheHandler();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, MockChallenge.Create("PoP").ToString());

var client = new HttpClient(handler);

var requestUrl = CreateMockUrl(2);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));

AssertChallengeCacheEntry(requestUrl, null);
}

private static void AssertChallengeCacheEntry(string requestUrl, MockChallenge expChallenge)
{
var actChallenge = HttpBearerChallengeCache.GetInstance().GetChallengeForURL(new Uri(requestUrl));

if (expChallenge == null)
{
Assert.Null(actChallenge);
}
else
{
Assert.NotNull(actChallenge);

Assert.Equal(expChallenge.AuthorizationServer, actChallenge.AuthorizationServer);

Assert.Equal(expChallenge.Scope, actChallenge.Scope);

Assert.Equal(expChallenge.Resource, actChallenge.Resource);
}
}

private static string BuildChallengeString(params Kvp[] parameters)
{
// remove the trailing ',' and return
return BuildChallengeString("Bearer", parameters);
}

private static string BuildChallengeString(string challengeType, params Kvp[] parameters)
{
StringBuilder buff = new StringBuilder(challengeType).Append(" ");

foreach (var kvp in parameters)
{
buff.Append(kvp.Key).Append("=\"").Append(kvp.Value).Append("\",");
}

// remove the trailing ',' and return
return buff.Remove(buff.Length - 1, 1).ToString();
}

public static string CreateMockUrl(int pathCount = 0)
{
return CreateMockUrl("https://" + Guid.NewGuid().ToString("N"), pathCount);
}

public static string CreateMockUrl(string baseUrl, int pathCount = 0)
{
var buff = new StringBuilder(baseUrl);

if(baseUrl.EndsWith("/"))
{
buff.Remove(buff.Length - 1, 1);
}

for (int i = 0; i < pathCount; i++)
{
buff.Append("/").Append(Guid.NewGuid().ToString("N"));
}

return buff.ToString();
}

private class MockChallenge
{
public string ChallengeType { get; set; }

public string AuthorizationServer { get; set; }

public string Resource { get; set; }

public string Scope { get; set; }

public static MockChallenge Create(string challengeType = null, string authority = null, string resource = null, string scope = null)
{
var mock = new MockChallenge();
mock.ChallengeType = challengeType ?? "Bearer";
mock.AuthorizationServer = authority ?? CreateMockUrl(1);
mock.Resource = resource ?? CreateMockUrl(0);
mock.Scope = scope ?? mock.Resource + "/.default";
return mock;
}

public HttpBearerChallenge ToHttpBearerChallenge(string requestUrl)
{
return new HttpBearerChallenge(new Uri(requestUrl), ToString());
}

public override string ToString()
{
var parameters = new List<Kvp>();

StringBuilder buff = new StringBuilder();

if(AuthorizationServer != null)
{
parameters.Add(new Kvp("authorization", AuthorizationServer));
}

if (Resource != null)
{
parameters.Add(new Kvp("resource", Resource));
}

if (Scope != null)
{
parameters.Add(new Kvp("scope", Scope));
}

return BuildChallengeString(ChallengeType, parameters.ToArray());
}

}


private class StaticChallengeResponseHandler : HttpMessageHandler
{
private HttpStatusCode _statusCode;
private string _challengeHeader;

public StaticChallengeResponseHandler(HttpStatusCode statusCode, string challengeHeader)
{
_statusCode = statusCode;
_challengeHeader = challengeHeader;
}

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
return Task.FromResult(CreateResponse(request));
}

private HttpResponseMessage CreateResponse(HttpRequestMessage request)
{
var response = new HttpResponseMessage(_statusCode);

response.RequestMessage = request;

if(_challengeHeader != null)
{
response.Headers.Add("WWW-Authenticate", _challengeHeader);
}

return response;
}
}
}
}

0 comments on commit c593eee

Please sign in to comment.