Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite HttpFaultInjector to use ILogger and built-in ASP.NET Core logging #7476

Merged
merged 5 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@

<ItemGroup>
<PackageReference Include="CommandLineParser" Version="2.8.0" />
<PackageReference Include="OpenTelemetry.Exporter.Console" Version="1.7.0" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
using System.Threading.Tasks;
using System.Threading;
using Microsoft.Extensions.Logging;
using System;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
using System.Collections.Generic;
using System.Net.Http;
using System.Net.Sockets;
using System.Reflection;
using System.Linq;
using System.IO;
using System.Buffers;
using System.Diagnostics;

namespace Azure.Sdk.Tools.HttpFaultInjector
{
public class FaultInjectingMiddleware
{
private readonly ILogger<FaultInjectingMiddleware> _logger;
private readonly HttpClient _httpClient;
public FaultInjectingMiddleware(RequestDelegate _, IHttpClientFactory httpClientFactory, ILogger<FaultInjectingMiddleware> logger)
{
_logger = logger;
_httpClient = httpClientFactory.CreateClient("upstream");
}

public async Task InvokeAsync(HttpContext context)
{
string faultHeaderValue = context.Request.Headers[Utils.ResponseSelectionHeader];
string upstreamBaseUri = context.Request.Headers[Utils.UpstreamBaseUriHeader];

if (ValidateOrReadFaultMode(faultHeaderValue, out var fault))
{
await ProxyResponse(context, upstreamBaseUri, fault, context.RequestAborted);
}
else
{
context.Response.StatusCode = 400;
}
}


private async Task<UpstreamResponse> SendUpstreamRequest(HttpRequest request, string uri, CancellationToken cancellationToken)
{
var incomingUriBuilder = new UriBuilder()
{
Scheme = request.Scheme,
Host = request.Host.Host,
Path = request.Path.Value,
Query = request.QueryString.Value,
};
if (request.Host.Port.HasValue)
{
incomingUriBuilder.Port = request.Host.Port.Value;
}
var incomingUri = incomingUriBuilder.Uri;

var upstreamUriBuilder = new UriBuilder(uri)
{
Path = request.Path.Value,
Query = request.QueryString.Value,
};

var upstreamUri = upstreamUriBuilder.Uri;

using (var upstreamRequest = new HttpRequestMessage(new HttpMethod(request.Method), upstreamUri))
{
if (request.ContentLength > 0)
{
upstreamRequest.Content = new StreamContent(request.Body);
foreach (var header in request.Headers.Where(h => Utils.ContentRequestHeaders.Contains(h.Key)))
{
upstreamRequest.Content.Headers.Add(header.Key, values: header.Value);
}
}

foreach (var header in request.Headers.Where(h => !Utils.ExcludedRequestHeaders.Contains(h.Key) && !Utils.ContentRequestHeaders.Contains(h.Key)))
{
if (!upstreamRequest.Headers.TryAddWithoutValidation(header.Key, values: header.Value))
{
throw new InvalidOperationException($"Could not add header {header.Key} with value {header.Value}");
}
}

var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest);
var headers = new List<KeyValuePair<string, IEnumerable<string>>>();
// Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement
// your own chunking.
headers.AddRange(upstreamResponseMessage.Headers.Where(header => !string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)));
headers.AddRange(upstreamResponseMessage.Content.Headers);

var upstreamResponse = await UpstreamResponseFromHttpContent(upstreamResponseMessage.Content, cancellationToken);
upstreamResponse.StatusCode = (int)upstreamResponseMessage.StatusCode;
upstreamResponse.Headers = headers.Select(h => new KeyValuePair<string, StringValues>(h.Key, h.Value.ToArray()));

return upstreamResponse;
}
}


private async Task<UpstreamResponse> UpstreamResponseFromHttpContent(HttpContent content, CancellationToken cancellationToken)
{
if (content.Headers.ContentLength == null)
{
MemoryStream contentStream = await BufferContentAsync(content, cancellationToken);
// we no longer need that content and can let the connection go back to the pool.
content.Dispose();
return new UpstreamResponse(contentStream);
}

return new UpstreamResponse(content);
}

