Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pool CancellationTokenSources #1297

Merged
merged 3 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions src/ReverseProxy/Forwarder/HttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ public async ValueTask<ForwarderError> SendAsync(

ForwarderTelemetry.Log.ForwarderStart(destinationPrefix);

var activityCancellationSource = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted);
var activityTimeout = requestConfig?.ActivityTimeout ?? DefaultTimeout;
activityCancellationSource.CancelAfter(activityTimeout);
var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted);
try
{
var isClientHttp2 = ProtocolHelper.IsHttp2(context.Request.Protocol);
Expand All @@ -119,7 +117,7 @@ public async ValueTask<ForwarderError> SendAsync(

// :: Step 1-3: Create outgoing HttpRequestMessage
var (destinationRequest, requestContent) = await CreateRequestMessageAsync(
context, destinationPrefix, transformer, requestConfig, isStreamingRequest, activityCancellationSource, activityTimeout);
context, destinationPrefix, transformer, requestConfig, isStreamingRequest, activityCancellationSource);

// :: Step 4: Send the outgoing request using HttpClient
HttpResponseMessage destinationResponse;
Expand All @@ -130,7 +128,7 @@ public async ValueTask<ForwarderError> SendAsync(
ForwarderTelemetry.Log.ForwarderStage(ForwarderStage.SendAsyncStop);

// Reset the timeout since we received the response headers.
activityCancellationSource.CancelAfter(activityTimeout);
activityCancellationSource.ResetTimeout();
}
catch (Exception requestException)
{
Expand Down Expand Up @@ -185,7 +183,7 @@ public async ValueTask<ForwarderError> SendAsync(
if (destinationResponse.StatusCode == HttpStatusCode.SwitchingProtocols)
{
Debug.Assert(requestContent?.Started != true);
return await HandleUpgradedResponse(context, destinationResponse, activityCancellationSource, activityTimeout);
return await HandleUpgradedResponse(context, destinationResponse, activityCancellationSource);
}

// NOTE: it may *seem* wise to call `context.Response.StartAsync()` at this point
Expand All @@ -199,7 +197,7 @@ public async ValueTask<ForwarderError> SendAsync(
// and clients misbehave if the initial headers response does not indicate stream end.

// :: Step 7-B: Copy response body Client ◄-- Proxy ◄-- Destination
var (responseBodyCopyResult, responseBodyException) = await CopyResponseBodyAsync(destinationResponse.Content, context.Response.Body, activityCancellationSource, activityTimeout);
var (responseBodyCopyResult, responseBodyException) = await CopyResponseBodyAsync(destinationResponse.Content, context.Response.Body, activityCancellationSource);

if (responseBodyCopyResult != StreamCopyResult.Success)
{
Expand Down Expand Up @@ -248,15 +246,15 @@ public async ValueTask<ForwarderError> SendAsync(
}
finally
{
activityCancellationSource.Dispose();
activityCancellationSource.Return();
ForwarderTelemetry.Log.ForwarderStop(context.Response.StatusCode);
}

return ForwarderError.None;
}

private async ValueTask<(HttpRequestMessage, StreamCopyHttpContent?)> CreateRequestMessageAsync(HttpContext context, string destinationPrefix,
HttpTransformer transformer, ForwarderRequestConfig? requestConfig, bool isStreamingRequest, CancellationTokenSource activityToken, TimeSpan activityTimeout)
HttpTransformer transformer, ForwarderRequestConfig? requestConfig, bool isStreamingRequest, ActivityCancellationTokenSource activityToken)
{
// "http://a".Length = 8
if (destinationPrefix == null || destinationPrefix.Length < 8)
Expand Down Expand Up @@ -288,7 +286,7 @@ public async ValueTask<ForwarderError> SendAsync(

// :: Step 2: Setup copy of request body (background) Client --► Proxy --► Destination
// Note that we must do this before step (3) because step (3) may also add headers to the HttpContent that we set up here.
var requestContent = SetupRequestBodyCopy(context.Request, isStreamingRequest, activityToken, activityTimeout);
var requestContent = SetupRequestBodyCopy(context.Request, isStreamingRequest, activityToken);
destinationRequest.Content = requestContent;

// :: Step 3: Copy request headers Client --► Proxy --► Destination
Expand Down Expand Up @@ -334,7 +332,7 @@ private static void RestoreUpgradeHeaders(HttpContext context, HttpRequestMessag
}
}

private StreamCopyHttpContent? SetupRequestBodyCopy(HttpRequest request, bool isStreamingRequest, CancellationTokenSource activityToken, TimeSpan activityTimeout)
private StreamCopyHttpContent? SetupRequestBodyCopy(HttpRequest request, bool isStreamingRequest, ActivityCancellationTokenSource activityToken)
{
// If we generate an HttpContent without a Content-Length then for HTTP/1.1 HttpClient will add a Transfer-Encoding: chunked header
// even if it's a GET request. Some servers reject requests containing a Transfer-Encoding header if they're not expecting a body.
Expand Down Expand Up @@ -410,8 +408,7 @@ private static void RestoreUpgradeHeaders(HttpContext context, HttpRequestMessag
source: request.Body,
autoFlushHttpClientOutgoingStream: isStreamingRequest,
clock: _clock,
activityToken,
activityTimeout);
activityToken);
}

return null;
Expand Down Expand Up @@ -547,7 +544,7 @@ private static void RestoreUpgradeHeaders(HttpContext context, HttpResponseMessa
}

private async ValueTask<ForwarderError> HandleUpgradedResponse(HttpContext context, HttpResponseMessage destinationResponse,
CancellationTokenSource activityCancellationSource, TimeSpan activityTimeout)
ActivityCancellationTokenSource activityCancellationSource)
{
ForwarderTelemetry.Log.ForwarderStage(ForwarderStage.ResponseUpgrade);

Expand Down Expand Up @@ -588,8 +585,8 @@ private async ValueTask<ForwarderError> HandleUpgradedResponse(HttpContext conte
// :: Step 7-A-2: Copy duplex streams
using var destinationStream = await destinationResponse.Content.ReadAsStreamAsync();

var requestTask = StreamCopier.CopyAsync(isRequest: true, clientStream, destinationStream, _clock, activityCancellationSource, activityTimeout).AsTask();
var responseTask = StreamCopier.CopyAsync(isRequest: false, destinationStream, clientStream, _clock, activityCancellationSource, activityTimeout).AsTask();
var requestTask = StreamCopier.CopyAsync(isRequest: true, clientStream, destinationStream, _clock, activityCancellationSource).AsTask();
var responseTask = StreamCopier.CopyAsync(isRequest: false, destinationStream, clientStream, _clock, activityCancellationSource).AsTask();

// Make sure we report the first failure.
var firstTask = await Task.WhenAny(requestTask, responseTask);
Expand Down Expand Up @@ -637,7 +634,7 @@ ForwarderError ReportResult(HttpContext context, bool reqeuest, StreamCopyResult
}

private async ValueTask<(StreamCopyResult, Exception?)> CopyResponseBodyAsync(HttpContent destinationResponseContent, Stream clientResponseStream,
CancellationTokenSource activityCancellationSource, TimeSpan activityTimeout)
ActivityCancellationTokenSource activityCancellationSource)
{
// SocketHttpHandler and similar transports always provide an HttpContent object, even if it's empty.
// In 3.1 this is only likely to return null in tests.
Expand All @@ -646,7 +643,7 @@ ForwarderError ReportResult(HttpContext context, bool reqeuest, StreamCopyResult
if (destinationResponseContent != null)
{
using var destinationResponseStream = await destinationResponseContent.ReadAsStreamAsync();
return await StreamCopier.CopyAsync(isRequest: false, destinationResponseStream, clientResponseStream, _clock, activityCancellationSource, activityTimeout);
return await StreamCopier.CopyAsync(isRequest: false, destinationResponseStream, clientResponseStream, _clock, activityCancellationSource);
}

return (StreamCopyResult.Success, null);
Expand Down
6 changes: 3 additions & 3 deletions src/ReverseProxy/Forwarder/StreamCopier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ internal static class StreamCopier
/// Based on <c>Microsoft.AspNetCore.Http.StreamCopyOperationInternal.CopyToAsync</c>.
/// See: <see href="https://github.com/dotnet/aspnetcore/blob/080660967b6043f731d4b7163af9e9e6047ef0c4/src/Http/Shared/StreamCopyOperationInternal.cs"/>.
/// </remarks>
public static async ValueTask<(StreamCopyResult, Exception?)> CopyAsync(bool isRequest, Stream input, Stream output, IClock clock, CancellationTokenSource activityToken, TimeSpan activityTimeout)
public static async ValueTask<(StreamCopyResult, Exception?)> CopyAsync(bool isRequest, Stream input, Stream output, IClock clock, ActivityCancellationTokenSource activityToken)
{
_ = input ?? throw new ArgumentNullException(nameof(input));
_ = output ?? throw new ArgumentNullException(nameof(output));
Expand Down Expand Up @@ -69,7 +69,7 @@ internal static class StreamCopier
read = await input.ReadAsync(buffer.AsMemory(), cancellation);

// Success, reset the activity monitor.
activityToken.CancelAfter(activityTimeout);
activityToken.ResetTimeout();
}
finally
{
Expand Down Expand Up @@ -106,7 +106,7 @@ internal static class StreamCopier
await output.WriteAsync(buffer.AsMemory(0, read), cancellation);

// Success, reset the activity monitor.
activityToken.CancelAfter(activityTimeout);
activityToken.ResetTimeout();
}
finally
{
Expand Down
18 changes: 5 additions & 13 deletions src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,17 @@ internal sealed class StreamCopyHttpContent : HttpContent
private readonly Stream _source;
private readonly bool _autoFlushHttpClientOutgoingStream;
private readonly IClock _clock;
private readonly CancellationTokenSource _activityToken;
private readonly TimeSpan _activityTimeout;
private readonly ActivityCancellationTokenSource _activityToken;
private readonly TaskCompletionSource<(StreamCopyResult, Exception?)> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
private int _started;

public StreamCopyHttpContent(Stream source, bool autoFlushHttpClientOutgoingStream, IClock clock, CancellationTokenSource activityToken, TimeSpan activityTimeout)
public StreamCopyHttpContent(Stream source, bool autoFlushHttpClientOutgoingStream, IClock clock, ActivityCancellationTokenSource activityToken)
{
_source = source ?? throw new ArgumentNullException(nameof(source));
_autoFlushHttpClientOutgoingStream = autoFlushHttpClientOutgoingStream;
_clock = clock ?? throw new ArgumentNullException(nameof(clock));

_activityToken = activityToken;
_activityTimeout = activityTimeout;
}

/// <summary>
Expand Down Expand Up @@ -145,17 +143,11 @@ async Task SerializeToStreamAsync(Stream stream, TransportContext? context, Canc
if (_activityToken.Token != cancellationToken)
{
Debug.Assert(cancellationToken.CanBeCanceled);
registration = cancellationToken.Register(static obj =>
{
var activityToken = (CancellationTokenSource)obj!;
activityToken.Cancel();
}, _activityToken);
cancellationToken = _activityToken.Token;
registration = _activityToken.LinkTo(cancellationToken);
}
#else
// On .NET Core 3.1, cancellationToken will always be CancellationToken.None
Debug.Assert(!cancellationToken.CanBeCanceled);
cancellationToken = _activityToken.Token;
#endif

try
Expand All @@ -174,7 +166,7 @@ async Task SerializeToStreamAsync(Stream stream, TransportContext? context, Canc
// https://github.com/dotnet/corefx/issues/39586#issuecomment-516210081
try
{
await stream.FlushAsync(cancellationToken);
await stream.FlushAsync(_activityToken.Token);
}
catch (OperationCanceledException oex)
{
Expand All @@ -187,7 +179,7 @@ async Task SerializeToStreamAsync(Stream stream, TransportContext? context, Canc
return;
}

var (result, error) = await StreamCopier.CopyAsync(isRequest: true, _source, stream, _clock, _activityToken, _activityTimeout);
var (result, error) = await StreamCopier.CopyAsync(isRequest: true, _source, stream, _clock, _activityToken);
_tcs.TrySetResult((result, error));

// Check for errors that weren't the result of the destination failing.
Expand Down
81 changes: 81 additions & 0 deletions src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Threading;

namespace Yarp.ReverseProxy.Utilities
{
internal sealed class ActivityCancellationTokenSource : CancellationTokenSource
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
{
#if NET6_0_OR_GREATER
private const int MaxQueueSize = 1024;
private static readonly ConcurrentQueue<ActivityCancellationTokenSource> _sharedSources = new();
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
private static int _count;
#endif

private static readonly Action<object?> _linkedTokenCancelDelegate = static s =>
{
((ActivityCancellationTokenSource)s!).Cancel(throwOnFirstException: false);
};

private int _activityTimeoutMs;
private CancellationTokenRegistration _linkedRegistration;

private ActivityCancellationTokenSource() { }

public void ResetTimeout()
{
CancelAfter(_activityTimeoutMs);
}

public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken)
{
#if NET6_0_OR_GREATER
if (_sharedSources.TryDequeue(out var cts))
{
Interlocked.Decrement(ref _count);
}
else
{
cts = new ActivityCancellationTokenSource();
}
#else
var cts = new ActivityCancellationTokenSource();
#endif

cts._activityTimeoutMs = (int)activityTimeout.TotalMilliseconds;
cts._linkedRegistration = cts.LinkTo(linkedToken);
cts.ResetTimeout();

return cts;
}

public void Return()
{
_linkedRegistration.Dispose();
_linkedRegistration = default;

#if NET6_0_OR_GREATER
if (TryReset())
{
if (Interlocked.Increment(ref _count) <= MaxQueueSize)
{
_sharedSources.Enqueue(this);
return;
}

Interlocked.Decrement(ref _count);
}
#endif

Dispose();
}

public CancellationTokenRegistration LinkTo(CancellationToken linkedToken)
{
return linkedToken.UnsafeRegister(_linkedTokenCancelDelegate, this);
}
}
}
Loading