Skip to content

Commit

Permalink
Merge pull request Cysharp#642 from Cysharp/hotfix/HubMethodRunContin…
Browse files Browse the repository at this point in the history
…uationsAsync

Run continuations of hub method calls asynchronously.
  • Loading branch information
mayuki authored May 8, 2023
2 parents 077c3b1 + 5e5f651 commit c5cb55b
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 23 deletions.
1 change: 0 additions & 1 deletion samples/JwtAuthentication/JwtAuthApp.Client/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ private async Task MainCore(string[] args)
{ "Authorization", "Bearer " + AuthenticationTokenStorage.Current.Token }
}));
await timerHubClient.SetAsync(TimeSpan.FromSeconds(5));
await Task.Yield(); // NOTE: Release the gRPC's worker thread here.
}

// 6. Insufficient privilege (The current user is not in administrators role).
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Threading.Tasks;

namespace MagicOnion.Utils
Expand All @@ -11,6 +11,11 @@ internal interface ITaskCompletion

internal class TaskCompletionSourceEx<T> : TaskCompletionSource<T>, ITaskCompletion
{
public TaskCompletionSourceEx()
{ }
public TaskCompletionSourceEx(TaskCreationOptions options) : base(options)
{ }

bool ITaskCompletion.TrySetCanceled()
{
return this.TrySetCanceled();
Expand All @@ -21,4 +26,4 @@ bool ITaskCompletion.TrySetException(Exception ex)
return this.TrySetException(ex);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class StreamingHubClientBase<TStreamingHub, TReceiver>
TaskCompletionSource<object> waitForDisconnect = new TaskCompletionSource<object>();

// {messageId, TaskCompletionSource}
ConcurrentDictionary<int, object> responseFutures = new ConcurrentDictionary<int, object>();
ConcurrentDictionary<int, ITaskCompletion> responseFutures = new ConcurrentDictionary<int, ITaskCompletion>();
protected CancellationTokenSource cts = new CancellationTokenSource();
int messageId = 0;
bool disposed;
Expand Down Expand Up @@ -196,8 +196,7 @@ void ConsumeData(SynchronizationContext syncContext, byte[] data)
if (arrayLength == 3)
{
var messageId = messagePackReader.ReadInt32();
object future;
if (responseFutures.TryRemove(messageId, out future))
if (responseFutures.TryRemove(messageId, out var future))
{
var methodId = messagePackReader.ReadInt32();
try
Expand All @@ -208,7 +207,7 @@ void ConsumeData(SynchronizationContext syncContext, byte[] data)
}
catch (Exception ex)
{
if (!(future as ITaskCompletion).TrySetException(ex))
if (!future.TrySetException(ex))
{
throw;
}
Expand All @@ -218,8 +217,7 @@ void ConsumeData(SynchronizationContext syncContext, byte[] data)
else if (arrayLength == 4)
{
var messageId = messagePackReader.ReadInt32();
object future;
if (responseFutures.TryRemove(messageId, out future))
if (responseFutures.TryRemove(messageId, out var future))
{
var statusCode = messagePackReader.ReadInt32();
var detail = messagePackReader.ReadString();
Expand All @@ -236,7 +234,7 @@ void ConsumeData(SynchronizationContext syncContext, byte[] data)
ex = new RpcException(new Status((StatusCode)statusCode, detail), detail + Environment.NewLine + error);
}

(future as ITaskCompletion).TrySetException(ex);
future.TrySetException(ex);
}
}
else
Expand Down Expand Up @@ -290,8 +288,11 @@ protected async Task<TResponse> WriteMessageWithResponseAsync<TRequest, TRespons
ThrowIfDisposed();

var mid = Interlocked.Increment(ref messageId);
var tcs = new TaskCompletionSourceEx<TResponse>(); // use Ex
responseFutures[mid] = (object)tcs;
// NOTE: The continuations (user code) should be executed asynchronously.
// This is because the continuation may block the thread, for example, Console.ReadLine().
// If the thread is blocked, it will no longer return to the message consuming loop.
var tcs = new TaskCompletionSourceEx<TResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
responseFutures[mid] = tcs;

byte[] BuildMessage()
{
Expand Down
19 changes: 10 additions & 9 deletions src/MagicOnion.Client/StreamingHubClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class StreamingHubClientBase<TStreamingHub, TReceiver>
TaskCompletionSource<object> waitForDisconnect = new TaskCompletionSource<object>();

// {messageId, TaskCompletionSource}
ConcurrentDictionary<int, object> responseFutures = new ConcurrentDictionary<int, object>();
ConcurrentDictionary<int, ITaskCompletion> responseFutures = new ConcurrentDictionary<int, ITaskCompletion>();
protected CancellationTokenSource cts = new CancellationTokenSource();
int messageId = 0;
bool disposed;
Expand Down Expand Up @@ -196,8 +196,7 @@ void ConsumeData(SynchronizationContext syncContext, byte[] data)
if (arrayLength == 3)
{
var messageId = messagePackReader.ReadInt32();
object future;
if (responseFutures.TryRemove(messageId, out future))
if (responseFutures.TryRemove(messageId, out var future))
{
var methodId = messagePackReader.ReadInt32();
try
Expand All @@ -208,7 +207,7 @@ void ConsumeData(SynchronizationContext syncContext, byte[] data)
}
catch (Exception ex)
{
if (!(future as ITaskCompletion).TrySetException(ex))
if (!future.TrySetException(ex))
{
throw;
}
Expand All @@ -218,8 +217,7 @@ void ConsumeData(SynchronizationContext syncContext, byte[] data)
else if (arrayLength == 4)
{
var messageId = messagePackReader.ReadInt32();
object future;
if (responseFutures.TryRemove(messageId, out future))
if (responseFutures.TryRemove(messageId, out var future))
{
var statusCode = messagePackReader.ReadInt32();
var detail = messagePackReader.ReadString();
Expand All @@ -236,7 +234,7 @@ void ConsumeData(SynchronizationContext syncContext, byte[] data)
ex = new RpcException(new Status((StatusCode)statusCode, detail), detail + Environment.NewLine + error);
}

(future as ITaskCompletion).TrySetException(ex);
future.TrySetException(ex);
}
}
else
Expand Down Expand Up @@ -290,8 +288,11 @@ protected async Task<TResponse> WriteMessageWithResponseAsync<TRequest, TRespons
ThrowIfDisposed();

var mid = Interlocked.Increment(ref messageId);
var tcs = new TaskCompletionSourceEx<TResponse>(); // use Ex
responseFutures[mid] = (object)tcs;
// NOTE: The continuations (user code) should be executed asynchronously.
// This is because the continuation may block the thread, for example, Console.ReadLine().
// If the thread is blocked, it will no longer return to the message consuming loop.
var tcs = new TaskCompletionSourceEx<TResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
responseFutures[mid] = tcs;

byte[] BuildMessage()
{
Expand Down
9 changes: 7 additions & 2 deletions src/MagicOnion.Shared/Utils/TaskCompletionSourceEx.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Threading.Tasks;

namespace MagicOnion.Utils
Expand All @@ -11,6 +11,11 @@ internal interface ITaskCompletion

internal class TaskCompletionSourceEx<T> : TaskCompletionSource<T>, ITaskCompletion
{
public TaskCompletionSourceEx()
{ }
public TaskCompletionSourceEx(TaskCreationOptions options) : base(options)
{ }

bool ITaskCompletion.TrySetCanceled()
{
return this.TrySetCanceled();
Expand All @@ -21,4 +26,4 @@ bool ITaskCompletion.TrySetException(Exception ex)
return this.TrySetException(ex);
}
}
}
}
38 changes: 38 additions & 0 deletions tests/MagicOnion.Integration.Tests/StreamingHubTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,30 @@ public async Task Receiver_RefType_Null(TestStreamingHubClientFactory clientFact
receiver.Verify(x => x.Receiver_RefType_Null(default));
}

[Theory]
[MemberData(nameof(EnumerateStreamingHubClientFactory))]
public async Task ContinuationBlocking(TestStreamingHubClientFactory clientFactory)
{
// Arrange
var httpClient = factory.CreateDefaultClient();
var channel = GrpcChannel.ForAddress("http://localhost", new GrpcChannelOptions() { HttpClient = httpClient });

var receiver = new Mock<IStreamingHubTestHubReceiver>();
var client = await clientFactory.CreateAndConnectAsync<IStreamingHubTestHub, IStreamingHubTestHubReceiver>(channel, receiver.Object);

// Act
// NOTE: Runs on another thread.
_ = Task.Run(async () =>
{
await client.CallReceiver_Delay(500); // The receiver will be called after 500ms.
Thread.Sleep(60 * 1000); // Block the continuation.
});
await Task.Delay(1000); // Wait for broadcast queue to be consumed.

// Assert
receiver.Verify(x => x.Receiver_Delay());
}

}

public class StreamingHubTestHub : StreamingHubBase<IStreamingHubTestHub, IStreamingHubTestHubReceiver>, IStreamingHubTestHub
Expand Down Expand Up @@ -574,6 +598,17 @@ public Task CallReceiver_RefType_Null()
Broadcast(group).Receiver_RefType_Null(default);
return Task.CompletedTask;
}

public Task CallReceiver_Delay(int milliseconds)
{
_ = Task.Run(async () =>
{
await Task.Delay(milliseconds);
Broadcast(group).Receiver_Delay();
});

return Task.CompletedTask;
}
}

public interface IStreamingHubTestHubReceiver
Expand All @@ -583,6 +618,7 @@ public interface IStreamingHubTestHubReceiver
void Receiver_Parameter_Many(int arg0, string arg1, bool arg2);
void Receiver_RefType(MyStreamingResponse request);
void Receiver_RefType_Null(MyStreamingResponse? request);
void Receiver_Delay();
}

