Skip to content

Commit

Permalink
Rewrite HttpFaultInjector to use ILogger and built-in ASP.NET Core lo…
Browse files Browse the repository at this point in the history
…gging (#7476)

* Rewrite failtinjector to use ILogger and built-in logging
  • Loading branch information
lmolkova authored Jan 31, 2024
1 parent 1f41756 commit 2cea39b
Show file tree
Hide file tree
Showing 6 changed files with 401 additions and 388 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@

<ItemGroup>
<PackageReference Include="CommandLineParser" Version="2.8.0" />
<PackageReference Include="OpenTelemetry.Exporter.Console" Version="1.7.0" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
using System.Threading.Tasks;
using System.Threading;
using Microsoft.Extensions.Logging;
using System;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
using System.Collections.Generic;
using System.Net.Http;
using System.Net.Sockets;
using System.Reflection;
using System.Linq;
using System.IO;
using System.Buffers;
using System.Diagnostics;

namespace Azure.Sdk.Tools.HttpFaultInjector
{
public class FaultInjectingMiddleware
{
private readonly ILogger<FaultInjectingMiddleware> _logger;
private readonly HttpClient _httpClient;
public FaultInjectingMiddleware(RequestDelegate _, IHttpClientFactory httpClientFactory, ILogger<FaultInjectingMiddleware> logger)
{
_logger = logger;
_httpClient = httpClientFactory.CreateClient("upstream");
}

public async Task InvokeAsync(HttpContext context)
{
string faultHeaderValue = context.Request.Headers[Utils.ResponseSelectionHeader];
string upstreamBaseUri = context.Request.Headers[Utils.UpstreamBaseUriHeader];

if (ValidateOrReadFaultMode(faultHeaderValue, out var fault))
{
await ProxyResponse(context, upstreamBaseUri, fault, context.RequestAborted);
}
else
{
context.Response.StatusCode = 400;
}
}


private async Task<UpstreamResponse> SendUpstreamRequest(HttpRequest request, string uri, CancellationToken cancellationToken)
{
var incomingUriBuilder = new UriBuilder()
{
Scheme = request.Scheme,
Host = request.Host.Host,
Path = request.Path.Value,
Query = request.QueryString.Value,
};
if (request.Host.Port.HasValue)
{
incomingUriBuilder.Port = request.Host.Port.Value;
}
var incomingUri = incomingUriBuilder.Uri;

var upstreamUriBuilder = new UriBuilder(uri)
{
Path = request.Path.Value,
Query = request.QueryString.Value,
};

var upstreamUri = upstreamUriBuilder.Uri;

using (var upstreamRequest = new HttpRequestMessage(new HttpMethod(request.Method), upstreamUri))
{
if (request.ContentLength > 0)
{
upstreamRequest.Content = new StreamContent(request.Body);
foreach (var header in request.Headers.Where(h => Utils.ContentRequestHeaders.Contains(h.Key)))
{
upstreamRequest.Content.Headers.Add(header.Key, values: header.Value);
}
}

foreach (var header in request.Headers.Where(h => !Utils.ExcludedRequestHeaders.Contains(h.Key) && !Utils.ContentRequestHeaders.Contains(h.Key)))
{
if (!upstreamRequest.Headers.TryAddWithoutValidation(header.Key, values: header.Value))
{
throw new InvalidOperationException($"Could not add header {header.Key} with value {header.Value}");
}
}

var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest);
var headers = new List<KeyValuePair<string, IEnumerable<string>>>();
// Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement
// your own chunking.
headers.AddRange(upstreamResponseMessage.Headers.Where(header => !string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)));
headers.AddRange(upstreamResponseMessage.Content.Headers);

var upstreamResponse = await UpstreamResponseFromHttpContent(upstreamResponseMessage.Content, cancellationToken);
upstreamResponse.StatusCode = (int)upstreamResponseMessage.StatusCode;
upstreamResponse.Headers = headers.Select(h => new KeyValuePair<string, StringValues>(h.Key, h.Value.ToArray()));

return upstreamResponse;
}
}


private async Task<UpstreamResponse> UpstreamResponseFromHttpContent(HttpContent content, CancellationToken cancellationToken)
{
if (content.Headers.ContentLength == null)
{
MemoryStream contentStream = await BufferContentAsync(content, cancellationToken);
// we no longer need that content and can let the connection go back to the pool.
content.Dispose();
return new UpstreamResponse(contentStream);
}

return new UpstreamResponse(content);
}

private async Task<MemoryStream> BufferContentAsync(HttpContent content, CancellationToken cancellationToken)
{
Debug.Assert(content.Headers.ContentLength == null, "We should not buffer content if length is available.");

_logger.LogWarning("Response does not have content length (is chunked or malformed) and is being buffered");
byte[] contentBytes = await content.ReadAsByteArrayAsync(cancellationToken);
_logger.LogInformation("Finished buffering response body ({length})", contentBytes.Length);
return new MemoryStream(contentBytes);
}

private async Task ProxyResponse(HttpContext context, string upstreamUri, string fault, CancellationToken cancellationToken)
{
UpstreamResponse upstreamResponse = await SendUpstreamRequest(context.Request, upstreamUri, cancellationToken);
switch (fault)
{
case "f":
// Full response
await SendDownstreamResponse(context.Response, upstreamResponse, upstreamResponse.ContentLength, cancellationToken);
return;
case "p":
// Partial Response (full headers, 50% of body), then wait indefinitely
await SendDownstreamResponse(context.Response, upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken);
return;
case "pc":
// Partial Response (full headers, 50% of body), then close (TCP FIN)
await SendDownstreamResponse(context.Response,upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
Close(context);
return;
case "pa":
// Partial Response (full headers, 50% of body), then abort (TCP RST)
await SendDownstreamResponse(context.Response,upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
Abort(context);
return;
case "pn":
// Partial Response (full headers, 50% of body), then finish normally
await SendDownstreamResponse(context.Response, upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
return;
case "n":
// No response, then wait indefinitely
await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken);
return;
case "nc":
// No response, then close (TCP FIN)
Close(context);
return;
case "na":
// No response, then abort (TCP RST)
Abort(context);
return;
default:
// can't really happen since we validated options before calling into this method.
throw new ArgumentException($"Invalid fault mode: {fault}", nameof(fault));
}
}

private async Task SendDownstreamResponse(HttpResponse response, UpstreamResponse upstreamResponse, long contentBytes, CancellationToken cancellationToken)
{
response.StatusCode = upstreamResponse.StatusCode;
foreach (var header in upstreamResponse.Headers)
{
response.Headers.Add(header.Key, header.Value);
}

_logger.LogInformation("Started writing response body, {actualLength}", contentBytes);

byte[] buffer = ArrayPool<byte>.Shared.Rent(81920);

try
{
using Stream source = await upstreamResponse.GetContentStreamAsync(cancellationToken);
for (long remaining = contentBytes; remaining > 0;)
{
int read = await source.ReadAsync(buffer, 0, (int)Math.Min(buffer.Length, remaining), cancellationToken);
if (read <= 0)
{
break;
}

remaining -= read;
await response.Body.WriteAsync(buffer, 0, read, cancellationToken);
}

await response.Body.FlushAsync();
}
catch (Exception ex)
{
_logger.LogError(ex, "Can't write response body");
}
finally
{
// disponse content as early as possible (before infinite wait that might happen later)
// so that underlying connection returns to connection pool
// and we won't run out of them
upstreamResponse.Dispose();
ArrayPool<byte>.Shared.Return(buffer);
_logger.LogInformation("Finished writing response body");
}
}

// Close the TCP connection by sending FIN
private void Close(HttpContext context)
{
context.Abort();
}

// Abort the TCP connection by sending RST
private void Abort(HttpContext context)
{
// SocketConnection registered "this" as the IConnectionIdFeature among other things.
var socketConnection = context.Features.Get<IConnectionIdFeature>();
var socket = (Socket)socketConnection.GetType().GetField("_socket", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(socketConnection);
socket.LingerState = new LingerOption(true, 0);
socket.Dispose();
}

private bool ValidateOrReadFaultMode(string headerValue, out string fault)
{
fault = headerValue ?? Utils.ReadSelectionFromConsole();
if (!Utils.FaultModes.TryGetValue(fault, out var description))
{
_logger.LogError("Unknown {ResponseSelectionHeader} value - {fault}.", Utils.ResponseSelectionHeader, fault);
return false;
}

_logger.LogInformation("Using response option '{description}' from header value.", description);
return true;
}
}
}
Loading

0 comments on commit 2cea39b

Please sign in to comment.