Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.

Commit

Permalink
Add ManualResetValueTaskSourceLogic
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Oct 18, 2018
1 parent 0ca17b0 commit baa1583
Show file tree
Hide file tree
Showing 4 changed files with 356 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@
<Compile Include="$(MSBuildThisFileDirectory)System\Threading\Tasks\TaskToApm.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Threading\Tasks\TaskSchedulerException.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Threading\Tasks\ValueTask.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Threading\Tasks\Sources\ManualResetValueTaskSourceLogic.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Threading\Tasks\Sources\IValueTaskSource.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Threading\ThreadAbortException.cs" />
<Compile Include="$(MSBuildThisFileDirectory)System\Threading\ThreadInterruptedException.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace System.Threading
{
public delegate void ContextCallback(object state);

internal delegate void ContextCallback<TState>(ref TState state);

public sealed class ExecutionContext : IDisposable, ISerializable
{
internal static readonly ExecutionContext Default = new ExecutionContext(isDefault: true);
Expand Down Expand Up @@ -201,6 +203,85 @@ internal static void RunInternal(ExecutionContext executionContext, ContextCallb
edi?.Throw();
}

// Direct copy of the above RunInternal overload, except that it passes the state into the callback strongly-typed and by ref.
internal static void RunInternal<TState>(ExecutionContext executionContext, ContextCallback<TState> callback, ref TState state)
{
// Note: ExecutionContext.RunInternal is an extremely hot function and used by every await, ThreadPool execution, etc.
// Note: Manual enregistering may be addressed by "Exception Handling Write Through Optimization"
// https://github.com/dotnet/coreclr/blob/master/Documentation/design-docs/eh-writethru.md

// Enregister variables with 0 post-fix so they can be used in registers without EH forcing them to stack
// Capture references to Thread Contexts
Thread currentThread0 = Thread.CurrentThread;
Thread currentThread = currentThread0;
ExecutionContext previousExecutionCtx0 = currentThread0.ExecutionContext;

// Store current ExecutionContext and SynchronizationContext as "previousXxx".
// This allows us to restore them and undo any Context changes made in callback.Invoke
// so that they won't "leak" back into caller.
// These variables will cross EH so be forced to stack
ExecutionContext previousExecutionCtx = previousExecutionCtx0;
SynchronizationContext previousSyncCtx = currentThread0.SynchronizationContext;

if (executionContext != null && executionContext.m_isDefault)
{
// Default is a null ExecutionContext internally
executionContext = null;
}

if (previousExecutionCtx0 != executionContext)
{
// Restore changed ExecutionContext
currentThread0.ExecutionContext = executionContext;
if ((executionContext != null && executionContext.HasChangeNotifications) ||
(previousExecutionCtx0 != null && previousExecutionCtx0.HasChangeNotifications))
{
// There are change notifications; trigger any affected
OnValuesChanged(previousExecutionCtx0, executionContext);
}
}

ExceptionDispatchInfo edi = null;
try
{
callback.Invoke(ref state);
}
catch (Exception ex)
{
// Note: we have a "catch" rather than a "finally" because we want
// to stop the first pass of EH here. That way we can restore the previous
// context before any of our callers' EH filters run.
edi = ExceptionDispatchInfo.Capture(ex);
}

// Re-enregistrer variables post EH with 1 post-fix so they can be used in registers rather than from stack
SynchronizationContext previousSyncCtx1 = previousSyncCtx;
Thread currentThread1 = currentThread;
// The common case is that these have not changed, so avoid the cost of a write barrier if not needed.
if (currentThread1.SynchronizationContext != previousSyncCtx1)
{
// Restore changed SynchronizationContext back to previous
currentThread1.SynchronizationContext = previousSyncCtx1;
}

ExecutionContext previousExecutionCtx1 = previousExecutionCtx;
ExecutionContext currentExecutionCtx1 = currentThread1.ExecutionContext;
if (currentExecutionCtx1 != previousExecutionCtx1)
{
// Restore changed ExecutionContext back to previous
currentThread1.ExecutionContext = previousExecutionCtx1;
if ((currentExecutionCtx1 != null && currentExecutionCtx1.HasChangeNotifications) ||
(previousExecutionCtx1 != null && previousExecutionCtx1.HasChangeNotifications))
{
// There are change notifications; trigger any affected
OnValuesChanged(currentExecutionCtx1, previousExecutionCtx1);
}
}

// If exception was thrown by callback, rethrow it now original contexts are restored
edi?.Throw();
}

internal static void OnValuesChanged(ExecutionContext previousExecutionCtx, ExecutionContext nextExecutionCtx)
{
Debug.Assert(previousExecutionCtx != nextExecutionCtx);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;

namespace System.Threading.Tasks.Sources
{
/// <summary>Provides the core logic for implementing a manual-reset <see cref="IValueTaskSource"/> or <see cref="IValueTaskSource{TResult}"/>.</summary>
/// <typeparam name="TResult"></typeparam>
[StructLayout(LayoutKind.Auto)]
public struct ManualResetValueTaskSourceLogic<TResult>
{
/// <summary>
/// The callback to invoke when the operation completes if <see cref="OnCompleted"/> was called before the operation completed,
/// or <see cref="ManualResetValueTaskSourceLogicShared.s_sentinel"/> if the operation completed before a callback was supplied,
/// or null if a callback hasn't yet been provided and the operation hasn't yet completed.
/// </summary>
private Action<object> _continuation;
/// <summary>State to pass to <see cref="_continuation"/>.</summary>
private object _continuationState;
/// <summary><see cref="ExecutionContext"/> to flow to the callback, or null if no flowing is required.</summary>
private ExecutionContext _executionContext;
/// <summary>
/// A "captured" <see cref="SynchronizationContext"/> or <see cref="TaskScheduler"/> with which to invoke the callback,
/// or null if no special context is required.
/// </summary>
private object _capturedContext;
/// <summary>Whether the current operation has completed.</summary>
private bool _completed;
/// <summary>The result with which the operation succeeded, or the default value if it hasn't yet completed or failed.</summary>
private TResult _result;
/// <summary>The exception with which the operation failed, or null if it hasn't yet completed or completed successfully.</summary>
private ExceptionDispatchInfo _error;
/// <summary>The current version of this value, used to help prevent misuse.</summary>
private short _version;

/// <summary>Gets or sets whether to force continuations to run asynchronously.</summary>
/// <remarks>Continuations may run asynchronously if this is false, but they'll never run synchronously if this is true.</remarks>
public bool RunContinuationsAsynchronously { get; set; }

/// <summary>Resets to prepare for the next operation.</summary>
public void Reset()
{
// Reset/update state for the next use/await of this instance.
_version++;
_completed = false;
_result = default;
_error = null;
_executionContext = null;
_capturedContext = null;
_continuation = null;
_continuationState = null;
}

/// <summary>Completes with a successful result.</summary>
/// <param name="result">The result.</param>
public void SetResult(TResult result)
{
_result = result;
SignalCompletion();
}

/// <summary>Complets with an error.</summary>
/// <param name="error"></param>
public void SetException(Exception error)
{
_error = ExceptionDispatchInfo.Capture(error);
SignalCompletion();
}

/// <summary>Gets the operation version.</summary>
public short Version => _version;

/// <summary>Gets the status of the operation.</summary>
/// <param name="token">Opaque value that was provided to the <see cref="ValueTask"/>'s constructor.</param>
public ValueTaskSourceStatus GetStatus(short token)
{
ValidateToken(token);
return
!_completed ? ValueTaskSourceStatus.Pending :
_error == null ? ValueTaskSourceStatus.Succeeded :
_error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled :
ValueTaskSourceStatus.Faulted;
}

/// <summary>Gets the result of the operation.</summary>
/// <param name="token">Opaque value that was provided to the <see cref="ValueTask"/>'s constructor.</param>
public TResult GetResult(short token)
{
ValidateToken(token);
if (!_completed)
{
ThrowInvalidOperationException();
}

_error?.Throw();
return _result;
}

/// <summary>Schedules the continuation action for this operation.</summary>
/// <param name="continuation">The continuation to invoke when the operation has completed.</param>
/// <param name="state">The state object to pass to <paramref name="continuation"/> when it's invoked.</param>
/// <param name="token">Opaque value that was provided to the <see cref="ValueTask"/>'s constructor.</param>
/// <param name="flags">The flags describing the behavior of the continuation.</param>
public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
{
if (continuation == null)
{
throw new ArgumentNullException(nameof(continuation));
}
ValidateToken(token);

if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0)
{
_executionContext = ExecutionContext.Capture();
}

if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0)
{
SynchronizationContext sc = SynchronizationContext.Current;
if (sc != null && sc.GetType() != typeof(SynchronizationContext))
{
_capturedContext = sc;
}
else
{
TaskScheduler ts = TaskScheduler.Current;
if (ts != TaskScheduler.Default)
{
_capturedContext = ts;
}
}
}

_continuationState = state;
if (Interlocked.CompareExchange(ref _continuation, continuation, null) != null)
{
switch (_capturedContext)
{
case null:
if (_executionContext != null)
{
ThreadPool.QueueUserWorkItem(continuation, state, preferLocal: true);
}
else
{
ThreadPool.UnsafeQueueUserWorkItem(continuation, state, preferLocal: true);
}
break;

case SynchronizationContext sc:
sc.Post(s =>
{
var tuple = (Tuple<Action<object>, object>)s;
tuple.Item1(tuple.Item2);
}, Tuple.Create(continuation, state));
break;

case TaskScheduler ts:
Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts);
break;
}
}
}

/// <summary>Ensures that the specified token matches the current version.</summary>
/// <param name="token">The token supplied by <see cref="ValueTask"/>.</param>
private void ValidateToken(short token)
{
if (token != _version)
{
ThrowInvalidOperationException();
}
}

/// <summary>Signals that that the operation has completed. Invoked after the result or error has been set.</summary>
private void SignalCompletion()
{
if (_completed)
{
ThrowInvalidOperationException();
}
_completed = true;

if (Interlocked.CompareExchange(ref _continuation, ManualResetValueTaskSourceLogicShared.s_sentinel, null) != null)
{
if (_executionContext != null)
{
ExecutionContext.RunInternal(
_executionContext,
(ref ManualResetValueTaskSourceLogic<TResult> s) => s.InvokeContinuation(),
ref this);
}
else
{
InvokeContinuation();
}
}
}

/// <summary>
/// Invokes the continuation with the appropriate captured context / scheduler.
/// This assumes that if <see cref="_executionContext"/> is not null we're already
/// running within that <see cref="ExecutionContext"/>.
/// </summary>
private void InvokeContinuation()
{
switch (_capturedContext)
{
case null:
if (RunContinuationsAsynchronously)
{
if (_executionContext != null)
{
ThreadPool.QueueUserWorkItem(_continuation, _continuationState, preferLocal: true);
}
else
{
ThreadPool.UnsafeQueueUserWorkItem(_continuation, _continuationState, preferLocal: true);
}
}
else
{
_continuation(_continuationState);
}
break;

case SynchronizationContext sc:
sc.Post(s =>
{
var state = (Tuple<Action<object>, object>)s;
state.Item1(state.Item2);
}, Tuple.Create(_continuation, _continuationState));
break;

case TaskScheduler ts:
Task.Factory.StartNew(_continuation, _continuationState, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts);
break;
}
}

private static void ThrowInvalidOperationException() =>
throw new InvalidOperationException();
}

internal static class ManualResetValueTaskSourceLogicShared
{
internal static readonly Action<object> s_sentinel = new Action<object>(s =>
{
Debug.Fail("The sentinel delegate should never be invoked.");
throw null;
});
}
}
16 changes: 16 additions & 0 deletions src/System.Private.CoreLib/src/System/Threading/ThreadPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,22 @@ public static bool QueueUserWorkItem<TState>(Action<TState> callBack, TState sta
return true;
}

// TODO: https://github.com/dotnet/corefx/issues/32547. Make public.
internal static bool UnsafeQueueUserWorkItem<TState>(Action<TState> callBack, TState state, bool preferLocal)
{
if (callBack == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.callBack);
}

EnsureVMInitialized();

ThreadPoolGlobals.workQueue.Enqueue(
new QueueUserWorkItemCallback<TState>(callBack, state, null), forceGlobal: !preferLocal);

return true;
}

public static bool UnsafeQueueUserWorkItem(WaitCallback callBack, object state)
{
if (callBack == null)
Expand Down

0 comments on commit baa1583

Please sign in to comment.