From 1357aeca506732bc0dcebe53c2a6f44029713407 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Tue, 19 Oct 2021 12:28:38 +0200 Subject: [PATCH 1/3] Pool CancellationTokenSources --- src/ReverseProxy/Forwarder/HttpForwarder.cs | 33 +++++---- src/ReverseProxy/Forwarder/StreamCopier.cs | 6 +- .../Forwarder/StreamCopyHttpContent.cs | 18 ++--- .../ActivityCancellationTokenSource.cs | 65 ++++++++++++++++++ .../Forwarder/StreamCopierTests.cs | 26 ++++--- .../Forwarder/StreamCopyHttpContentTests.cs | 10 +-- .../ActivityCancellationTokenSourceTests.cs | 67 +++++++++++++++++++ 7 files changed, 172 insertions(+), 53 deletions(-) create mode 100644 src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs create mode 100644 test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs diff --git a/src/ReverseProxy/Forwarder/HttpForwarder.cs b/src/ReverseProxy/Forwarder/HttpForwarder.cs index 281f8a916..f94996947 100644 --- a/src/ReverseProxy/Forwarder/HttpForwarder.cs +++ b/src/ReverseProxy/Forwarder/HttpForwarder.cs @@ -106,9 +106,7 @@ public async ValueTask SendAsync( ForwarderTelemetry.Log.ForwarderStart(destinationPrefix); - var activityCancellationSource = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted); - var activityTimeout = requestConfig?.ActivityTimeout ?? DefaultTimeout; - activityCancellationSource.CancelAfter(activityTimeout); + var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted); try { var isClientHttp2 = ProtocolHelper.IsHttp2(context.Request.Protocol); @@ -119,7 +117,7 @@ public async ValueTask SendAsync( // :: Step 1-3: Create outgoing HttpRequestMessage var (destinationRequest, requestContent) = await CreateRequestMessageAsync( - context, destinationPrefix, transformer, requestConfig, isStreamingRequest, activityCancellationSource, activityTimeout); + context, destinationPrefix, transformer, requestConfig, isStreamingRequest, activityCancellationSource); // :: Step 4: Send the outgoing request using HttpClient HttpResponseMessage destinationResponse; @@ -130,7 +128,7 @@ public async ValueTask SendAsync( ForwarderTelemetry.Log.ForwarderStage(ForwarderStage.SendAsyncStop); // Reset the timeout since we received the response headers. - activityCancellationSource.CancelAfter(activityTimeout); + activityCancellationSource.ResetTimeout(); } catch (Exception requestException) { @@ -185,7 +183,7 @@ public async ValueTask SendAsync( if (destinationResponse.StatusCode == HttpStatusCode.SwitchingProtocols) { Debug.Assert(requestContent?.Started != true); - return await HandleUpgradedResponse(context, destinationResponse, activityCancellationSource, activityTimeout); + return await HandleUpgradedResponse(context, destinationResponse, activityCancellationSource); } // NOTE: it may *seem* wise to call `context.Response.StartAsync()` at this point @@ -199,7 +197,7 @@ public async ValueTask SendAsync( // and clients misbehave if the initial headers response does not indicate stream end. // :: Step 7-B: Copy response body Client ◄-- Proxy ◄-- Destination - var (responseBodyCopyResult, responseBodyException) = await CopyResponseBodyAsync(destinationResponse.Content, context.Response.Body, activityCancellationSource, activityTimeout); + var (responseBodyCopyResult, responseBodyException) = await CopyResponseBodyAsync(destinationResponse.Content, context.Response.Body, activityCancellationSource); if (responseBodyCopyResult != StreamCopyResult.Success) { @@ -248,7 +246,7 @@ public async ValueTask SendAsync( } finally { - activityCancellationSource.Dispose(); + activityCancellationSource.Return(); ForwarderTelemetry.Log.ForwarderStop(context.Response.StatusCode); } @@ -256,7 +254,7 @@ public async ValueTask SendAsync( } private async ValueTask<(HttpRequestMessage, StreamCopyHttpContent?)> CreateRequestMessageAsync(HttpContext context, string destinationPrefix, - HttpTransformer transformer, ForwarderRequestConfig? requestConfig, bool isStreamingRequest, CancellationTokenSource activityToken, TimeSpan activityTimeout) + HttpTransformer transformer, ForwarderRequestConfig? requestConfig, bool isStreamingRequest, ActivityCancellationTokenSource activityToken) { // "http://a".Length = 8 if (destinationPrefix == null || destinationPrefix.Length < 8) @@ -288,7 +286,7 @@ public async ValueTask SendAsync( // :: Step 2: Setup copy of request body (background) Client --► Proxy --► Destination // Note that we must do this before step (3) because step (3) may also add headers to the HttpContent that we set up here. - var requestContent = SetupRequestBodyCopy(context.Request, isStreamingRequest, activityToken, activityTimeout); + var requestContent = SetupRequestBodyCopy(context.Request, isStreamingRequest, activityToken); destinationRequest.Content = requestContent; // :: Step 3: Copy request headers Client --► Proxy --► Destination @@ -334,7 +332,7 @@ private static void RestoreUpgradeHeaders(HttpContext context, HttpRequestMessag } } - private StreamCopyHttpContent? SetupRequestBodyCopy(HttpRequest request, bool isStreamingRequest, CancellationTokenSource activityToken, TimeSpan activityTimeout) + private StreamCopyHttpContent? SetupRequestBodyCopy(HttpRequest request, bool isStreamingRequest, ActivityCancellationTokenSource activityToken) { // If we generate an HttpContent without a Content-Length then for HTTP/1.1 HttpClient will add a Transfer-Encoding: chunked header // even if it's a GET request. Some servers reject requests containing a Transfer-Encoding header if they're not expecting a body. @@ -410,8 +408,7 @@ private static void RestoreUpgradeHeaders(HttpContext context, HttpRequestMessag source: request.Body, autoFlushHttpClientOutgoingStream: isStreamingRequest, clock: _clock, - activityToken, - activityTimeout); + activityToken); } return null; @@ -547,7 +544,7 @@ private static void RestoreUpgradeHeaders(HttpContext context, HttpResponseMessa } private async ValueTask HandleUpgradedResponse(HttpContext context, HttpResponseMessage destinationResponse, - CancellationTokenSource activityCancellationSource, TimeSpan activityTimeout) + ActivityCancellationTokenSource activityCancellationSource) { ForwarderTelemetry.Log.ForwarderStage(ForwarderStage.ResponseUpgrade); @@ -588,8 +585,8 @@ private async ValueTask HandleUpgradedResponse(HttpContext conte // :: Step 7-A-2: Copy duplex streams using var destinationStream = await destinationResponse.Content.ReadAsStreamAsync(); - var requestTask = StreamCopier.CopyAsync(isRequest: true, clientStream, destinationStream, _clock, activityCancellationSource, activityTimeout).AsTask(); - var responseTask = StreamCopier.CopyAsync(isRequest: false, destinationStream, clientStream, _clock, activityCancellationSource, activityTimeout).AsTask(); + var requestTask = StreamCopier.CopyAsync(isRequest: true, clientStream, destinationStream, _clock, activityCancellationSource).AsTask(); + var responseTask = StreamCopier.CopyAsync(isRequest: false, destinationStream, clientStream, _clock, activityCancellationSource).AsTask(); // Make sure we report the first failure. var firstTask = await Task.WhenAny(requestTask, responseTask); @@ -637,7 +634,7 @@ ForwarderError ReportResult(HttpContext context, bool reqeuest, StreamCopyResult } private async ValueTask<(StreamCopyResult, Exception?)> CopyResponseBodyAsync(HttpContent destinationResponseContent, Stream clientResponseStream, - CancellationTokenSource activityCancellationSource, TimeSpan activityTimeout) + ActivityCancellationTokenSource activityCancellationSource) { // SocketHttpHandler and similar transports always provide an HttpContent object, even if it's empty. // In 3.1 this is only likely to return null in tests. @@ -646,7 +643,7 @@ ForwarderError ReportResult(HttpContext context, bool reqeuest, StreamCopyResult if (destinationResponseContent != null) { using var destinationResponseStream = await destinationResponseContent.ReadAsStreamAsync(); - return await StreamCopier.CopyAsync(isRequest: false, destinationResponseStream, clientResponseStream, _clock, activityCancellationSource, activityTimeout); + return await StreamCopier.CopyAsync(isRequest: false, destinationResponseStream, clientResponseStream, _clock, activityCancellationSource); } return (StreamCopyResult.Success, null); diff --git a/src/ReverseProxy/Forwarder/StreamCopier.cs b/src/ReverseProxy/Forwarder/StreamCopier.cs index 65c6e8569..88bb9fc66 100644 --- a/src/ReverseProxy/Forwarder/StreamCopier.cs +++ b/src/ReverseProxy/Forwarder/StreamCopier.cs @@ -25,7 +25,7 @@ internal static class StreamCopier /// Based on Microsoft.AspNetCore.Http.StreamCopyOperationInternal.CopyToAsync. /// See: . /// - public static async ValueTask<(StreamCopyResult, Exception?)> CopyAsync(bool isRequest, Stream input, Stream output, IClock clock, CancellationTokenSource activityToken, TimeSpan activityTimeout) + public static async ValueTask<(StreamCopyResult, Exception?)> CopyAsync(bool isRequest, Stream input, Stream output, IClock clock, ActivityCancellationTokenSource activityToken) { _ = input ?? throw new ArgumentNullException(nameof(input)); _ = output ?? throw new ArgumentNullException(nameof(output)); @@ -69,7 +69,7 @@ internal static class StreamCopier read = await input.ReadAsync(buffer.AsMemory(), cancellation); // Success, reset the activity monitor. - activityToken.CancelAfter(activityTimeout); + activityToken.ResetTimeout(); } finally { @@ -106,7 +106,7 @@ internal static class StreamCopier await output.WriteAsync(buffer.AsMemory(0, read), cancellation); // Success, reset the activity monitor. - activityToken.CancelAfter(activityTimeout); + activityToken.ResetTimeout(); } finally { diff --git a/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs b/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs index df4b7d3e8..6053a6ad4 100644 --- a/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs +++ b/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs @@ -41,19 +41,17 @@ internal sealed class StreamCopyHttpContent : HttpContent private readonly Stream _source; private readonly bool _autoFlushHttpClientOutgoingStream; private readonly IClock _clock; - private readonly CancellationTokenSource _activityToken; - private readonly TimeSpan _activityTimeout; + private readonly ActivityCancellationTokenSource _activityToken; private readonly TaskCompletionSource<(StreamCopyResult, Exception?)> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); private int _started; - public StreamCopyHttpContent(Stream source, bool autoFlushHttpClientOutgoingStream, IClock clock, CancellationTokenSource activityToken, TimeSpan activityTimeout) + public StreamCopyHttpContent(Stream source, bool autoFlushHttpClientOutgoingStream, IClock clock, ActivityCancellationTokenSource activityToken) { _source = source ?? throw new ArgumentNullException(nameof(source)); _autoFlushHttpClientOutgoingStream = autoFlushHttpClientOutgoingStream; _clock = clock ?? throw new ArgumentNullException(nameof(clock)); _activityToken = activityToken; - _activityTimeout = activityTimeout; } /// @@ -145,17 +143,11 @@ async Task SerializeToStreamAsync(Stream stream, TransportContext? context, Canc if (_activityToken.Token != cancellationToken) { Debug.Assert(cancellationToken.CanBeCanceled); - registration = cancellationToken.Register(static obj => - { - var activityToken = (CancellationTokenSource)obj!; - activityToken.Cancel(); - }, _activityToken); - cancellationToken = _activityToken.Token; + registration = cancellationToken.UnsafeRegister(ActivityCancellationTokenSource.LinkedTokenCancelDelegate, _activityToken); } #else // On .NET Core 3.1, cancellationToken will always be CancellationToken.None Debug.Assert(!cancellationToken.CanBeCanceled); - cancellationToken = _activityToken.Token; #endif try @@ -174,7 +166,7 @@ async Task SerializeToStreamAsync(Stream stream, TransportContext? context, Canc // https://github.com/dotnet/corefx/issues/39586#issuecomment-516210081 try { - await stream.FlushAsync(cancellationToken); + await stream.FlushAsync(_activityToken.Token); } catch (OperationCanceledException oex) { @@ -187,7 +179,7 @@ async Task SerializeToStreamAsync(Stream stream, TransportContext? context, Canc return; } - var (result, error) = await StreamCopier.CopyAsync(isRequest: true, _source, stream, _clock, _activityToken, _activityTimeout); + var (result, error) = await StreamCopier.CopyAsync(isRequest: true, _source, stream, _clock, _activityToken); _tcs.TrySetResult((result, error)); // Check for errors that weren't the result of the destination failing. diff --git a/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs b/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs new file mode 100644 index 000000000..8b3b0d105 --- /dev/null +++ b/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Concurrent; +using System.Threading; + +namespace Yarp.ReverseProxy.Utilities +{ + internal sealed class ActivityCancellationTokenSource : CancellationTokenSource + { +#if NET6_0_OR_GREATER + private static readonly ConcurrentQueue _sharedSources = new(); +#endif + + public static readonly Action LinkedTokenCancelDelegate = static s => + { + ((ActivityCancellationTokenSource)s!).Cancel(throwOnFirstException: false); + }; + + private int _activityTimeoutMs; + private CancellationTokenRegistration _linkedRegistration; + + private ActivityCancellationTokenSource() { } + + public void ResetTimeout() + { + CancelAfter(_activityTimeoutMs); + } + + public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken) + { +#if NET6_0_OR_GREATER + if (!_sharedSources.TryDequeue(out var cts)) + { + cts = new ActivityCancellationTokenSource(); + } +#else + var cts = new ActivityCancellationTokenSource(); +#endif + + cts._activityTimeoutMs = (int)activityTimeout.TotalMilliseconds; + cts._linkedRegistration = linkedToken.UnsafeRegister(LinkedTokenCancelDelegate, cts); + cts.ResetTimeout(); + + return cts; + } + + public void Return() + { + _linkedRegistration.Dispose(); + _linkedRegistration = default; + +#if NET6_0_OR_GREATER + if (TryReset()) + { + _sharedSources.Enqueue(this); + return; + } +#endif + + Dispose(); + } + } +} diff --git a/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs b/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs index d49c78bbe..7f91064e8 100644 --- a/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs @@ -28,8 +28,8 @@ public async Task CopyAsync_Works(bool isRequest) var source = new MemoryStream(sourceBytes); var destination = new MemoryStream(); - using var cts = new CancellationTokenSource(); - await StreamCopier.CopyAsync(isRequest, source, destination, new Clock(), cts, TimeSpan.FromSeconds(10)); + using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); + await StreamCopier.CopyAsync(isRequest, source, destination, new Clock(), cts); Assert.False(cts.Token.IsCancellationRequested); @@ -50,8 +50,8 @@ public async Task SourceThrows_Reported(bool isRequest) var source = new SlowStream(new ThrowStream(), clock, sourceWaitTime); var destination = new MemoryStream(); - using var cts = new CancellationTokenSource(); - var (result, error) = await StreamCopier.CopyAsync(isRequest, source, destination, clock, cts, TimeSpan.FromSeconds(10)); + using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); + var (result, error) = await StreamCopier.CopyAsync(isRequest, source, destination, clock, cts); Assert.Equal(StreamCopyResult.InputError, result); Assert.IsAssignableFrom(error); @@ -79,8 +79,8 @@ public async Task DestinationThrows_Reported(bool isRequest) var source = new SlowStream(new MemoryStream(new byte[SourceSize]), clock, sourceWaitTime) { MaxBytesPerRead = BytesPerRead }; var destination = new SlowStream(new ThrowStream(), clock, destinationWaitTime); - using var cts = new CancellationTokenSource(); - var (result, error) = await StreamCopier.CopyAsync(isRequest, source, destination, clock, cts, TimeSpan.FromSeconds(10)); + using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); + var (result, error) = await StreamCopier.CopyAsync(isRequest, source, destination, clock, cts); Assert.Equal(StreamCopyResult.OutputError, result); Assert.IsAssignableFrom(error); @@ -102,9 +102,9 @@ public async Task Cancelled_Reported(bool isRequest) var source = new MemoryStream(new byte[10]); var destination = new MemoryStream(); - using var cts = new CancellationTokenSource(); + using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); cts.Cancel(); - var (result, error) = await StreamCopier.CopyAsync(isRequest, source, destination, new Clock(), cts, TimeSpan.FromSeconds(10)); + var (result, error) = await StreamCopier.CopyAsync(isRequest, source, destination, new Clock(), cts); Assert.Equal(StreamCopyResult.Canceled, result); Assert.IsAssignableFrom(error); @@ -132,14 +132,13 @@ public async Task SlowStreams_TelemetryReportsCorrectTime(bool isRequest) var sourceWaitTime = TimeSpan.FromMilliseconds(12345); var destinationWaitTime = TimeSpan.FromMilliseconds(42); - using var cts = new CancellationTokenSource(); + using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); await StreamCopier.CopyAsync( isRequest, new SlowStream(source, clock, sourceWaitTime), new SlowStream(destination, clock, destinationWaitTime), clock, - cts, - TimeSpan.FromSeconds(10)); + cts); Assert.Equal(sourceBytes, destination.ToArray()); @@ -169,14 +168,13 @@ public async Task LongContentTransfer_TelemetryReportsTransferringEvents(bool is const int BytesPerRead = 3; var contentReads = (int)Math.Ceiling((double)SourceSize / BytesPerRead); - using var cts = new CancellationTokenSource(); + using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); await StreamCopier.CopyAsync( isRequest, new SlowStream(source, clock, sourceWaitTime) { MaxBytesPerRead = BytesPerRead }, new SlowStream(destination, clock, destinationWaitTime), clock, - cts, - TimeSpan.FromSeconds(10)); + cts); Assert.Equal(sourceBytes, destination.ToArray()); diff --git a/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs b/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs index 9a842e3e5..52850f0f1 100644 --- a/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs @@ -17,14 +17,14 @@ namespace Yarp.ReverseProxy.Forwarder.Tests { public class StreamCopyHttpContentTests { - private static StreamCopyHttpContent CreateContent(Stream source = null, bool autoFlushHttpClientOutgoingStream = false, IClock clock = null, CancellationTokenSource contentCancellation = null) + private static StreamCopyHttpContent CreateContent(Stream source = null, bool autoFlushHttpClientOutgoingStream = false, IClock clock = null, ActivityCancellationTokenSource contentCancellation = null) { source ??= new MemoryStream(); clock ??= new Clock(); - contentCancellation ??= new CancellationTokenSource(); + contentCancellation ??= ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); - return new StreamCopyHttpContent(source, autoFlushHttpClientOutgoingStream, clock, contentCancellation, TimeSpan.FromSeconds(10)); + return new StreamCopyHttpContent(source, autoFlushHttpClientOutgoingStream, clock, contentCancellation); } [Fact] @@ -141,7 +141,7 @@ public async Task SerializeToStreamAsync_RespectsContentCancellation() return 0; }); - var contentCts = new CancellationTokenSource(); + using var contentCts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); var sut = CreateContent(source, contentCancellation: contentCts); @@ -168,7 +168,7 @@ public async Task SerializeToStreamAsync_CanBeCanceledExternally() var sut = CreateContent(source); - var cts = new CancellationTokenSource(); + using var cts = new CancellationTokenSource(); var copyToTask = sut.CopyToAsync(new MemoryStream(), cts.Token); cts.Cancel(); tcs.SetResult(0); diff --git a/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs b/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs new file mode 100644 index 000000000..dfd6c8267 --- /dev/null +++ b/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Yarp.ReverseProxy.Utilities.Tests +{ + public class ActivityCancellationTokenSourceTests + { + [Fact] + public void ActivityCancellationTokenSource_PoolsSources() + { + var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); + + cts.Return(); + + var cts2 = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); + + Assert.Same(cts, cts2); + } + + [Fact] + public void ActivityCancellationTokenSource_RespectsLinkedToken() + { + var linkedCts = new CancellationTokenSource(); + + var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), linkedCts.Token); + linkedCts.Cancel(); + + Assert.True(cts.IsCancellationRequested); + } + + [Fact] + public void ActivityCancellationTokenSource_ClearsRegistrations() + { + var linkedCts = new CancellationTokenSource(); + + var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), linkedCts.Token); + cts.Return(); + + linkedCts.Cancel(); + + Assert.False(cts.IsCancellationRequested); + } + + [Fact] + public async Task ActivityCancellationTokenSource_RespectsTimeout() + { + var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromMilliseconds(1), CancellationToken.None); + + for (var i = 0; i < 1000; i++) + { + if (cts.IsCancellationRequested) + { + return; + } + + await Task.Delay(1); + } + + throw new TimeoutException("Cts was not canceled"); + } + } +} From 2a506228fc17630a7f55a7b8f62ffe2220f87dd2 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Tue, 19 Oct 2021 14:15:20 +0200 Subject: [PATCH 2/3] Don't assert the sources are pooled before 6.0 --- .../ActivityCancellationTokenSourceTests.cs | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs b/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs index dfd6c8267..cf48bb949 100644 --- a/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs +++ b/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs @@ -10,16 +10,40 @@ namespace Yarp.ReverseProxy.Utilities.Tests { public class ActivityCancellationTokenSourceTests { +#if NET6_0_OR_GREATER // CancellationTokenSource.TryReset() is only available in 6.0+ [Fact] public void ActivityCancellationTokenSource_PoolsSources() { - var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); + // This test can run in parallel with others making use of ActivityCancellationTokenSource + // A different thread could have already added/removed a source from the queue - cts.Return(); + for (var i = 0; i < 1000; i++) + { + var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); + cts.Return(); + + var cts2 = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); + cts2.Return(); + + if (ReferenceEquals(cts, cts2)) + { + return; + } + } + + Assert.True(false, "CancellationTokenSources were not pooled"); + } +#endif + + [Fact] + public void ActivityCancellationTokenSource_DoesNotPoolsCanceledSources() + { + var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); + cts.Cancel(); var cts2 = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); - Assert.Same(cts, cts2); + Assert.NotSame(cts, cts2); } [Fact] @@ -61,7 +85,7 @@ public async Task ActivityCancellationTokenSource_RespectsTimeout() await Task.Delay(1); } - throw new TimeoutException("Cts was not canceled"); + Assert.True(false, "Cts was not canceled"); } } } From 733e986b87103f74fc753f4762a24b6392ffaa75 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Wed, 20 Oct 2021 04:54:07 +0200 Subject: [PATCH 3/3] Limit pool capacity, add LinkTo helper --- .../Forwarder/StreamCopyHttpContent.cs | 2 +- .../ActivityCancellationTokenSource.cs | 26 +++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs b/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs index 6053a6ad4..31dd00eca 100644 --- a/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs +++ b/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs @@ -143,7 +143,7 @@ async Task SerializeToStreamAsync(Stream stream, TransportContext? context, Canc if (_activityToken.Token != cancellationToken) { Debug.Assert(cancellationToken.CanBeCanceled); - registration = cancellationToken.UnsafeRegister(ActivityCancellationTokenSource.LinkedTokenCancelDelegate, _activityToken); + registration = _activityToken.LinkTo(cancellationToken); } #else // On .NET Core 3.1, cancellationToken will always be CancellationToken.None diff --git a/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs b/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs index 8b3b0d105..f5f6d48ec 100644 --- a/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs +++ b/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs @@ -10,10 +10,12 @@ namespace Yarp.ReverseProxy.Utilities internal sealed class ActivityCancellationTokenSource : CancellationTokenSource { #if NET6_0_OR_GREATER + private const int MaxQueueSize = 1024; private static readonly ConcurrentQueue _sharedSources = new(); + private static int _count; #endif - public static readonly Action LinkedTokenCancelDelegate = static s => + private static readonly Action _linkedTokenCancelDelegate = static s => { ((ActivityCancellationTokenSource)s!).Cancel(throwOnFirstException: false); }; @@ -31,7 +33,11 @@ public void ResetTimeout() public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken) { #if NET6_0_OR_GREATER - if (!_sharedSources.TryDequeue(out var cts)) + if (_sharedSources.TryDequeue(out var cts)) + { + Interlocked.Decrement(ref _count); + } + else { cts = new ActivityCancellationTokenSource(); } @@ -40,7 +46,7 @@ public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, Can #endif cts._activityTimeoutMs = (int)activityTimeout.TotalMilliseconds; - cts._linkedRegistration = linkedToken.UnsafeRegister(LinkedTokenCancelDelegate, cts); + cts._linkedRegistration = cts.LinkTo(linkedToken); cts.ResetTimeout(); return cts; @@ -54,12 +60,22 @@ public void Return() #if NET6_0_OR_GREATER if (TryReset()) { - _sharedSources.Enqueue(this); - return; + if (Interlocked.Increment(ref _count) <= MaxQueueSize) + { + _sharedSources.Enqueue(this); + return; + } + + Interlocked.Decrement(ref _count); } #endif Dispose(); } + + public CancellationTokenRegistration LinkTo(CancellationToken linkedToken) + { + return linkedToken.UnsafeRegister(_linkedTokenCancelDelegate, this); + } } }