diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/ExecutionContext.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/ExecutionContext.cs index dc8070ada491b..03ccbe053f0a3 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/ExecutionContext.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/ExecutionContext.cs @@ -200,74 +200,27 @@ internal static void RunInternal(ExecutionContext? executionContext, ContextCall edi?.Throw(); } - // Direct copy of the above RunInternal overload, except that it passes the state into the callback strongly-typed and by ref. - internal static void RunInternal(ExecutionContext? executionContext, ContextCallback callback, ref TState state) + internal static void Restore(ExecutionContext? executionContext) { - // Note: ExecutionContext.RunInternal is an extremely hot function and used by every await, ThreadPool execution, etc. - // Note: Manual enregistering may be addressed by "Exception Handling Write Through Optimization" - // https://github.com/dotnet/runtime/blob/master/docs/design/features/eh-writethru.md + Thread currentThread = Thread.CurrentThread; - // Enregister variables with 0 post-fix so they can be used in registers without EH forcing them to stack - // Capture references to Thread Contexts - Thread currentThread0 = Thread.CurrentThread; - Thread currentThread = currentThread0; - ExecutionContext? previousExecutionCtx0 = currentThread0._executionContext; - if (previousExecutionCtx0 != null && previousExecutionCtx0.m_isDefault) + ExecutionContext? currentExecutionCtx = currentThread._executionContext; + if (currentExecutionCtx != null && currentExecutionCtx.m_isDefault) { // Default is a null ExecutionContext internally - previousExecutionCtx0 = null; + currentExecutionCtx = null; } - // Store current ExecutionContext and SynchronizationContext as "previousXxx". - // This allows us to restore them and undo any Context changes made in callback.Invoke - // so that they won't "leak" back into caller. - // These variables will cross EH so be forced to stack - ExecutionContext? previousExecutionCtx = previousExecutionCtx0; - SynchronizationContext? previousSyncCtx = currentThread0._synchronizationContext; - if (executionContext != null && executionContext.m_isDefault) { // Default is a null ExecutionContext internally executionContext = null; } - if (previousExecutionCtx0 != executionContext) + if (currentExecutionCtx != executionContext) { - RestoreChangedContextToThread(currentThread0, executionContext, previousExecutionCtx0); + RestoreChangedContextToThread(currentThread, executionContext, currentExecutionCtx); } - - ExceptionDispatchInfo? edi = null; - try - { - callback.Invoke(ref state); - } - catch (Exception ex) - { - // Note: we have a "catch" rather than a "finally" because we want - // to stop the first pass of EH here. That way we can restore the previous - // context before any of our callers' EH filters run. - edi = ExceptionDispatchInfo.Capture(ex); - } - - // Re-enregistrer variables post EH with 1 post-fix so they can be used in registers rather than from stack - SynchronizationContext? previousSyncCtx1 = previousSyncCtx; - Thread currentThread1 = currentThread; - // The common case is that these have not changed, so avoid the cost of a write barrier if not needed. - if (currentThread1._synchronizationContext != previousSyncCtx1) - { - // Restore changed SynchronizationContext back to previous - currentThread1._synchronizationContext = previousSyncCtx1; - } - - ExecutionContext? previousExecutionCtx1 = previousExecutionCtx; - ExecutionContext? currentExecutionCtx1 = currentThread1._executionContext; - if (currentExecutionCtx1 != previousExecutionCtx1) - { - RestoreChangedContextToThread(currentThread1, previousExecutionCtx1, currentExecutionCtx1); - } - - // If exception was thrown by callback, rethrow it now original contexts are restored - edi?.Throw(); } internal static void RunFromThreadPoolDispatchLoop(Thread threadPoolThread, ExecutionContext executionContext, ContextCallback callback, object state) diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Sources/ManualResetValueTaskSourceCore.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Sources/ManualResetValueTaskSourceCore.cs index bef04bd62c7a6..4edbaf5d02301 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Sources/ManualResetValueTaskSourceCore.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Sources/ManualResetValueTaskSourceCore.cs @@ -207,51 +207,117 @@ private void SignalCompletion() } _completed = true; - if (_continuation != null || Interlocked.CompareExchange(ref _continuation, ManualResetValueTaskSourceCoreShared.s_sentinel, null) != null) + if (_continuation is null && Interlocked.CompareExchange(ref _continuation, ManualResetValueTaskSourceCoreShared.s_sentinel, null) is null) { - if (_executionContext != null) + return; + } + + if (_executionContext is null) + { + if (_capturedContext is null) { - ExecutionContext.RunInternal( - _executionContext, - (ref ManualResetValueTaskSourceCore s) => s.InvokeContinuation(), - ref this); + if (RunContinuationsAsynchronously) + { + ThreadPool.UnsafeQueueUserWorkItem(_continuation, _continuationState, preferLocal: true); + } + else + { + _continuation(_continuationState); + } } else { - InvokeContinuation(); + InvokeSchedulerContinuation(); } } + else + { + InvokeContinuationWithContext(); + } } - /// - /// Invokes the continuation with the appropriate captured context / scheduler. - /// This assumes that if is not null we're already - /// running within that . - /// - private void InvokeContinuation() + private void InvokeContinuationWithContext() { + // This is in a helper as the error handling causes the generated asm + // for the surrounding code to become less efficent (stack spills etc) + // and it is an uncommon path. + Debug.Assert(_continuation != null); + Debug.Assert(_executionContext != null); - switch (_capturedContext) + ExecutionContext? currentContext = ExecutionContext.Capture(); + // Restore the captured ExecutionContext before executing anything. + ExecutionContext.Restore(_executionContext); + + if (_capturedContext is null) { - case null: - if (RunContinuationsAsynchronously) + if (RunContinuationsAsynchronously) + { + try { - if (_executionContext != null) - { - ThreadPool.QueueUserWorkItem(_continuation, _continuationState, preferLocal: true); - } - else - { - ThreadPool.UnsafeQueueUserWorkItem(_continuation, _continuationState, preferLocal: true); - } + ThreadPool.QueueUserWorkItem(_continuation, _continuationState, preferLocal: true); } - else + finally + { + // Restore the current ExecutionContext. + ExecutionContext.Restore(currentContext); + } + } + else + { + // Running inline may throw; capture the edi if it does as we changed the ExecutionContext, + // so need to restore it back before propagating the throw. + ExceptionDispatchInfo? edi = null; + SynchronizationContext? syncContext = SynchronizationContext.Current; + try { _continuation(_continuationState); } - break; + catch (Exception ex) + { + // Note: we have a "catch" rather than a "finally" because we want + // to stop the first pass of EH here. That way we can restore the previous + // context before any of our callers' EH filters run. + edi = ExceptionDispatchInfo.Capture(ex); + } + finally + { + // Set sync context back to what it was prior to coming in + SynchronizationContext.SetSynchronizationContext(syncContext); + // Restore the current ExecutionContext. + ExecutionContext.Restore(currentContext); + } + + // Now rethrow the exception; if there is one. + edi?.Throw(); + } + return; + } + + try + { + InvokeSchedulerContinuation(); + } + finally + { + // Restore the current ExecutionContext. + ExecutionContext.Restore(currentContext); + } + } + + /// + /// Invokes the continuation with the appropriate scheduler. + /// This assumes that if is not null we're already + /// running within that . + /// + private void InvokeSchedulerContinuation() + { + Debug.Assert(_capturedContext != null); + Debug.Assert(_continuation != null); + + switch (_capturedContext) + { case SynchronizationContext sc: sc.Post(s => {