-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ChallengeCacheHandler to update challenge cache on 401 response #6950
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Net.Http; | ||
using System.Text; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
|
||
namespace Microsoft.Azure.KeyVault.Customized.Authentication | ||
{ | ||
/// <summary> | ||
/// A <see cref="DelegatingHandler"/> which will update the <see cref="HttpBearerChallengeCache"/> when a 401 response is returned with a WWW-Authenticate bearer challenge header. | ||
/// </summary> | ||
public class ChallengeCacheHandler : MessageProcessingHandler | ||
{ | ||
/// <summary> | ||
/// Returns the specified request without performing any processing | ||
/// </summary> | ||
/// <param name="request"></param> | ||
/// <param name="cancellationToken"></param> | ||
/// <returns></returns> | ||
protected override HttpRequestMessage ProcessRequest(HttpRequestMessage request, CancellationToken cancellationToken) | ||
{ | ||
return request; | ||
} | ||
/// <summary> | ||
/// Updates the <see cref="HttpBearerChallengeCache"/> when the specified response has a return code of 401 | ||
/// </summary> | ||
/// <param name="response">The response to evaluate</param> | ||
/// <param name="cancellationToken">The cancellation token</param> | ||
/// <returns></returns> | ||
protected override HttpResponseMessage ProcessResponse(HttpResponseMessage response, CancellationToken cancellationToken) | ||
{ | ||
// if the response came back as 401 and the response contains a bearer challenge update the challenge cache | ||
if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized) | ||
{ | ||
HttpBearerChallenge challenge = HttpBearerChallenge.GetBearerChallengeFromResponse(response); | ||
|
||
if (challenge != null) | ||
{ | ||
// Update challenge cache | ||
HttpBearerChallengeCache.GetInstance().SetChallengeForURL(response.RequestMessage.RequestUri, challenge); | ||
} | ||
} | ||
|
||
return response; | ||
} | ||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
using Microsoft.Azure.KeyVault.Customized.Authentication; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Net; | ||
using System.Net.Http; | ||
using System.Text; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using Xunit; | ||
using Kvp = System.Collections.Generic.KeyValuePair<string, string>; | ||
|
||
namespace Microsoft.Azure.KeyVault.Tests | ||
{ | ||
public class ChallengeCacheHandlerTests | ||
{ | ||
[Fact] | ||
public async Task CacheAddOn401Async() | ||
{ | ||
var handler = new ChallengeCacheHandler(); | ||
|
||
var expChallenge = MockChallenge.Create(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString()); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var requestUrl = CreateMockUrl(2); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); | ||
|
||
AssertChallengeCacheEntry(requestUrl, expChallenge); | ||
} | ||
|
||
[Fact] | ||
public async Task CacheUpdateOn401Async() | ||
{ | ||
string requestUrlBase = CreateMockUrl(); | ||
|
||
string requestUrl1 = CreateMockUrl(requestUrlBase, 2); | ||
|
||
HttpBearerChallengeCache.GetInstance().SetChallengeForURL(new Uri(requestUrl1), MockChallenge.Create().ToHttpBearerChallenge(requestUrl1)); | ||
|
||
string requestUrl2 = CreateMockUrl(requestUrlBase, 2); | ||
|
||
var handler = new ChallengeCacheHandler(); | ||
|
||
var expChallenge = MockChallenge.Create(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, expChallenge.ToString()); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl2)); | ||
|
||
AssertChallengeCacheEntry(requestUrl1, expChallenge); | ||
|
||
AssertChallengeCacheEntry(requestUrl2, expChallenge); | ||
} | ||
|
||
[Fact] | ||
public async Task CacheNotUpdatedNoChallengeAsync() | ||
{ | ||
var handler = new ChallengeCacheHandler(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, null); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var requestUrl = CreateMockUrl(2); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); | ||
|
||
AssertChallengeCacheEntry(requestUrl, null); | ||
} | ||
|
||
[Fact] | ||
public async Task CacheNotUpdatedNon401Aysnc() | ||
{ | ||
var handler = new ChallengeCacheHandler(); | ||
|
||
var expChallenge = MockChallenge.Create(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Forbidden, expChallenge.ToString()); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var requestUrl = CreateMockUrl(2); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); | ||
|
||
AssertChallengeCacheEntry(requestUrl, null); | ||
} | ||
|
||
[Fact] | ||
public async Task CacheNotUpdatedNonBearerChallengeAsync() | ||
{ | ||
var handler = new ChallengeCacheHandler(); | ||
|
||
handler.InnerHandler = new StaticChallengeResponseHandler(HttpStatusCode.Unauthorized, MockChallenge.Create("PoP").ToString()); | ||
|
||
var client = new HttpClient(handler); | ||
|
||
var requestUrl = CreateMockUrl(2); | ||
|
||
var _ = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, requestUrl)); | ||
|
||
AssertChallengeCacheEntry(requestUrl, null); | ||
} | ||
|
||
private static void AssertChallengeCacheEntry(string requestUrl, MockChallenge expChallenge) | ||
{ | ||
var actChallenge = HttpBearerChallengeCache.GetInstance().GetChallengeForURL(new Uri(requestUrl)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there any other tests these tests might conflict with when running in parallel? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These shouldn't interfere with other tests as all the URLs that I'm using are guid based. |
||
|
||
if (expChallenge == null) | ||
{ | ||
Assert.Null(actChallenge); | ||
} | ||
else | ||
{ | ||
Assert.NotNull(actChallenge); | ||
|
||
Assert.Equal(expChallenge.AuthorizationServer, actChallenge.AuthorizationServer); | ||
|
||
Assert.Equal(expChallenge.Scope, actChallenge.Scope); | ||
|
||
Assert.Equal(expChallenge.Resource, actChallenge.Resource); | ||
} | ||
} | ||
|
||
private static string BuildChallengeString(params Kvp[] parameters) | ||
{ | ||
// remove the trailing ',' and return | ||
return BuildChallengeString("Bearer", parameters); | ||
} | ||
|
||
private static string BuildChallengeString(string challengeType, params Kvp[] parameters) | ||
{ | ||
StringBuilder buff = new StringBuilder(challengeType).Append(" "); | ||
|
||
foreach (var kvp in parameters) | ||
{ | ||
buff.Append(kvp.Key).Append("=\"").Append(kvp.Value).Append("\","); | ||
} | ||
|
||
// remove the trailing ',' and return | ||
return buff.Remove(buff.Length - 1, 1).ToString(); | ||
} | ||
|
||
public static string CreateMockUrl(int pathCount = 0) | ||
{ | ||
return CreateMockUrl("https://" + Guid.NewGuid().ToString("N"), pathCount); | ||
} | ||
|
||
public static string CreateMockUrl(string baseUrl, int pathCount = 0) | ||
{ | ||
var buff = new StringBuilder(baseUrl); | ||
|
||
if(baseUrl.EndsWith("/")) | ||
{ | ||
buff.Remove(buff.Length - 1, 1); | ||
} | ||
|
||
for (int i = 0; i < pathCount; i++) | ||
{ | ||
buff.Append("/").Append(Guid.NewGuid().ToString("N")); | ||
} | ||
|
||
return buff.ToString(); | ||
} | ||
|
||
private class MockChallenge | ||
{ | ||
public string ChallengeType { get; set; } | ||
|
||
public string AuthorizationServer { get; set; } | ||
|
||
public string Resource { get; set; } | ||
|
||
public string Scope { get; set; } | ||
|
||
public static MockChallenge Create(string challengeType = null, string authority = null, string resource = null, string scope = null) | ||
{ | ||
var mock = new MockChallenge(); | ||
mock.ChallengeType = challengeType ?? "Bearer"; | ||
mock.AuthorizationServer = authority ?? CreateMockUrl(1); | ||
mock.Resource = resource ?? CreateMockUrl(0); | ||
mock.Scope = scope ?? mock.Resource + "/.default"; | ||
return mock; | ||
} | ||
|
||
public HttpBearerChallenge ToHttpBearerChallenge(string requestUrl) | ||
{ | ||
return new HttpBearerChallenge(new Uri(requestUrl), ToString()); | ||
} | ||
|
||
public override string ToString() | ||
{ | ||
var parameters = new List<Kvp>(); | ||
|
||
StringBuilder buff = new StringBuilder(); | ||
|
||
if(AuthorizationServer != null) | ||
{ | ||
parameters.Add(new Kvp("authorization", AuthorizationServer)); | ||
} | ||
|
||
if (Resource != null) | ||
{ | ||
parameters.Add(new Kvp("resource", Resource)); | ||
} | ||
|
||
if (Scope != null) | ||
{ | ||
parameters.Add(new Kvp("scope", Scope)); | ||
} | ||
|
||
return BuildChallengeString(ChallengeType, parameters.ToArray()); | ||
} | ||
|
||
} | ||
|
||
|
||
private class StaticChallengeResponseHandler : HttpMessageHandler | ||
{ | ||
private HttpStatusCode _statusCode; | ||
private string _challengeHeader; | ||
|
||
public StaticChallengeResponseHandler(HttpStatusCode statusCode, string challengeHeader) | ||
{ | ||
_statusCode = statusCode; | ||
_challengeHeader = challengeHeader; | ||
} | ||
|
||
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) | ||
{ | ||
return Task.FromResult(CreateResponse(request)); | ||
} | ||
|
||
private HttpResponseMessage CreateResponse(HttpRequestMessage request) | ||
{ | ||
var response = new HttpResponseMessage(_statusCode); | ||
|
||
response.RequestMessage = request; | ||
|
||
if(_challengeHeader != null) | ||
{ | ||
response.Headers.Add("WWW-Authenticate", _challengeHeader); | ||
} | ||
|
||
return response; | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I presume having the class public is intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes unfortunately. When constructing the key vault client the user has the option of giving us an HttpClient, in which case we our handler won't get added. So I made the class public so that users doing this would have a way to handle this bug. It's possible that I could have made it private and had users re-implement it, but I think this gives a better experience and also leaves us the ability to tweak the behavior.