Skip to content

Commit

Permalink
Allow default AsyncFlowControls rather than throwing (#82912)
Browse files Browse the repository at this point in the history
ExecutionContext.SuppressFlow currently throws an exception if flow is already suppressed.  This makes it complicated to use, as you need to check whether IsFlowSuppressed first and take two different paths based on the result.  If we instead just allow SuppressFlow to return a default AsyncFlowControl rather than throwing, and have AsyncFlowControl's Undo nop rather than throw if it doesn't contain a Thread, we can again make it simple to just always use SuppressFlow without any of the other complications.
  • Loading branch information
stephentoub authored Mar 3, 2023
1 parent db084d9 commit 94bec76
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 194 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,15 @@ public void FileSystemWatcher_File_Create_SuppressedExecutionContextHandled()

local.Value = 42;

ExecutionContext.SuppressFlow();
try
using (ExecutionContext.SuppressFlow())
{
watcher1.EnableRaisingEvents = true;
}
finally
{
ExecutionContext.RestoreFlow();
}

File.Create(fileName).Dispose();
tcs1.Task.Wait(WaitForExpectedEventTimeout);
File.Create(fileName).Dispose();
tcs1.Task.Wait(WaitForExpectedEventTimeout);

Assert.Equal(0, tcs1.Task.Result);
Assert.Equal(0, tcs1.Task.Result);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1291,15 +1291,8 @@ internal void HandleAltSvc(IEnumerable<string> altSvcHeaderValues, TimeSpan? res
{
var thisRef = new WeakReference<HttpConnectionPool>(this);

bool restoreFlow = false;
try
using (ExecutionContext.SuppressFlow())
{
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
restoreFlow = true;
}

_authorityExpireTimer = new Timer(static o =>
{
var wr = (WeakReference<HttpConnectionPool>)o!;
Expand All @@ -1309,10 +1302,6 @@ internal void HandleAltSvc(IEnumerable<string> altSvcHeaderValues, TimeSpan? res
}
}, thisRef, nextAuthorityMaxAge, Timeout.InfiniteTimeSpan);
}
finally
{
if (restoreFlow) ExecutionContext.RestoreFlow();
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,8 @@ public HttpConnectionPoolManager(HttpConnectionSettings settings)
_cleanPoolTimeout = timerPeriod.TotalSeconds >= MinScavengeSeconds ? timerPeriod : TimeSpan.FromSeconds(MinScavengeSeconds);
}

bool restoreFlow = false;
try
using (ExecutionContext.SuppressFlow()) // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever
{
// Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
restoreFlow = true;
}

// Create the timer. Ensure the Timer has a weak reference to this manager; otherwise, it
// can introduce a cycle that keeps the HttpConnectionPoolManager rooted by the Timer
// implementation until the handler is Disposed (or indefinitely if it's not).
Expand Down Expand Up @@ -131,14 +123,6 @@ public HttpConnectionPoolManager(HttpConnectionSettings settings)
}, thisRef, heartBeatInterval, heartBeatInterval);
}
}
finally
{
// Restore the current ExecutionContext
if (restoreFlow)
{
ExecutionContext.RestoreFlow();
}
}
}

// Figure out proxy stuff.
Expand Down Expand Up @@ -190,14 +174,7 @@ public void StartMonitoringNetworkChanges()
return;
}

if (!ExecutionContext.IsFlowSuppressed())
{
using (ExecutionContext.SuppressFlow())
{
NetworkChange.NetworkAddressChanged += networkChangedDelegate;
}
}
else
using (ExecutionContext.SuppressFlow())
{
NetworkChange.NetworkAddressChanged += networkChangedDelegate;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,10 @@ public static event NetworkAvailabilityChangedEventHandler? NetworkAvailabilityC
if (s_availabilityTimer == null)
{
// Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever
bool restoreFlow = false;
try
using (ExecutionContext.SuppressFlow())
{
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
restoreFlow = true;
}

s_availabilityTimer = new Timer(s_availabilityTimerFiredCallback, null, Timeout.Infinite, Timeout.Infinite);
}
finally
{
// Restore the current ExecutionContext
if (restoreFlow)
ExecutionContext.RestoreFlow();
}
}

s_availabilityChangedSubscribers.TryAdd(value, ExecutionContext.Capture());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,10 @@ public async Task SocketAsyncEventArgs_ExecutionContextFlowsAcrossAcceptAsyncOpe
};

asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
Assert.True(listener.AcceptAsync(saea));
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

