Skip to content

Commit

Permalink
FaultInjector: stream response instead of buffering (#7466)
Browse files Browse the repository at this point in the history
* Read and forward upstream response stream in chunks whenever possible to avoid buffering

---------

Co-authored-by: Mike Harder <[email protected]>
  • Loading branch information
lmolkova and mikeharder authored Jan 25, 2024
1 parent 693b937 commit ebca8d6
Showing 1 changed file with 144 additions and 65 deletions.
209 changes: 144 additions & 65 deletions tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using Microsoft.Extensions.Primitives;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
Expand Down Expand Up @@ -74,6 +76,7 @@ public static void Main(string[] args)

private static void Run(Options options)
{
TimeSpan keepAlive = TimeSpan.FromSeconds(options.KeepAliveTimeout);
if (options.Insecure)
{
_httpClient = new HttpClient(new HttpClientHandler()
Expand All @@ -87,6 +90,9 @@ private static void Run(Options options)
_httpClient = new HttpClient();
}

// TODO: we can switch to SocketsHttpHandler and configure read/write/connect timeouts separately
// for now let's just set upstream timeout to be slightly bigger than client timeout.
_httpClient.Timeout = keepAlive + TimeSpan.FromSeconds(1);
new WebHostBuilder()
.UseKestrel(kestrelOptions =>
{
Expand All @@ -95,13 +101,13 @@ private static void Run(Options options)
{
listenOptions.UseHttps();
});
kestrelOptions.Limits.KeepAliveTimeout = TimeSpan.FromSeconds(options.KeepAliveTimeout);
kestrelOptions.Limits.KeepAliveTimeout = keepAlive;
})
.Configure(app => app.Run(async context =>
{
try
{
var upstreamResponse = await SendUpstreamRequest(context.Request);
using var upstreamResponse = await SendUpstreamRequest(context.Request, context.RequestAborted);

// Attempt to remove the response selection header and use its value to handle the response selection.
if (context.Request.Headers.Remove(_responseSelectionHeader, out var selection))
Expand Down Expand Up @@ -149,7 +155,7 @@ private static void Run(Options options)
.Run();
}

private static async Task<UpstreamResponse> SendUpstreamRequest(HttpRequest request)
private static async Task<UpstreamResponse> SendUpstreamRequest(HttpRequest request, CancellationToken cancellationToken)
{
Console.WriteLine();
Log("Incoming Request");
Expand Down Expand Up @@ -186,74 +192,66 @@ private static async Task<UpstreamResponse> SendUpstreamRequest(HttpRequest requ
Log("Upstream Request");
Log($"URL: {upstreamUri}");

using (var upstreamRequest = new HttpRequestMessage(new HttpMethod(request.Method), upstreamUri))
using var upstreamRequest = new HttpRequestMessage(new HttpMethod(request.Method), upstreamUri);
Log("Headers:");

if (request.ContentLength > 0)
{
Log("Headers:");
upstreamRequest.Content = new StreamContent(request.Body);

if (request.ContentLength > 0)
foreach (var header in request.Headers.Where(h => _contentRequestHeaders.Contains(h.Key)))
{
upstreamRequest.Content = new StreamContent(request.Body);

foreach (var header in request.Headers.Where(h => _contentRequestHeaders.Contains(h.Key)))
{
Log($" {header.Key}:{header.Value.First()}");
upstreamRequest.Content.Headers.Add(header.Key, values: header.Value);
}
Log($" {header.Key}:{header.Value.First()}");
upstreamRequest.Content.Headers.Add(header.Key, values: header.Value);
}
}

foreach (var header in request.Headers.Where(h => !_excludedRequestHeaders.Contains(h.Key) && !_contentRequestHeaders.Contains(h.Key)))
foreach (var header in request.Headers.Where(h => !_excludedRequestHeaders.Contains(h.Key) && !_contentRequestHeaders.Contains(h.Key)))
{
Log($" {header.Key}:{header.Value.First()}");
if (!upstreamRequest.Headers.TryAddWithoutValidation(header.Key, values: header.Value))
{
Log($" {header.Key}:{header.Value.First()}");
if (!upstreamRequest.Headers.TryAddWithoutValidation(header.Key, values: header.Value))
{
throw new InvalidOperationException($"Could not add header {header.Key} with value {header.Value}");
}
throw new InvalidOperationException($"Could not add header {header.Key} with value {header.Value}");
}
}

Log("Sending request to upstream server...");
using (var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest))
{
Console.WriteLine();
Log("Upstream Response");
var headers = new List<KeyValuePair<string, StringValues>>();
Log("Sending request to upstream server...");
var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
Console.WriteLine();
Log("Upstream Response");
var headers = new List<KeyValuePair<string, StringValues>>();

Log($"StatusCode: {upstreamResponseMessage.StatusCode}");
Log($"StatusCode: {upstreamResponseMessage.StatusCode}");

Log("Headers:");
foreach (var header in upstreamResponseMessage.Headers)
{
Log($" {header.Key}:{header.Value.First()}");
Log("Headers:");

// Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement
// your own chunking.
if (string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase))
{
continue;
}
foreach (var header in upstreamResponseMessage.Headers)
{
Log($" {header.Key}:{header.Value.First()}");

headers.Add(new KeyValuePair<string, StringValues>(header.Key, header.Value.ToArray()));
}
// Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement
// your own chunking.
if (string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase))
{
continue;
}

foreach (var header in upstreamResponseMessage.Content.Headers)
{
Log($" {header.Key}:{header.Value.First()}");
headers.Add(new KeyValuePair<string, StringValues>(header.Key, header.Value.ToArray()));
}
headers.Add(new KeyValuePair<string, StringValues>(header.Key, header.Value.ToArray()));
}

Log("Reading upstream response body...");
foreach (var header in upstreamResponseMessage.Content.Headers)
{
Log($" {header.Key}:{header.Value.First()}");
headers.Add(new KeyValuePair<string, StringValues>(header.Key, header.Value.ToArray()));
}

var upstreamResponse = new UpstreamResponse()
{
StatusCode = (int)upstreamResponseMessage.StatusCode,
Headers = headers.ToArray(),
Content = await upstreamResponseMessage.Content.ReadAsByteArrayAsync()
};
Log($"ContentLength: {upstreamResponseMessage.Content.Headers.ContentLength}, is chunked: {upstreamResponseMessage.Headers.TransferEncodingChunked}");

Log($"ContentLength: {upstreamResponse.Content.Length}");
var response = await UpstreamResponse.FromContent(upstreamResponseMessage.Content, cancellationToken);
response.StatusCode = (int)upstreamResponseMessage.StatusCode;
response.Headers = headers.ToArray();

return upstreamResponse;
}
}
return response;
}

private static async Task<bool> TryHandleResponseOption(string selection, HttpContext context, UpstreamResponse upstreamResponse)
Expand All @@ -262,26 +260,26 @@ private static async Task<bool> TryHandleResponseOption(string selection, HttpCo
{
case "f":
// Full response
await SendDownstreamResponse(upstreamResponse, context.Response);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength, context.RequestAborted);
return true;
case "p":
// Partial Response (full headers, 50% of body), then wait indefinitely
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted);
await Task.Delay(Timeout.InfiniteTimeSpan, context.RequestAborted);
return true;
case "pc":
// Partial Response (full headers, 50% of body), then close (TCP FIN)
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted);
Close(context);
return true;
case "pa":
// Partial Response (full headers, 50% of body), then abort (TCP RST)
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted);
Abort(context);
return true;
case "pn":
// Partial Response (full headers, 50% of body), then finish normally
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted);
return true;
case "n":
// No response, then wait indefinitely
Expand All @@ -301,7 +299,7 @@ private static async Task<bool> TryHandleResponseOption(string selection, HttpCo
}
}

