Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[http-fault-injector] Add support for interrupting request #7698

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,10 @@ private async Task<UpstreamResponse> SendUpstreamRequest(HttpRequest request, st

using (var upstreamRequest = new HttpRequestMessage(new HttpMethod(request.Method), upstreamUri))
{
if (request.ContentLength > 0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should work even for requests with no body, or requests with chunked encoding. Shouldn't need to gate on content-length header.

upstreamRequest.Content = new StreamContent(request.Body);
foreach (var header in request.Headers.Where(h => Utils.ContentRequestHeaders.Contains(h.Key)))
{
upstreamRequest.Content = new StreamContent(request.Body);
foreach (var header in request.Headers.Where(h => Utils.ContentRequestHeaders.Contains(h.Key)))
{
upstreamRequest.Content.Headers.Add(header.Key, values: header.Value);
}
upstreamRequest.Content.Headers.Add(header.Key, values: header.Value);
}

foreach (var header in request.Headers.Where(h => !Utils.ExcludedRequestHeaders.Contains(h.Key) && !Utils.ContentRequestHeaders.Contains(h.Key)))
Expand Down Expand Up @@ -125,7 +122,42 @@ private async Task<MemoryStream> BufferContentAsync(HttpContent content, Cancell

private async Task ProxyResponse(HttpContext context, string upstreamUri, string fault, CancellationToken cancellationToken)
{
switch (fault)
{
case "nq":
// No request body, then wait indefinitely
await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken);
return;
case "nqc":
// No request body, then close (TCP FIN)
Close(context);
return;
case "nqa":
// No request body, then abort (TCP RST)
Abort(context);
return;
case "pq":
// Partial request (50% of body), then wait indefinitely
await ReadPartialRequest(context.Request, cancellationToken);
await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken);
return;
case "pqc":
// Partial request (50% of body), then close (TCP FIN)
await ReadPartialRequest(context.Request, cancellationToken);
Close(context);
return;
case "pqa":
// Partial request (50% of body), then abort (TCP RST)
await ReadPartialRequest(context.Request, cancellationToken);
Abort(context);
return;
default:
// Fall through and read full request body
break;
}

UpstreamResponse upstreamResponse = await SendUpstreamRequest(context.Request, upstreamUri, cancellationToken);

switch (fault)
{
case "f":
Expand All @@ -139,12 +171,12 @@ private async Task ProxyResponse(HttpContext context, string upstreamUri, string
return;
case "pc":
// Partial Response (full headers, 50% of body), then close (TCP FIN)
await SendDownstreamResponse(context.Response,upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
await SendDownstreamResponse(context.Response, upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
Close(context);
return;
case "pa":
// Partial Response (full headers, 50% of body), then abort (TCP RST)
await SendDownstreamResponse(context.Response,upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
await SendDownstreamResponse(context.Response, upstreamResponse, upstreamResponse.ContentLength / 2, cancellationToken);
Abort(context);
return;
case "pn":
Expand All @@ -169,6 +201,36 @@ private async Task ProxyResponse(HttpContext context, string upstreamUri, string
}
}

private static async Task ReadPartialRequest(HttpRequest request, CancellationToken cancellationToken)
{
var contentLength = request.ContentLength
?? throw new InvalidOperationException("Partial request options require content-length request headers");
var bytesToRead = contentLength / 2;
long totalBytesRead = 0;
var buffer = ArrayPool<byte>.Shared.Rent(81920);
try
{
while (true)
{
var bytesRead = await request.Body.ReadAsync(
buffer,
0,
(int)Math.Min(buffer.Length, bytesToRead - totalBytesRead),
cancellationToken
);
totalBytesRead += bytesRead;
if (totalBytesRead >= bytesToRead || bytesRead == 0)
{
break;
}
}
}
finally
{
ArrayPool<byte>.Shared.Return(buffer);
}
}

private async Task SendDownstreamResponse(HttpResponse response, UpstreamResponse upstreamResponse, long contentBytes, CancellationToken cancellationToken)
{
response.StatusCode = upstreamResponse.StatusCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ public static class Utils
{
public static readonly IDictionary<string, string> FaultModes = new Dictionary<string, string>()
{
{ "f", "Full response" },
{ "p", "Partial Response (full headers, 50% of body), then wait indefinitely" },
{"pc", "Partial Response (full headers, 50% of body), then close (TCP FIN)" },
{"pa", "Partial Response (full headers, 50% of body), then abort (TCP RST)" },
{"pn", "Partial Response (full headers, 50% of body), then finish normally" },
{"n", "No response, then wait indefinitely"},
{"nc", "No response, then close (TCP FIN)" },
{"na", "No response, then abort (TCP RST)" }
{ "f", "Full response" },
{ "p", "Partial Response (full headers, 50% of body), then wait indefinitely" },
{ "pc", "Partial Response (full headers, 50% of body), then close (TCP FIN)" },
{ "pa", "Partial Response (full headers, 50% of body), then abort (TCP RST)" },
{ "pn", "Partial Response (full headers, 50% of body), then finish normally" },
{ "n", "No response, then wait indefinitely"},
{ "nc", "No response, then close (TCP FIN)" },
{ "na", "No response, then abort (TCP RST)" },
{ "pq", "Partial request (50% of body), then wait indefinitely"},
{"pqc", "Partial request (50% of body), then close (TCP FIN)"},
{"pqa", "Partial request (50% of body), then abort (TCP RST)"},
{ "nq", "No request body, then wait indefinitely"},
{"nqc", "No request body, then close (TCP FIN)"},
{"nqa", "No request body, then abort (TCP RST)"},
};

public static readonly string[] ExcludedRequestHeaders = new string[] {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using System;
using System.IO;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -9,11 +11,34 @@ class Program
{
static async Task Main(string[] args)
{
var httpClient = new HttpClient(new FaultInjectionClientHandler(new Uri("http://localhost:7777")));
var directClient = new HttpClient();
var faultInjectionClient = new HttpClient(new FaultInjectionClientHandler(new Uri("http://localhost:7777")))
{
// Short timeout for testing no response
Timeout = TimeSpan.FromSeconds(10)
};

Console.WriteLine("Sending request directly...");
await Test(directClient);

Console.WriteLine("Sending request through fault injector...");
await Test(faultInjectionClient);
}

private static async Task Test(HttpClient client)
{
var baseUrl = "http://localhost:5000";

var uploadStream = new LoggingStream(new MemoryStream(Encoding.UTF8.GetBytes(new string('a', 10 * 1024 * 1024))));

var response = await client.PutAsync(baseUrl + "/upload", new StreamContent(uploadStream));

Console.WriteLine("Sending request...");
var response = await httpClient.GetAsync("https://www.example.org");
Console.WriteLine(response.StatusCode);
var content = await response.Content.ReadAsStringAsync();
var shortContent = (content.Length <= 40 ? content : content.Substring(0, 40) + "...");

Console.WriteLine($"Status: {response.StatusCode}");
Console.WriteLine($"Content: {shortContent}");
Console.WriteLine($"Length: {content.Length}");
}

class FaultInjectionClientHandler : HttpClientHandler
Expand Down Expand Up @@ -52,5 +77,58 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
return base.SendAsync(request, cancellationToken);
}
}

class LoggingStream : Stream
{
private readonly Stream _stream;
private long _totalBytesRead;

public LoggingStream(Stream stream)
{
_stream = stream;
}

public override bool CanRead => _stream.CanRead;

public override bool CanSeek => _stream.CanSeek;

public override bool CanWrite => _stream.CanWrite;

public override long Length => _stream.Length;

public override long Position
{
get => _stream.Position;
set => _stream.Position = value;
}

public override void Flush()
{
_stream.Flush();
}

public override int Read(byte[] buffer, int offset, int count)
{
var bytesRead = _stream.Read(buffer, offset, count);
_totalBytesRead += bytesRead;
Console.WriteLine($"Read(buffer: byte[{buffer.Length}], offset: {offset}, count: {count}) => {bytesRead} (total {_totalBytesRead})");
return bytesRead;
}

public override long Seek(long offset, SeekOrigin origin)
{
return _stream.Seek(offset, origin);
}

public override void SetLength(long value)
{
_stream.SetLength(value);
}

public override void Write(byte[] buffer, int offset, int count)
{
_stream.Write(buffer, offset, count);
}
}
}
}
Loading