Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ChallengeCacheHandler to update challenge cache on 401 response #6950

Merged
merged 3 commits into from
Jul 17, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Azure.KeyVault.Customized.Authentication
{
/// <summary>
/// A <see cref="DelegatingHandler"/> which will update the <see cref="HttpBearerChallengeCache"/> when a 401 response is returned with a WWW-Authenticate bearer challenge header.
/// </summary>
public class ChallengeCacheHandler : MessageProcessingHandler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume having the class public is intentional?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes unfortunately. When constructing the key vault client the user has the option of giving us an HttpClient, in which case we our handler won't get added. So I made the class public so that users doing this would have a way to handle this bug. It's possible that I could have made it private and had users re-implement it, but I think this gives a better experience and also leaves us the ability to tweak the behavior.

{
/// <summary>
/// Returns the specified request without performing any processing
/// </summary>
/// <param name="request"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected override HttpRequestMessage ProcessRequest(HttpRequestMessage request, CancellationToken cancellationToken)
{
return request;
}
/// <summary>
/// Updates the <see cref="HttpBearerChallengeCache"/> when the specified response has a return code of 401
/// </summary>
/// <param name="response">The response to evaluate</param>
/// <param name="cancellationToken">The cancellation token</param>
/// <returns></returns>
protected override HttpResponseMessage ProcessResponse(HttpResponseMessage response, CancellationToken cancellationToken)
{
// if the response came back as 401 and the response contains a bearer challenge update the challenge cache
if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized)
{
HttpBearerChallenge challenge = HttpBearerChallenge.GetBearerChallengeFromResponse(response);

if (challenge != null)
{
// Update challenge cache
HttpBearerChallengeCache.GetInstance().SetChallengeForURL(response.RequestMessage.RequestUri, challenge);
}
}

return response;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Linq;

namespace Microsoft.Azure.KeyVault
{
Expand Down Expand Up @@ -36,6 +38,20 @@ public static bool IsBearerChallenge(string challenge)
return true;
}

internal static HttpBearerChallenge GetBearerChallengeFromResponse(HttpResponseMessage response)
{
if (response == null) return null;

string challenge = response?.Headers.WwwAuthenticate.FirstOrDefault()?.ToString();

if (!string.IsNullOrEmpty(challenge) && IsBearerChallenge(challenge))
{
return new HttpBearerChallenge(response.RequestMessage.RequestUri, challenge);
}

return null;
}

private Dictionary<string, string> _parameters = null;
private string _sourceAuthority = null;
private Uri _sourceUri = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.Azure.KeyVault
using Rest.Serialization;
using Newtonsoft.Json;
using System.Linq;
using Microsoft.Azure.KeyVault.Customized.Authentication;


/// <summary>
Expand Down Expand Up @@ -226,5 +227,20 @@ public KeyVaultClient(KeyVaultCredential credential, HttpClient httpClient)
}
return _result;
}

/// <summary>
/// Overrides the base <see cref="CreateHttpHandlerPipeline"/> to add a <see cref="ChallengeCacheHandler"/> as the outermost <see cref="DelegatingHandler"/>.
/// </summary>
/// <param name="httpClientHandler">The base handler for the client</param>
/// <param name="handlers">The handler pipeline for the client</param>
/// <returns></returns>
protected override DelegatingHandler CreateHttpHandlerPipeline(HttpClientHandler httpClientHandler, params DelegatingHandler[] handlers)
{
var challengeCacheHandler = new ChallengeCacheHandler();

challengeCacheHandler.InnerHandler = base.CreateHttpHandlerPipeline(httpClientHandler, handlers);

return challengeCacheHandler;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
using Microsoft.Azure.KeyVault.Customized.Authentication;
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Kvp = System.Collections.Generic.KeyValuePair<string, string>;

namespace Microsoft.Azure.KeyVault.Tests
{
public class ChallengeCacheHandlerTests
{
[Fact]
public async Task CacheAddOn401Async()
{
var handler = new ChallengeCacheHandler();

var expChallenge = MockChallenge.Create();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString());

var client = new HttpClient(handler);

var requestUrl = CreateMockUrl(2);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));

AssertChallengeCacheEntry(requestUrl, expChallenge);
}

[Fact]
public async Task CacheUpdateOn401Async()
{
string requestUrlBase = CreateMockUrl();

string requestUrl1 = CreateMockUrl(requestUrlBase, 2);

HttpBearerChallengeCache.GetInstance().SetChallengeForURL(new Uri(requestUrl1), MockChallenge.Create().ToHttpBearerChallenge(requestUrl1));

string requestUrl2 = CreateMockUrl(requestUrlBase, 2);

var handler = new ChallengeCacheHandler();

var expChallenge = MockChallenge.Create();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString());

var client = new HttpClient(handler);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl2));

AssertChallengeCacheEntry(requestUrl1, expChallenge);