public interface IStreamingHubTestHub : IStreamingHub<IStreamingHubTestHub, IStreamingHubTestHubReceiver>
Expand Down Expand Up @@ -617,4 +653,6 @@ public interface IStreamingHubTestHub : IStreamingHub<IStreamingHubTestHub, IStr
Task<MyStreamingResponse?> RefType_Null(MyStreamingRequest? request);
Task CallReceiver_RefType(MyStreamingRequest request);
Task CallReceiver_RefType_Null();

Task CallReceiver_Delay(int milliseconds);
}
13 changes: 13 additions & 0 deletions tests/MagicOnion.Integration.Tests/_GeneratedClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,8 @@ public StreamingHubTestHubClient(global::Grpc.Core.CallInvoker callInvoker, glob
=> base.WriteMessageWithResponseAsync<global::MagicOnion.Integration.Tests.MyStreamingRequest, global::MessagePack.Nil>(1503747814, request);
public global::System.Threading.Tasks.Task CallReceiver_RefType_Null()
=> base.WriteMessageWithResponseAsync<global::MessagePack.Nil, global::MessagePack.Nil>(-1093215042, global::MessagePack.Nil.Default);
public global::System.Threading.Tasks.Task CallReceiver_Delay(global::System.Int32 milliseconds)
=> base.WriteMessageWithResponseAsync<global::System.Int32, global::MessagePack.Nil>(1865731236, milliseconds);

public global::MagicOnion.Integration.Tests.IStreamingHubTestHub FireAndForget()
=> new FireAndForgetClient(this);
Expand Down Expand Up @@ -828,6 +830,8 @@ public FireAndForgetClient(StreamingHubTestHubClient parent)
=> parent.WriteMessageFireAndForgetAsync<global::MagicOnion.Integration.Tests.MyStreamingRequest, global::MessagePack.Nil>(1503747814, request);
public global::System.Threading.Tasks.Task CallReceiver_RefType_Null()
=> parent.WriteMessageFireAndForgetAsync<global::MessagePack.Nil, global::MessagePack.Nil>(-1093215042, global::MessagePack.Nil.Default);
public global::System.Threading.Tasks.Task CallReceiver_Delay(global::System.Int32 milliseconds)
=> parent.WriteMessageFireAndForgetAsync<global::System.Int32, global::MessagePack.Nil>(1865731236, milliseconds);

}

Expand Down Expand Up @@ -865,6 +869,12 @@ protected override void OnBroadcastEvent(global::System.Int32 methodId, global::
receiver.Receiver_RefType_Null(value);
}
break;
case -5486432: // Void Receiver_Delay()
{
var value = base.Deserialize<global::MessagePack.Nil>(data);
receiver.Receiver_Delay();
}
break;
}
}

Expand Down Expand Up @@ -941,6 +951,9 @@ protected override void OnResponseEvent(global::System.Int32 methodId, global::S
case -1093215042: // Task CallReceiver_RefType_Null()
base.SetResultForResponse<global::MessagePack.Nil>(taskCompletionSource, data);
break;
case 1865731236: // Task CallReceiver_Delay(global::System.Int32 milliseconds)
base.SetResultForResponse<global::MessagePack.Nil>(taskCompletionSource, data);
break;
}
}

Expand Down

0 comments on commit c5cb55b

Please sign in to comment.