Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FaultInjector: stream response instead of buffering #7466

Merged
merged 6 commits into from
Jan 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using Microsoft.Extensions.Primitives;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
Expand Down Expand Up @@ -74,6 +76,7 @@ public static void Main(string[] args)

private static void Run(Options options)
{
TimeSpan keepAlive = TimeSpan.FromSeconds(options.KeepAliveTimeout);
if (options.Insecure)
{
_httpClient = new HttpClient(new HttpClientHandler()
Expand All @@ -87,6 +90,9 @@ private static void Run(Options options)
_httpClient = new HttpClient();
}

// TODO: we can switch to SocketsHttpHandler and configure read/write/connect timeouts separately
// for now let's just set upstream timeout to be slightly bigger than client timeout.
_httpClient.Timeout = keepAlive + TimeSpan.FromSeconds(1);
new WebHostBuilder()
.UseKestrel(kestrelOptions =>
{
Expand All @@ -95,13 +101,13 @@ private static void Run(Options options)
{
listenOptions.UseHttps();
});
kestrelOptions.Limits.KeepAliveTimeout = TimeSpan.FromSeconds(options.KeepAliveTimeout);
kestrelOptions.Limits.KeepAliveTimeout = keepAlive;
})
.Configure(app => app.Run(async context =>
{
try
{
var upstreamResponse = await SendUpstreamRequest(context.Request);
using var upstreamResponse = await SendUpstreamRequest(context.Request, context.RequestAborted);

// Attempt to remove the response selection header and use its value to handle the response selection.
if (context.Request.Headers.Remove(_responseSelectionHeader, out var selection))
Expand Down Expand Up @@ -149,7 +155,7 @@ private static void Run(Options options)
.Run();
}

private static async Task<UpstreamResponse> SendUpstreamRequest(HttpRequest request)
private static async Task<UpstreamResponse> SendUpstreamRequest(HttpRequest request, CancellationToken cancellationToken)
{
Console.WriteLine();
Log("Incoming Request");
Expand Down Expand Up @@ -186,74 +192,66 @@ private static async Task<UpstreamResponse> SendUpstreamRequest(HttpRequest requ
Log("Upstream Request");
Log($"URL: {upstreamUri}");

using (var upstreamRequest = new HttpRequestMessage(new HttpMethod(request.Method), upstreamUri))
using var upstreamRequest = new HttpRequestMessage(new HttpMethod(request.Method), upstreamUri);
Log("Headers:");

if (request.ContentLength > 0)
{
Log("Headers:");
upstreamRequest.Content = new StreamContent(request.Body);

if (request.ContentLength > 0)
foreach (var header in request.Headers.Where(h => _contentRequestHeaders.Contains(h.Key)))
{
upstreamRequest.Content = new StreamContent(request.Body);

foreach (var header in request.Headers.Where(h => _contentRequestHeaders.Contains(h.Key)))
{
Log($" {header.Key}:{header.Value.First()}");
upstreamRequest.Content.Headers.Add(header.Key, values: header.Value);
}
Log($" {header.Key}:{header.Value.First()}");
upstreamRequest.Content.Headers.Add(header.Key, values: header.Value);
}
}

foreach (var header in request.Headers.Where(h => !_excludedRequestHeaders.Contains(h.Key) && !_contentRequestHeaders.Contains(h.Key)))
foreach (var header in request.Headers.Where(h => !_excludedRequestHeaders.Contains(h.Key) && !_contentRequestHeaders.Contains(h.Key)))
{
Log($" {header.Key}:{header.Value.First()}");
if (!upstreamRequest.Headers.TryAddWithoutValidation(header.Key, values: header.Value))
{
Log($" {header.Key}:{header.Value.First()}");
if (!upstreamRequest.Headers.TryAddWithoutValidation(header.Key, values: header.Value))
{
throw new InvalidOperationException($"Could not add header {header.Key} with value {header.Value}");
}
throw new InvalidOperationException($"Could not add header {header.Key} with value {header.Value}");
}
}

Log("Sending request to upstream server...");
using (var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest))
{
Console.WriteLine();
Log("Upstream Response");
var headers = new List<KeyValuePair<string, StringValues>>();
Log("Sending request to upstream server...");
var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
Console.WriteLine();
Log("Upstream Response");
var headers = new List<KeyValuePair<string, StringValues>>();

Log($"StatusCode: {upstreamResponseMessage.StatusCode}");
Log($"StatusCode: {upstreamResponseMessage.StatusCode}");

Log("Headers:");
foreach (var header in upstreamResponseMessage.Headers)
{
Log($" {header.Key}:{header.Value.First()}");
Log("Headers:");

// Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement
// your own chunking.
if (string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase))
{
continue;
}
foreach (var header in upstreamResponseMessage.Headers)
{
Log($" {header.Key}:{header.Value.First()}");

headers.Add(new KeyValuePair<string, StringValues>(header.Key, header.Value.ToArray()));
}
// Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement
// your own chunking.
if (string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase))
{
continue;
}

foreach (var header in upstreamResponseMessage.Content.Headers)
{
Log($" {header.Key}:{header.Value.First()}");
headers.Add(new KeyValuePair<string, StringValues>(header.Key, header.Value.ToArray()));
}
headers.Add(new KeyValuePair<string, StringValues>(header.Key, header.Value.ToArray()));
}

Log("Reading upstream response body...");
foreach (var header in upstreamResponseMessage.Content.Headers)
{
Log($" {header.Key}:{header.Value.First()}");
headers.Add(new KeyValuePair<string, StringValues>(header.Key, header.Value.ToArray()));
}

var upstreamResponse = new UpstreamResponse()
{
StatusCode = (int)upstreamResponseMessage.StatusCode,
Headers = headers.ToArray(),
Content = await upstreamResponseMessage.Content.ReadAsByteArrayAsync()
};
Log($"ContentLength: {upstreamResponseMessage.Content.Headers.ContentLength}, is chunked: {upstreamResponseMessage.Headers.TransferEncodingChunked}");

Log($"ContentLength: {upstreamResponse.Content.Length}");
var response = await UpstreamResponse.FromContent(upstreamResponseMessage.Content, cancellationToken);
response.StatusCode = (int)upstreamResponseMessage.StatusCode;
response.Headers = headers.ToArray();

return upstreamResponse;
}
}
return response;
}

private static async Task<bool> TryHandleResponseOption(string selection, HttpContext context, UpstreamResponse upstreamResponse)
Expand All @@ -262,26 +260,26 @@ private static async Task<bool> TryHandleResponseOption(string selection, HttpCo
{
case "f":
// Full response
await SendDownstreamResponse(upstreamResponse, context.Response);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength, context.RequestAborted);
return true;
case "p":
// Partial Response (full headers, 50% of body), then wait indefinitely
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted);
await Task.Delay(Timeout.InfiniteTimeSpan, context.RequestAborted);
return true;
case "pc":
// Partial Response (full headers, 50% of body), then close (TCP FIN)
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted);
Close(context);
return true;
case "pa":
// Partial Response (full headers, 50% of body), then abort (TCP RST)
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted);
Abort(context);
return true;
case "pn":
// Partial Response (full headers, 50% of body), then finish normally
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2);
await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted);
return true;
case "n":
// No response, then wait indefinitely
Expand All @@ -301,7 +299,7 @@ private static async Task<bool> TryHandleResponseOption(string selection, HttpCo
}
}

