Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make use of NonValidated headers #1507

Merged
merged 1 commit into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions src/ReverseProxy/Forwarder/HttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -522,15 +522,12 @@ private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMes

private static void RestoreUpgradeHeaders(HttpContext context, HttpResponseMessage response)
{
if (response.Headers.TryGetValues(HeaderNames.Connection, out var connectionValues)
&& response.Headers.TryGetValues(HeaderNames.Upgrade, out var upgradeValues))
// We don't use RequestUtilities.TryGetValues for the Connection as we do want value validation.
// HttpHeaders.TryGetValues will handle the parsing and split the values for us.
if (RequestUtilities.TryGetValues(response.Headers, HeaderNames.Upgrade, out var upgradeValues)
&& response.Headers.TryGetValues(HeaderNames.Connection, out var connectionValues))
{
var upgradeStringValues = StringValues.Empty;
foreach (var value in upgradeValues)
{
upgradeStringValues = StringValues.Concat(upgradeStringValues, value);
}
context.Response.Headers.TryAdd(HeaderNames.Upgrade, upgradeStringValues);
context.Response.Headers.TryAdd(HeaderNames.Upgrade, upgradeValues);

foreach (var value in connectionValues)
{
Expand Down
24 changes: 18 additions & 6 deletions src/ReverseProxy/Forwarder/HttpTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, H
// remove the received Content-Length field prior to forwarding such
// a message downstream.
if (proxyResponse.Content != null
&& proxyResponse.Headers.TryGetValues(HeaderNames.TransferEncoding, out var _)
&& proxyResponse.Content.Headers.TryGetValues(HeaderNames.ContentLength, out var _))
&& RequestUtilities.ContainsHeader(proxyResponse.Headers, HeaderNames.TransferEncoding)
&& RequestUtilities.ContainsHeader(proxyResponse.Content.Headers, HeaderNames.ContentLength))
{
httpContext.Response.Headers.Remove(HeaderNames.ContentLength);
}
Expand Down Expand Up @@ -164,6 +164,20 @@ public virtual ValueTask TransformResponseTrailersAsync(HttpContext httpContext,

private static void CopyResponseHeaders(HttpHeaders source, IHeaderDictionary destination)
{
// We want to append to any prior values, if any.
// Not using Append here because it skips empty headers.
#if NET6_0_OR_GREATER
foreach (var header in source.NonValidated)
{
var headerName = header.Key;
if (RequestUtilities.ShouldSkipResponseHeader(headerName))
{
continue;
}

destination[headerName] = RequestUtilities.Concat(destination[headerName], header.Value);
}
#else
foreach (var header in source)
{
var headerName = header.Key;
Expand All @@ -174,10 +188,8 @@ private static void CopyResponseHeaders(HttpHeaders source, IHeaderDictionary de

Debug.Assert(header.Value is string[]);
var values = header.Value as string[] ?? header.Value.ToArray();
// We want to append to any prior values, if any.
// Not using Append here because it skips empty headers.
values = StringValues.Concat(destination[headerName], values);
destination[headerName] = values;
destination[headerName] = StringValues.Concat(destination[headerName], values);
}
#endif
}
}
92 changes: 92 additions & 0 deletions src/ReverseProxy/Forwarder/RequestUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Diagnostics;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Text;
using Microsoft.AspNetCore.Http;
Expand Down Expand Up @@ -344,4 +345,95 @@ internal static void RemoveHeader(HttpRequestMessage request, string headerName)
request.Headers.Remove(headerName);
}
}

#if NET6_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static StringValues Concat(in StringValues existing, in HeaderStringValues values)
{
if (values.Count <= 1)
{
return StringValues.Concat(existing, values.ToString());
}
else
{
return ConcatSlow(existing, values);
}

static StringValues ConcatSlow(in StringValues existing, in HeaderStringValues values)
{
Debug.Assert(values.Count > 1);

var count = existing.Count;
var newArray = new string[count + values.Count];

if (count == 1)
{
newArray[0] = existing.ToString();
}
else
{
existing.ToArray().CopyTo(newArray, 0);
}

foreach (var value in values)
{
newArray[count++] = value;
}
Debug.Assert(count == newArray.Length);

return newArray;
}
}
#endif

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool TryGetValues(HttpHeaders headers, string headerName, out StringValues values)
{
#if NET6_0_OR_GREATER
if (headers.NonValidated.TryGetValues(headerName, out var headerStringValues))
{
if (headerStringValues.Count <= 1)
{
values = headerStringValues.ToString();
}
else
{
values = ToArray(headerStringValues);
}
return true;
}

static StringValues ToArray(in HeaderStringValues values)
{
var array = new string[values.Count];
var i = 0;
foreach (var value in values)
{
array[i++] = value;
}
Debug.Assert(i == array.Length);
return array;
}
#else
if (headers.TryGetValues(headerName, out var headerValues))
{
Debug.Assert(headerValues is string[]);
values = headerValues as string[] ?? headerValues.ToArray();
return true;
}
#endif

values = default;
return false;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool ContainsHeader(HttpHeaders headers, string headerName)
{
#if NET6_0_OR_GREATER
return headers.NonValidated.Contains(headerName);
#else
return headers.TryGetValues(headerName, out _);
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public override ValueTask ApplyAsync(RequestTransformContext context)
{
if (context == null)
{
throw new System.ArgumentNullException(nameof(context));
throw new ArgumentNullException(nameof(context));
}

context.Query.Collection.Remove(Key);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using System;
using System.Threading.Tasks;
using Microsoft.Net.Http.Headers;
using Yarp.ReverseProxy.Forwarder;

namespace Yarp.ReverseProxy.Transforms;

Expand Down Expand Up @@ -34,12 +36,16 @@ public override ValueTask ApplyAsync(RequestTransformContext context)
if (!context.HeadersCopied)
{
// Don't override a custom host
context.ProxyRequest.Headers.Host ??= context.HttpContext.Request.Host.Value;
if (!RequestUtilities.ContainsHeader(context.ProxyRequest.Headers, HeaderNames.Host))
{
context.ProxyRequest.Headers.TryAddWithoutValidation(HeaderNames.Host, context.HttpContext.Request.Host.Value);
}
}
}
else if (context.HeadersCopied
// Don't remove a custom host, only the original
&& string.Equals(context.HttpContext.Request.Host.Value, context.ProxyRequest.Headers.Host, StringComparison.Ordinal))
&& RequestUtilities.TryGetValues(context.ProxyRequest.Headers, HeaderNames.Host, out var existingHost)
&& string.Equals(context.HttpContext.Request.Host.Value, existingHost.ToString(), StringComparison.Ordinal))
{
// Remove it after the copy, use the destination host instead.
context.ProxyRequest.Headers.Host = null;
Expand Down
4 changes: 2 additions & 2 deletions src/ReverseProxy/Transforms/RequestHeaderValueTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ public override ValueTask ApplyAsync(RequestTransformContext context)
throw new ArgumentNullException(nameof(context));
}

var existingValues = TakeHeader(context, HeaderName);

if (Append)
{
var existingValues = TakeHeader(context, HeaderName);
var values = StringValues.Concat(existingValues, Value);
AddHeader(context, HeaderName, values);
}
else
{
// Set
RemoveHeader(context, HeaderName);
AddHeader(context, HeaderName, Value);
}

Expand Down
26 changes: 13 additions & 13 deletions src/ReverseProxy/Transforms/RequestTransform.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System;
using System.Threading.Tasks;
using Microsoft.Extensions.Primitives;
using Yarp.ReverseProxy.Forwarder;
Expand Down Expand Up @@ -30,19 +31,18 @@ public static StringValues TakeHeader(RequestTransformContext context, string he
{
if (string.IsNullOrEmpty(headerName))
{
throw new System.ArgumentException($"'{nameof(headerName)}' cannot be null or empty.", nameof(headerName));
throw new ArgumentException($"'{nameof(headerName)}' cannot be null or empty.", nameof(headerName));
}

var existingValues = StringValues.Empty;
if (context.ProxyRequest.Headers.TryGetValues(headerName, out var values))
var proxyRequest = context.ProxyRequest;

if (RequestUtilities.TryGetValues(proxyRequest.Headers, headerName, out var existingValues))
{
context.ProxyRequest.Headers.Remove(headerName);
existingValues = (string[])values;
proxyRequest.Headers.Remove(headerName);
}
else if (context.ProxyRequest.Content?.Headers.TryGetValues(headerName, out values) ?? false)
else if (proxyRequest.Content is { } content && RequestUtilities.TryGetValues(content.Headers, headerName, out existingValues))
{
context.ProxyRequest.Content.Headers.Remove(headerName);
existingValues = (string[])values!;
content.Headers.Remove(headerName);
}
else if (!context.HeadersCopied)
{
Expand All @@ -59,30 +59,30 @@ public static void AddHeader(RequestTransformContext context, string headerName,
{
if (context is null)
{
throw new System.ArgumentNullException(nameof(context));
throw new ArgumentNullException(nameof(context));
}

if (string.IsNullOrEmpty(headerName))
{
throw new System.ArgumentException($"'{nameof(headerName)}' cannot be null or empty.", nameof(headerName));
throw new ArgumentException($"'{nameof(headerName)}' cannot be null or empty.", nameof(headerName));
}

RequestUtilities.AddHeader(context.ProxyRequest, headerName, values);
}

/// <summary>
/// Removed the given header from the HttpRequestMessage or HttpContent where applicable.
/// Removes the given header from the HttpRequestMessage or HttpContent where applicable.
/// </summary>
public static void RemoveHeader(RequestTransformContext context, string headerName)
{
if (context is null)
{
throw new System.ArgumentNullException(nameof(context));
throw new ArgumentNullException(nameof(context));
}

if (string.IsNullOrEmpty(headerName))
{
throw new System.ArgumentException($"'{nameof(headerName)}' cannot be null or empty.", nameof(headerName));
throw new ArgumentException($"'{nameof(headerName)}' cannot be null or empty.", nameof(headerName));
}

RequestUtilities.RemoveHeader(context.ProxyRequest, headerName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ public override ValueTask ApplyAsync(ResponseTransformContext context)
if (Condition == ResponseCondition.Always
|| Success(context) == (Condition == ResponseCondition.Success))
{
var existingHeader = TakeHeader(context, HeaderName);
if (Append)
{
var existingHeader = TakeHeader(context, HeaderName);
var value = StringValues.Concat(existingHeader, Value);
SetHeader(context, HeaderName, value);
}
Expand Down
25 changes: 21 additions & 4 deletions src/ReverseProxy/Transforms/ResponseHeadersAllowedTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
using Yarp.ReverseProxy.Forwarder;

namespace Yarp.ReverseProxy.Transforms;

Expand Down Expand Up @@ -60,17 +61,33 @@ public override ValueTask ApplyAsync(ResponseTransformContext context)
return default;
}

// See https://github.com/microsoft/reverse-proxy/blob/51d797986b1fea03500a1ad173d13a1176fb5552/src/ReverseProxy/Forwarder/HttpTransformer.cs#L102-L115
// See https://github.com/microsoft/reverse-proxy/blob/main/src/ReverseProxy/Forwarder/HttpTransformer.cs#:~:text=void-,CopyResponseHeaders
private void CopyResponseHeaders(HttpHeaders source, IHeaderDictionary destination)
{
#if NET6_0_OR_GREATER
foreach (var header in source.NonValidated)
{
var headerName = header.Key;
if (!AllowedHeadersSet.Contains(headerName))
{
continue;
}

destination[headerName] = RequestUtilities.Concat(destination[headerName], header.Value);
}
#else
foreach (var header in source)
{
var headerName = header.Key;
if (AllowedHeadersSet.Contains(headerName))
if (!AllowedHeadersSet.Contains(headerName))
{
Debug.Assert(header.Value is string[]);
destination.Append(headerName, header.Value as string[] ?? header.Value.ToArray());
continue;
}

Debug.Assert(header.Value is string[]);
var values = header.Value as string[] ?? header.Value.ToArray();
destination[headerName] = StringValues.Concat(destination[headerName], values);
}
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ public override ValueTask ApplyAsync(ResponseTrailersTransformContext context)
if (Condition == ResponseCondition.Always
|| Success(context) == (Condition == ResponseCondition.Success))
{
var existingHeader = TakeHeader(context, HeaderName);
if (Append)
{
var existingHeader = TakeHeader(context, HeaderName);
var value = StringValues.Concat(existingHeader, Value);
SetHeader(context, HeaderName, value);
}
Expand Down
Loading