private static async Task SendDownstreamResponse(UpstreamResponse upstreamResponse, HttpResponse response, int? contentBytes = null)
private static async Task SendDownstreamResponse(UpstreamResponse upstreamResponse, HttpResponse response, long? contentBytes, CancellationToken cancellationToken)
{
Console.WriteLine();

Expand All @@ -318,10 +316,41 @@ private static async Task SendDownstreamResponse(UpstreamResponse upstreamRespon
response.Headers.Add(header.Key, header.Value);
}

var count = contentBytes ?? upstreamResponse.Content.Length;
long count = contentBytes ?? long.MaxValue;

Log($"Writing response body of {count} bytes...");
await response.Body.WriteAsync(upstreamResponse.Content, 0, count);
Log($"Finished writing response body");

using Stream source = await upstreamResponse.GetContentStreamAsync(cancellationToken);
byte[] buffer = new byte[8192];

try
{
for (long remaining = count; remaining > 0;)
{
int read = await source.ReadAsync(buffer, 0, (int)Math.Min(buffer.Length, remaining), cancellationToken);
if (read <= 0)
{
break;
}

remaining -= read;
await response.Body.WriteAsync(buffer, 0, read, cancellationToken);
}

await response.Body.FlushAsync();
Log($"Finished writing response body");
}
catch (Exception ex)
{
Log(ex.ToString());
}
finally
{
// disponse content as early as possible (before infinite wait that might happen later)
// so that underlying connection returns to connection pool
// and we won't run out of them
upstreamResponse.Dispose();
}
}

// Close the TCP connection by sending FIN
Expand All @@ -345,11 +374,61 @@ private static void Log(object value)
Console.WriteLine($"[{DateTime.Now:hh:mm:ss.fff}] {value}");
}

private class UpstreamResponse
private class UpstreamResponse : IDisposable
{
private readonly HttpContent _content = null;
private readonly Stream _contentStream = null;
private UpstreamResponse(HttpContent content)
{
_content = content;
ContentLength = _content.Headers.ContentLength ?? throw new ArgumentNullException("ContentLength must not be null.");
}

private UpstreamResponse(MemoryStream contentStream)
{
_contentStream = contentStream;
ContentLength = contentStream.Length;
}

public int StatusCode { get; set; }
public KeyValuePair<string, StringValues>[] Headers { get; set; }
public byte[] Content { get; set; }
public async Task<Stream> GetContentStreamAsync(CancellationToken cancellationToken)
{
return _contentStream != null ? _contentStream : await _content.ReadAsStreamAsync(cancellationToken);
}

public long ContentLength { get; }

public void Dispose()
{
_content?.Dispose();
_contentStream?.Dispose();
}

public static async Task<UpstreamResponse> FromContent(HttpContent content, CancellationToken cancellationToken)
{
if (content.Headers.ContentLength == null)
{
MemoryStream contentStream = await BufferContentAsync(content, cancellationToken);
// we no longer need that content and can let the connection go back to the pool.
content.Dispose();
return new UpstreamResponse(contentStream);
}

return new UpstreamResponse(content);
}

private static async Task<MemoryStream> BufferContentAsync(HttpContent content, CancellationToken cancellationToken)
{
Debug.Assert(content.Headers.ContentLength == null, "We should not buffer content if length is available.");

Log("Response does not have content length (is chunked or malformed) and is being buffered");
byte[] contentBytes = await content.ReadAsByteArrayAsync(cancellationToken);
Log($"Content was buffered, total length - {contentBytes.Length}.");

// comparing to buffering full content memory stream allocation is not such a big deal.
return new MemoryStream(contentBytes);
}
}
}
}

0 comments on commit ebca8d6

Please sign in to comment.