Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Commit

Permalink
Use new ManualResetValueTaskSourceCore in tests, and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Oct 30, 2018
1 parent 9a35b16 commit ae7bca4
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,235 +2,21 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;

namespace System.Runtime.CompilerServices
{
public interface IStrongBox<T>
{
ref T Value { get; }
}
}

namespace System.Threading.Tasks.Sources
{
public sealed class ManualResetValueTaskSource<T> : IStrongBox<ManualResetValueTaskSourceLogic<T>>, IValueTaskSource<T>, IValueTaskSource
public sealed class ManualResetValueTaskSource<T> : IValueTaskSource<T>, IValueTaskSource
{
private ManualResetValueTaskSourceLogic<T> _logic; // mutable struct; do not make this readonly

public ManualResetValueTaskSource() => _logic = new ManualResetValueTaskSourceLogic<T>(this);

public short Version => _logic.Version;

public void Reset() => _logic.Reset();

public void SetResult(T result) => _logic.SetResult(result);

public void SetException(Exception error) => _logic.SetException(error);

public T GetResult(short token) => _logic.GetResult(token);
void IValueTaskSource.GetResult(short token) => _logic.GetResult(token);

public ValueTaskSourceStatus GetStatus(short token) => _logic.GetStatus(token);

public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) => _logic.OnCompleted(continuation, state, token, flags);

ref ManualResetValueTaskSourceLogic<T> IStrongBox<ManualResetValueTaskSourceLogic<T>>.Value => ref _logic;
}

public struct ManualResetValueTaskSourceLogic<TResult>
{
private static readonly Action<object> s_sentinel = new Action<object>(s => throw new InvalidOperationException());

private readonly IStrongBox<ManualResetValueTaskSourceLogic<TResult>> _parent;
private Action<object> _continuation;
private object _continuationState;
private object _capturedContext;
private ExecutionContext _executionContext;
private bool _completed;
private TResult _result;
private ExceptionDispatchInfo _error;
private short _version;

public ManualResetValueTaskSourceLogic(IStrongBox<ManualResetValueTaskSourceLogic<TResult>> parent)
{
_parent = parent ?? throw new ArgumentNullException(nameof(parent));
_continuation = null;
_continuationState = null;
_capturedContext = null;
_executionContext = null;
_completed = false;
_result = default;
_error = null;
_version = 0;
}

public short Version => _version;

private void ValidateToken(short token)
{
if (token != _version)
{
throw new InvalidOperationException();
}
}

public ValueTaskSourceStatus GetStatus(short token)
{
ValidateToken(token);

return
!_completed ? ValueTaskSourceStatus.Pending :
_error == null ? ValueTaskSourceStatus.Succeeded :
_error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled :
ValueTaskSourceStatus.Faulted;
}

public TResult GetResult(short token)
{
ValidateToken(token);

if (!_completed)
{
throw new InvalidOperationException();
}

_error?.Throw();
return _result;
}

public void Reset()
{
_version++;

_completed = false;
_continuation = null;
_continuationState = null;
_result = default;
_error = null;
_executionContext = null;
_capturedContext = null;
}

public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
{
if (continuation == null)
{
throw new ArgumentNullException(nameof(continuation));
}
ValidateToken(token);

if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0)
{
_executionContext = ExecutionContext.Capture();
}

if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0)
{
SynchronizationContext sc = SynchronizationContext.Current;
if (sc != null && sc.GetType() != typeof(SynchronizationContext))
{
_capturedContext = sc;
}
else
{
TaskScheduler ts = TaskScheduler.Current;
if (ts != TaskScheduler.Default)
{
_capturedContext = ts;
}
}
}

_continuationState = state;
if (Interlocked.CompareExchange(ref _continuation, continuation, null) != null)
{
_executionContext = null;

object cc = _capturedContext;
_capturedContext = null;

switch (cc)
{
case null:
Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default);
break;

case SynchronizationContext sc:
sc.Post(s =>
{
var tuple = (Tuple<Action<object>, object>)s;
tuple.Item1(tuple.Item2);
}, Tuple.Create(continuation, state));
break;

case TaskScheduler ts:
Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts);
break;
}
}
}

public void SetResult(TResult result)
{
_result = result;
SignalCompletion();
}

public void SetException(Exception error)
{
_error = ExceptionDispatchInfo.Capture(error);
SignalCompletion();
}

private void SignalCompletion()
{
if (_completed)
{
throw new InvalidOperationException();
}
_completed = true;

if (Interlocked.CompareExchange(ref _continuation, s_sentinel, null) != null)
{
if (_executionContext != null)
{
ExecutionContext.Run(
_executionContext,
s => ((IStrongBox<ManualResetValueTaskSourceLogic<TResult>>)s).Value.InvokeContinuation(),
_parent ?? throw new InvalidOperationException());
}
else
{
InvokeContinuation();
}
}
}

private void InvokeContinuation()
{
object cc = _capturedContext;
_capturedContext = null;

switch (cc)
{
case null:
_continuation(_continuationState);
break;

case SynchronizationContext sc:
sc.Post(s =>
{
ref ManualResetValueTaskSourceLogic<TResult> logicRef = ref ((IStrongBox<ManualResetValueTaskSourceLogic<TResult>>)s).Value;
logicRef._continuation(logicRef._continuationState);
}, _parent ?? throw new InvalidOperationException());
break;

case TaskScheduler ts:
Task.Factory.StartNew(_continuation, _continuationState, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts);
break;
}
}
private ManualResetValueTaskSourceCore<T> _core; // mutable struct; do not make this readonly

public bool RunContinuationsAsynchronously { get => _core.RunContinuationsAsynchronously; set => _core.RunContinuationsAsynchronously = value; }
public short Version => _core.Version;
public void Reset() => _core.Reset();
public void SetResult(T result) => _core.SetResult(result);
public void SetException(Exception error) => _core.SetException(error);

public T GetResult(short token) => _core.GetResult(token);
void IValueTaskSource.GetResult(short token) => _core.GetResult(token);
public ValueTaskSourceStatus GetStatus(short token) => _core.GetStatus(token);
public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) => _core.OnCompleted(continuation, state, token, flags);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
// See the LICENSE file in the project root for more information.

using System.Runtime.ExceptionServices;
using System.Threading.Tasks.Sources;

namespace System.Threading.Tasks.Tests
namespace System.Threading.Tasks.Sources.Tests
{
internal static class ManualResetValueTaskSourceFactory
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using System.Threading.Tasks.Sources.Tests;
using Xunit;

namespace System.Threading.Tasks.Tests
Expand Down
Loading

0 comments on commit ae7bca4

Please sign in to comment.