diff --git a/tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/FaultInjectingController.cs b/tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/FaultInjectingMiddleware.cs similarity index 86% rename from tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/FaultInjectingController.cs rename to tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/FaultInjectingMiddleware.cs index e31a5352c5c..25cecf94a8a 100644 --- a/tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/FaultInjectingController.cs +++ b/tools/http-fault-injector/Azure.Sdk.Tools.HttpFaultInjector/FaultInjectingMiddleware.cs @@ -20,23 +20,20 @@ public class FaultInjectingMiddleware { private readonly ILogger _logger; private readonly HttpClient _httpClient; - private readonly RequestDelegate _next; - public FaultInjectingMiddleware(RequestDelegate next, IHttpClientFactory httpClientFactory, ILogger logger) + public FaultInjectingMiddleware(RequestDelegate _, IHttpClientFactory httpClientFactory, ILogger logger) { _logger = logger; _httpClient = httpClientFactory.CreateClient("upstream"); - _next = next; } - public async Task InvokeAsync(HttpContext context, - CancellationToken cancellationToken) + public async Task InvokeAsync(HttpContext context) { string faultHeaderValue = context.Request.Headers[Utils.ResponseSelectionHeader]; string upstreamBaseUri = context.Request.Headers[Utils.UpstreamBaseUriHeader]; if (ValidateOrReadFaultMode(faultHeaderValue, out var fault)) { - await ProxyResponse(context, upstreamBaseUri, fault, cancellationToken); + await ProxyResponse(context, upstreamBaseUri, fault, context.RequestAborted); } else { @@ -87,21 +84,18 @@ private async Task SendUpstreamRequest(HttpRequest request, st } } - using (var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest)) - { - var headers = new List>>(); - // Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement - // your own chunking. - headers.AddRange(upstreamResponseMessage.Headers.Where(header => !string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase))); - headers.AddRange(upstreamResponseMessage.Content.Headers); - - var upstreamResponse = await UpstreamResponseFromHttpContent(upstreamResponseMessage.Content, cancellationToken); - upstreamResponse.StatusCode = (int)upstreamResponseMessage.StatusCode; - upstreamResponse.Headers = headers.Select(h => new KeyValuePair(h.Key, h.Value.ToArray())); - _logger.LogInformation("Finished reading response body ({length})", upstreamResponse.ContentLength); - - return upstreamResponse; - } + var upstreamResponseMessage = await _httpClient.SendAsync(upstreamRequest); + var headers = new List>>(); + // Must skip "Transfer-Encoding" header, since if it's set manually Kestrel requires you to implement + // your own chunking. + headers.AddRange(upstreamResponseMessage.Headers.Where(header => !string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase))); + headers.AddRange(upstreamResponseMessage.Content.Headers); + + var upstreamResponse = await UpstreamResponseFromHttpContent(upstreamResponseMessage.Content, cancellationToken); + upstreamResponse.StatusCode = (int)upstreamResponseMessage.StatusCode; + upstreamResponse.Headers = headers.Select(h => new KeyValuePair(h.Key, h.Value.ToArray())); + + return upstreamResponse; } } @@ -125,6 +119,7 @@ private async Task BufferContentAsync(HttpContent content, Cancell _logger.LogWarning("Response does not have content length (is chunked or malformed) and is being buffered"); byte[] contentBytes = await content.ReadAsByteArrayAsync(cancellationToken); + _logger.LogInformation("Finished buffering response body ({length})", contentBytes.Length); return new MemoryStream(contentBytes); }