diff --git a/tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/Program.cs b/tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/Program.cs index 1341333953e..a898879ba13 100644 --- a/tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/Program.cs +++ b/tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/Program.cs @@ -7,6 +7,8 @@ using Microsoft.Extensions.Primitives; using System; using System.Collections.Generic; +using System.Diagnostics; +using System.IO; using System.Linq; using System.Net; using System.Net.Http; @@ -74,6 +76,7 @@ public static void Main(string[] args) private static void Run(Options options) { + TimeSpan keepAlive = TimeSpan.FromSeconds(options.KeepAliveTimeout); if (options.Insecure) { _httpClient = new HttpClient(new HttpClientHandler() @@ -87,6 +90,9 @@ private static void Run(Options options) _httpClient = new HttpClient(); } + // TODO: we can switch to SocketsHttpHandler and configure read/write/connect timeouts separately + // for now let's just set upstream timeout to be slightly bigger than client timeout. + _httpClient.Timeout = keepAlive + TimeSpan.FromSeconds(1); new WebHostBuilder() .UseKestrel(kestrelOptions => { @@ -95,13 +101,13 @@ private static void Run(Options options) { listenOptions.UseHttps(); }); - kestrelOptions.Limits.KeepAliveTimeout = TimeSpan.FromSeconds(options.KeepAliveTimeout); + kestrelOptions.Limits.KeepAliveTimeout = keepAlive; }) .Configure(app => app.Run(async context => { try { - var upstreamResponse = await SendUpstreamRequest(context.Request); + using var upstreamResponse = await SendUpstreamRequest(context.Request, context.RequestAborted); // Attempt to remove the response selection header and use its value to handle the response selection. if (context.Request.Headers.Remove(_responseSelectionHeader, out var selection)) @@ -149,7 +155,7 @@ private static void Run(Options options) .Run(); } - private static async Task SendUpstreamRequest(HttpRequest request) + private static async Task SendUpstreamRequest(HttpRequest request, CancellationToken cancellationToken) { Console.WriteLine(); Log("Incoming Request"); @@ -186,74 +192,66 @@ private static async Task SendUpstreamRequest(HttpRequest requ Log("Upstream Request"); Log($"URL: {upstreamUri}"); - using (var upstreamRequest = new HttpRequestMessage(new HttpMethod(request.Method), upstreamUri)) + using var upstreamRequest = new HttpRequestMessage(new HttpMethod(request.Method), upstreamUri); + Log("Headers:"); + + if (request.ContentLength > 0) { - Log("Headers:"); + upstreamRequest.Content = new StreamContent(request.Body); - if (request.ContentLength > 0) + foreach (var header in request.Headers.Where(h => _contentRequestHeaders.Contains(h.Key))) { - upstreamRequest.Content = new StreamContent(request.Body); - - foreach (var header in request.Headers.Where(h => _contentRequestHeaders.Contains(h.Key))) - { - Log($" {header.Key}:{header.Value.First()}"); - upstreamRequest.Content.Headers.Add(header.Key, values: header.Value); - } + Log($" {header.Key}:{header.Value.First()}"); + upstreamRequest.Content.Headers.Add(header.Key, values: header.Value); } + } - foreach (var header in request.Headers.Where(h => !_excludedRequestHeaders.Contains(h.Key) && !_contentRequestHeaders.Contains(h.Key))) + foreach (var header in request.Headers.Where(h => !_excludedRequestHeaders.Contains(h.Key) && !_contentRequestHeaders.Contains(h.Key))) + { + Log($" {header.Key}:{header.Value.First()}"); + if (!upstreamRequest.Headers.TryAddWithoutValidation(header.Key, values: header.Value)) { - Log($" {header.Key}:{header.Value.First()}"); - if (!upstreamRequest.Headers.TryAddWithoutValidation(header.Key, values: header.Value)) - { - throw new InvalidOperationException($"Could not add header {header.Key} with value {header.Value}"); - } + throw new InvalidOperationException($"Could not add header {header.Key} with value {header.Value}"); } + } - Log("Sending request to upstream server..."); - using (var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest)) - { - Console.WriteLine(); - Log("Upstream Response"); - var headers = new List>(); + Log("Sending request to upstream server..."); + var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + Console.WriteLine(); + Log("Upstream Response"); + var headers = new List>(); - Log($"StatusCode: {upstreamResponseMessage.StatusCode}"); + Log($"StatusCode: {upstreamResponseMessage.StatusCode}"); - Log("Headers:"); - foreach (var header in upstreamResponseMessage.Headers) - { - Log($" {header.Key}:{header.Value.First()}"); + Log("Headers:"); - // Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement - // your own chunking. - if (string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)) - { - continue; - } + foreach (var header in upstreamResponseMessage.Headers) + { + Log($" {header.Key}:{header.Value.First()}"); - headers.Add(new KeyValuePair(header.Key, header.Value.ToArray())); - } + // Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement + // your own chunking. + if (string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)) + { + continue; + } - foreach (var header in upstreamResponseMessage.Content.Headers) - { - Log($" {header.Key}:{header.Value.First()}"); - headers.Add(new KeyValuePair(header.Key, header.Value.ToArray())); - } + headers.Add(new KeyValuePair(header.Key, header.Value.ToArray())); + } - Log("Reading upstream response body..."); + foreach (var header in upstreamResponseMessage.Content.Headers) + { + Log($" {header.Key}:{header.Value.First()}"); + headers.Add(new KeyValuePair(header.Key, header.Value.ToArray())); + } - var upstreamResponse = new UpstreamResponse() - { - StatusCode = (int)upstreamResponseMessage.StatusCode, - Headers = headers.ToArray(), - Content = await upstreamResponseMessage.Content.ReadAsByteArrayAsync() - }; + Log($"ContentLength: {upstreamResponseMessage.Content.Headers.ContentLength}, is chunked: {upstreamResponseMessage.Headers.TransferEncodingChunked}"); - Log($"ContentLength: {upstreamResponse.Content.Length}"); + var response = await UpstreamResponse.FromContent(upstreamResponseMessage.Content, cancellationToken); + response.StatusCode = (int)upstreamResponseMessage.StatusCode; + response.Headers = headers.ToArray(); - return upstreamResponse; - } - } + return response; } private static async Task TryHandleResponseOption(string selection, HttpContext context, UpstreamResponse upstreamResponse) @@ -262,26 +260,26 @@ private static async Task TryHandleResponseOption(string selection, HttpCo { case "f": // Full response - await SendDownstreamResponse(upstreamResponse, context.Response); + await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength, context.RequestAborted); return true; case "p": // Partial Response (full headers, 50% of body), then wait indefinitely - await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2); + await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted); await Task.Delay(Timeout.InfiniteTimeSpan, context.RequestAborted); return true; case "pc": // Partial Response (full headers, 50% of body), then close (TCP FIN) - await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2); + await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted); Close(context); return true; case "pa": // Partial Response (full headers, 50% of body), then abort (TCP RST) - await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2); + await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted); Abort(context); return true; case "pn": // Partial Response (full headers, 50% of body), then finish normally - await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.Content.Length / 2); + await SendDownstreamResponse(upstreamResponse, context.Response, upstreamResponse.ContentLength / 2, context.RequestAborted); return true; case "n": // No response, then wait indefinitely @@ -301,7 +299,7 @@ private static async Task TryHandleResponseOption(string selection, HttpCo } } - private static async Task SendDownstreamResponse(UpstreamResponse upstreamResponse, HttpResponse response, int? contentBytes = null) + private static async Task SendDownstreamResponse(UpstreamResponse upstreamResponse, HttpResponse response, long? contentBytes, CancellationToken cancellationToken) { Console.WriteLine(); @@ -318,10 +316,41 @@ private static async Task SendDownstreamResponse(UpstreamResponse upstreamRespon response.Headers.Add(header.Key, header.Value); } - var count = contentBytes ?? upstreamResponse.Content.Length; + long count = contentBytes ?? long.MaxValue; + Log($"Writing response body of {count} bytes..."); - await response.Body.WriteAsync(upstreamResponse.Content, 0, count); - Log($"Finished writing response body"); + + using Stream source = await upstreamResponse.GetContentStreamAsync(cancellationToken); + byte[] buffer = new byte[8192]; + + try + { + for (long remaining = count; remaining > 0;) + { + int read = await source.ReadAsync(buffer, 0, (int)Math.Min(buffer.Length, remaining), cancellationToken); + if (read <= 0) + { + break; + } + + remaining -= read; + await response.Body.WriteAsync(buffer, 0, read, cancellationToken); + } + + await response.Body.FlushAsync(); + Log($"Finished writing response body"); + } + catch (Exception ex) + { + Log(ex.ToString()); + } + finally + { + // disponse content as early as possible (before infinite wait that might happen later) + // so that underlying connection returns to connection pool + // and we won't run out of them + upstreamResponse.Dispose(); + } } // Close the TCP connection by sending FIN @@ -345,11 +374,61 @@ private static void Log(object value) Console.WriteLine($"[{DateTime.Now:hh:mm:ss.fff}] {value}"); } - private class UpstreamResponse + private class UpstreamResponse : IDisposable { + private readonly HttpContent _content = null; + private readonly Stream _contentStream = null; + private UpstreamResponse(HttpContent content) + { + _content = content; + ContentLength = _content.Headers.ContentLength ?? throw new ArgumentNullException("ContentLength must not be null."); + } + + private UpstreamResponse(MemoryStream contentStream) + { + _contentStream = contentStream; + ContentLength = contentStream.Length; + } + public int StatusCode { get; set; } public KeyValuePair[] Headers { get; set; } - public byte[] Content { get; set; } + public async Task GetContentStreamAsync(CancellationToken cancellationToken) + { + return _contentStream != null ? _contentStream : await _content.ReadAsStreamAsync(cancellationToken); + } + + public long ContentLength { get; } + + public void Dispose() + { + _content?.Dispose(); + _contentStream?.Dispose(); + } + + public static async Task FromContent(HttpContent content, CancellationToken cancellationToken) + { + if (content.Headers.ContentLength == null) + { + MemoryStream contentStream = await BufferContentAsync(content, cancellationToken); + // we no longer need that content and can let the connection go back to the pool. + content.Dispose(); + return new UpstreamResponse(contentStream); + } + + return new UpstreamResponse(content); + } + + private static async Task BufferContentAsync(HttpContent content, CancellationToken cancellationToken) + { + Debug.Assert(content.Headers.ContentLength == null, "We should not buffer content if length is available."); + + Log("Response does not have content length (is chunked or malformed) and is being buffered"); + byte[] contentBytes = await content.ReadAsByteArrayAsync(cancellationToken); + Log($"Content was buffered, total length - {contentBytes.Length}."); + + // comparing to buffering full content memory stream allocation is not such a big deal. + return new MemoryStream(contentBytes); + } } } }