private static async Task SendDownstreamResponse(UpstreamResponse upstreamResponse, HttpResponse response, int? contentBytes = null)
private static async Task SendDownstreamResponse(UpstreamResponse upstreamResponse, HttpResponse response, long? contentBytes, CancellationToken cancellationToken)
{
Console.WriteLine();

Expand All @@ -318,10 +316,41 @@ private static async Task SendDownstreamResponse(UpstreamResponse upstreamRespon
response.Headers.Add(header.Key, header.Value);
}

var count = contentBytes ?? upstreamResponse.Content.Length;
long count = contentBytes ?? long.MaxValue;

Log($"Writing response body of {count} bytes...");
await response.Body.WriteAsync(upstreamResponse.Content, 0, count);
Log($"Finished writing response body");

using Stream source = await upstreamResponse.GetContentStreamAsync(cancellationToken);
byte[] buffer = new byte[8192];
lmolkova marked this conversation as resolved.
Show resolved Hide resolved

try
{
for (long remaining = count; remaining > 0;)
{
int read = await source.ReadAsync(buffer, 0, (int)Math.Min(buffer.Length, remaining), cancellationToken);
if (read <= 0)
{
break;
}

remaining -= read;
await response.Body.WriteAsync(buffer, 0, read, cancellationToken);
}

await response.Body.FlushAsync();
Log($"Finished writing response body");
}
catch (Exception ex)
{
Log(ex.ToString());
}
finally
{
lmolkova marked this conversation as resolved.
Show resolved Hide resolved
// disponse content as early as possible (before infinite wait that might happen later)
// so that underlying connection returns to connection pool
// and we won't run out of them
upstreamResponse.Dispose();
}
}

// Close the TCP connection by sending FIN
Expand All @@ -345,11 +374,61 @@ private static void Log(object value)
Console.WriteLine($"[{DateTime.Now:hh:mm:ss.fff}] {value}");
}

private class UpstreamResponse
private class UpstreamResponse : IDisposable
{
private readonly HttpContent _content = null;
private readonly Stream _contentStream = null;
private UpstreamResponse(HttpContent content)
{
_content = content;
ContentLength = _content.Headers.ContentLength ?? throw new ArgumentNullException("ContentLength must not be null.");
}

private UpstreamResponse(MemoryStream contentStream)
{
_contentStream = contentStream;
ContentLength = contentStream.Length;
}

public int StatusCode { get; set; }
public KeyValuePair<string, StringValues>[] Headers { get; set; }
public byte[] Content { get; set; }
public async Task<Stream> GetContentStreamAsync(CancellationToken cancellationToken)
{
return _contentStream != null ? _contentStream : await _content.ReadAsStreamAsync(cancellationToken);
}

public long ContentLength { get; }

public void Dispose()
{
_content?.Dispose();
_contentStream?.Dispose();
}

public static async Task<UpstreamResponse> FromContent(HttpContent content, CancellationToken cancellationToken)
{
if (content.Headers.ContentLength == null)
{
MemoryStream contentStream = await BufferContentAsync(content, cancellationToken);
// we no longer need that content and can let the connection go back to the pool.
content.Dispose();
return new UpstreamResponse(contentStream);
}

return new UpstreamResponse(content);
}

private static async Task<MemoryStream> BufferContentAsync(HttpContent content, CancellationToken cancellationToken)
{
Debug.Assert(content.Headers.ContentLength == null, "We should not buffer content if length is available.");

Log("Response does not have content length (is chunked or malformed) and is being buffered");
byte[] contentBytes = await content.ReadAsByteArrayAsync(cancellationToken);
Log($"Content was buffered, total length - {contentBytes.Length}.");

// comparing to buffering full content memory stream allocation is not such a big deal.
return new MemoryStream(contentBytes);
}
}
}
}
Loading