AssertChallengeCacheEntry(requestUrl2, expChallenge);
}

[Fact]
public async Task CacheNotUpdatedNoChallengeAsync()
{
var handler = new ChallengeCacheHandler();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, null);

var client = new HttpClient(handler);

var requestUrl = CreateMockUrl(2);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));

AssertChallengeCacheEntry(requestUrl, null);
}

[Fact]
public async Task CacheNotUpdatedNon401Aysnc()
{
var handler = new ChallengeCacheHandler();

var expChallenge = MockChallenge.Create();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Forbidden, expChallenge.ToString());

var client = new HttpClient(handler);

var requestUrl = CreateMockUrl(2);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));

AssertChallengeCacheEntry(requestUrl, null);
}

[Fact]
public async Task CacheNotUpdatedNonBearerChallengeAsync()
{
var handler = new ChallengeCacheHandler();

handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, MockChallenge.Create("PoP").ToString());

var client = new HttpClient(handler);

var requestUrl = CreateMockUrl(2);

var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl));

AssertChallengeCacheEntry(requestUrl, null);
}

private static void AssertChallengeCacheEntry(string requestUrl, MockChallenge expChallenge)
{
var actChallenge = HttpBearerChallengeCache.GetInstance().GetChallengeForURL(new Uri(requestUrl));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any other tests these tests might conflict with when running in parallel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These shouldn't interfere with other tests as all the URLs that I'm using are guid based.


if (expChallenge == null)
{
Assert.Null(actChallenge);
}
else
{
Assert.NotNull(actChallenge);

Assert.Equal(expChallenge.AuthorizationServer, actChallenge.AuthorizationServer);

Assert.Equal(expChallenge.Scope, actChallenge.Scope);

Assert.Equal(expChallenge.Resource, actChallenge.Resource);
}
}

private static string BuildChallengeString(params Kvp[] parameters)
{
// remove the trailing ',' and return
return BuildChallengeString("Bearer", parameters);
}

private static string BuildChallengeString(string challengeType, params Kvp[] parameters)
{
StringBuilder buff = new StringBuilder(challengeType).Append(" ");

foreach (var kvp in parameters)
{
buff.Append(kvp.Key).Append("=\"").Append(kvp.Value).Append("\",");
}

// remove the trailing ',' and return
return buff.Remove(buff.Length - 1, 1).ToString();
}

public static string CreateMockUrl(int pathCount = 0)
{
return CreateMockUrl("https://" + Guid.NewGuid().ToString().Replace('-', '.'), pathCount);
}

public static string CreateMockUrl(string baseUrl, int pathCount = 0)
{
var buff = new StringBuilder(baseUrl);

if(baseUrl.EndsWith("/"))
{
buff.Remove(buff.Length - 1, 1);
}

for (int i = 0; i < pathCount; i++)
{
buff.Append("/").Append(Guid.NewGuid().ToString().Replace("-", ""));
pakrym marked this conversation as resolved.
Show resolved Hide resolved
}

return buff.ToString();
}

private class MockChallenge
{
public string ChallengeType { get; set; }

public string AuthorizationServer { get; set; }

public string Resource { get; set; }

public string Scope { get; set; }

public static MockChallenge Create(string challengeType = null, string authority = null, string resource = null, string scope = null)
{
var mock = new MockChallenge();
mock.ChallengeType = challengeType ?? "Bearer";
mock.AuthorizationServer = authority ?? CreateMockUrl(1);
mock.Resource = resource ?? CreateMockUrl(0);
mock.Scope = scope ?? mock.Resource + "/.default";
return mock;
}

public HttpBearerChallenge ToHttpBearerChallenge(string requestUrl)
{
return new HttpBearerChallenge(new Uri(requestUrl), ToString());
}

public override string ToString()
{
var parameters = new List<Kvp>();

StringBuilder buff = new StringBuilder();

if(AuthorizationServer != null)
{
parameters.Add(new Kvp("authorization", AuthorizationServer));
}

if (Resource != null)
{
parameters.Add(new Kvp("resource", Resource));
}

if (Scope != null)
{
parameters.Add(new Kvp("scope", Scope));
}

return BuildChallengeString(ChallengeType, parameters.ToArray());
}

}


private class StaticChallengeResponseHandler : HttpMessageHandler
{
private HttpStatusCode _statusCode;
private string _challengeHeader;

public StaticChallengeResponseHandler(HttpStatusCode statusCode, string challengeHeader)
{
_statusCode = statusCode;
_challengeHeader = challengeHeader;
}

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
return Task.FromResult(CreateResponse(request));
}

private HttpResponseMessage CreateResponse(HttpRequestMessage request)
{
var response = new HttpResponseMessage(_statusCode);

response.RequestMessage = request;

if(_challengeHeader != null)
{
response.Headers.Add("WWW-Authenticate", _challengeHeader);
}

return response;
}
}
}
}