client.Connect(listener.LocalEndPoint);
Expand All @@ -65,19 +60,14 @@ public async Task APM_ExecutionContextFlowsAcrossBeginAcceptOperation(bool suppr
var tcs = new TaskCompletionSource<int>();

asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
listener.BeginAccept(iar =>
{
listener.EndAccept(iar).Dispose();
tcs.SetResult(asyncLocal.Value);
}, null);
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

client.Connect(listener.LocalEndPoint);
Expand Down Expand Up @@ -105,15 +95,10 @@ public async Task SocketAsyncEventArgs_ExecutionContextFlowsAcrossConnectAsyncOp

bool pending;
asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
pending = client.ConnectAsync(saea);
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

if (pending)
Expand All @@ -139,19 +124,14 @@ public async Task APM_ExecutionContextFlowsAcrossBeginConnectOperation(bool supp

bool pending;
asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
pending = !client.BeginConnect(listener.LocalEndPoint, iar =>
{
client.EndConnect(iar);
tcs.SetResult(asyncLocal.Value);
}, null).CompletedSynchronously;
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

if (pending)
Expand Down Expand Up @@ -182,15 +162,10 @@ public async Task SocketAsyncEventArgs_ExecutionContextFlowsAcrossDisconnectAsyn

bool pending;
asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
pending = client.DisconnectAsync(saea);
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

if (pending)
Expand Down Expand Up @@ -220,19 +195,14 @@ public async Task APM_ExecutionContextFlowsAcrossBeginDisconnectOperation(bool s

bool pending;
asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
pending = !client.BeginDisconnect(reuseSocket: false, iar =>
{
client.EndDisconnect(iar);
tcs.SetResult(asyncLocal.Value);
}, null).CompletedSynchronously;
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

if (pending)
Expand Down Expand Up @@ -267,17 +237,12 @@ public async Task SocketAsyncEventArgs_ExecutionContextFlowsAcrossReceiveAsyncOp
saea.RemoteEndPoint = server.LocalEndPoint;

asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
Assert.True(receiveFrom ?
client.ReceiveFromAsync(saea) :
client.ReceiveAsync(saea));
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

server.Send(new byte[] { 18 });
Expand Down Expand Up @@ -306,8 +271,7 @@ public async Task APM_ExecutionContextFlowsAcrossBeginReceiveOperation(bool supp
var tcs = new TaskCompletionSource<int>();

asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
EndPoint ep = server.LocalEndPoint;
Assert.False(receiveFrom ?
Expand All @@ -322,11 +286,6 @@ public async Task APM_ExecutionContextFlowsAcrossBeginReceiveOperation(bool supp
tcs.SetResult(asyncLocal.Value);
}, null).CompletedSynchronously);
}
finally
{
if (suppressContext)
ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

server.Send(new byte[] { 18 });
Expand Down Expand Up @@ -365,18 +324,13 @@ public async Task SocketAsyncEventArgs_ExecutionContextFlowsAcrossSendAsyncOpera

bool pending;
asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
pending =
sendMode == 0 ? client.SendAsync(saea) :
sendMode == 1 ? client.SendToAsync(saea) :
client.SendPacketsAsync(saea);
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

int totalReceived = 0;
Expand Down Expand Up @@ -416,8 +370,7 @@ public async Task APM_ExecutionContextFlowsAcrossBeginSendOperation(bool suppres

bool pending;
asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
pending = sendTo ?
!client.BeginSendTo(buffer, 0, buffer.Length, SocketFlags.None, server.LocalEndPoint, iar =>
Expand All @@ -431,10 +384,6 @@ public async Task APM_ExecutionContextFlowsAcrossBeginSendOperation(bool suppres
tcs.SetResult(asyncLocal.Value);
}, null).CompletedSynchronously;
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

int totalReceived = 0;
Expand Down Expand Up @@ -477,19 +426,14 @@ public async Task APM_ExecutionContextFlowsAcrossBeginSendFileOperation(bool sup

bool pending;
asyncLocal.Value = 42;
if (suppressContext) ExecutionContext.SuppressFlow();
try
using (suppressContext ? ExecutionContext.SuppressFlow() : default)
{
pending = !client.BeginSendFile(filePath, iar =>
{
client.EndSendFile(iar);
tcs.SetResult(asyncLocal.Value);
}, null).CompletedSynchronously;
}
finally
{
if (suppressContext) ExecutionContext.RestoreFlow();
}
asyncLocal.Value = 0;

if (pending)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,25 @@ await Task.WhenAll(
using (Socket server = await acceptTask)
using (var receiveSaea = new SocketAsyncEventArgs())
{
if (suppressed)
using (suppressed ? ExecutionContext.SuppressFlow() : default)
{
ExecutionContext.SuppressFlow();
}

var local = new AsyncLocal<int>();
local.Value = 42;
int threadId = Environment.CurrentManagedThreadId;
var local = new AsyncLocal<int>();
local.Value = 42;
int threadId = Environment.CurrentManagedThreadId;

var mres = new ManualResetEventSlim();
receiveSaea.SetBuffer(new byte[1], 0, 1);
receiveSaea.Completed += delegate
{
Assert.NotEqual(threadId, Environment.CurrentManagedThreadId);
Assert.Equal(suppressed ? 0 : 42, local.Value);
mres.Set();
};

Assert.True(client.ReceiveAsync(receiveSaea));
server.Send(new byte[1]);
mres.Wait();
var mres = new ManualResetEventSlim();
receiveSaea.SetBuffer(new byte[1], 0, 1);
receiveSaea.Completed += delegate
{
Assert.NotEqual(threadId, Environment.CurrentManagedThreadId);
Assert.Equal(suppressed ? 0 : 42, local.Value);
mres.Set();
};

Assert.True(client.ReceiveAsync(receiveSaea));
server.Send(new byte[1]);
mres.Wait();
}
}
}
}
Expand Down
Loading

0 comments on commit 94bec76

Please sign in to comment.