private async Task<MemoryStream> BufferContentAsync(HttpContent content, CancellationToken cancellationToken)
{
Debug.Assert(content.Headers.ContentLength == null, "We should not buffer content if length is available.");

_logger.LogWarning("Response does not have content length (is chunked or malformed) and is being buffered");
byte[] contentBytes = await content.ReadAsByteArrayAsync(cancellationToken);
_logger.LogInformation("Finished buffering response body ({length})", contentBytes.Length);
return new MemoryStream(contentBytes);
}

private async Task ProxyResponse(HttpContext context, string upstreamUri, string fault, CancellationToken cancellationToken)
{
UpstreamResponse upstreamResponse = await SendUpstreamRequest(context.Request, upstreamUri, cancellationToken);
switch (fault)
{
case "f":
// Full response
await SendDownstreamResponse(context.Response, upstreamResponse, upstreamResponse.ContentLength, cancellationToken);
return;
case "p":
// Partial Response (full headers, 50% of body), then wait indefinitely
await SendDownstreamResponse(context.Response, upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken);
return;
case "pc":
// Partial Response (full headers, 50% of body), then close (TCP FIN)
await SendDownstreamResponse(context.Response,upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
Close(context);
return;
case "pa":
// Partial Response (full headers, 50% of body), then abort (TCP RST)
await SendDownstreamResponse(context.Response,upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
Abort(context);
return;
case "pn":
// Partial Response (full headers, 50% of body), then finish normally
await SendDownstreamResponse(context.Response, upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
return;
case "n":
// No response, then wait indefinitely
await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken);
return;
case "nc":
// No response, then close (TCP FIN)
Close(context);
return;
case "na":
// No response, then abort (TCP RST)
Abort(context);
return;
default:
// can't really happen since we validated options before calling into this method.
throw new ArgumentException($"Invalid fault mode: {fault}", nameof(fault));
}
}

private async Task SendDownstreamResponse(HttpResponse response, UpstreamResponse upstreamResponse, long contentBytes, CancellationToken cancellationToken)
{
response.StatusCode = upstreamResponse.StatusCode;
foreach (var header in upstreamResponse.Headers)
{
response.Headers.Add(header.Key, header.Value);
}

_logger.LogInformation("Started writing response body, {actualLength}", contentBytes);

byte[] buffer = ArrayPool<byte>.Shared.Rent(81920);

try
{
using Stream source = await upstreamResponse.GetContentStreamAsync(cancellationToken);
for (long remaining = contentBytes; remaining > 0;)
{
int read = await source.ReadAsync(buffer, 0, (int)Math.Min(buffer.Length, remaining), cancellationToken);
if (read <= 0)
{
break;
}

remaining -= read;
await response.Body.WriteAsync(buffer, 0, read, cancellationToken);
}

await response.Body.FlushAsync();
}
catch (Exception ex)
{
_logger.LogError(ex, "Can't write response body");
}
finally
{
// disponse content as early as possible (before infinite wait that might happen later)
// so that underlying connection returns to connection pool
// and we won't run out of them
upstreamResponse.Dispose();
ArrayPool<byte>.Shared.Return(buffer);
_logger.LogInformation("Finished writing response body");
}
}

// Close the TCP connection by sending FIN
private void Close(HttpContext context)
{
context.Abort();
}

// Abort the TCP connection by sending RST
private void Abort(HttpContext context)
{
// SocketConnection registered "this" as the IConnectionIdFeature among other things.
var socketConnection = context.Features.Get<IConnectionIdFeature>();
var socket = (Socket)socketConnection.GetType().GetField("_socket", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(socketConnection);
socket.LingerState = new LingerOption(true, 0);
socket.Dispose();
}

private bool ValidateOrReadFaultMode(string headerValue, out string fault)
{
fault = headerValue ?? Utils.ReadSelectionFromConsole();
if (!Utils.FaultModes.TryGetValue(fault, out var description))
{
_logger.LogError("Unknown {ResponseSelectionHeader} value - {fault}.", Utils.ResponseSelectionHeader, fault);
return false;
}

_logger.LogInformation("Using response option '{description}' from header value.", description);
return true;
}
}
}
Loading
Loading