Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ValueTask as a return type of the hub method. #583

Merged
merged 1 commit into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ public static MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse> C
}
}


// WORKAROUND: Prior to MagicOnion 5.0, the request type for the parameter-less method was byte[].
// DynamicClient sends byte[], but GeneratedClient sends Nil, which is incompatible,
// so as a special case we do not serialize/deserialize and always convert to a fixed values.
Expand All @@ -124,7 +123,7 @@ public static MagicOnionMethod<TRequest, TResponse, TRawRequest, TRawResponse> C

var writer = ctx.GetBufferWriter();
var buffer = writer.GetSpan(unsafeNilBytes.Length); // Write `Nil` as `byte[]` to the buffer.
MagicOnionMarshallers.UnsafeNilBytes.CopyTo(buffer);
unsafeNilBytes.CopyTo(buffer);
writer.Advance(unsafeNilBytes.Length);

ctx.Complete();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ public static class MagicOnionMarshallers
.OrderBy(x => x.GetGenericArguments().Length)
.ToArray();

public static readonly byte[] UnsafeNilBytes = new byte[] { MessagePackCode.Nil };

public static readonly Marshaller<byte[]> ThroughMarshaller = new Marshaller<byte[]>(x => x, x => x);

internal static Type CreateRequestType(ParameterInfo[] parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public static class KnownTypes
public static MagicOnionTypeInfo System_Boolean { get; } = new MagicOnionTypeInfo("System", "Boolean", SubType.ValueType);
public static MagicOnionTypeInfo MessagePack_Nil { get; } = new MagicOnionTypeInfo("MessagePack", "Nil", SubType.ValueType);
public static MagicOnionTypeInfo System_Threading_Tasks_Task { get; } = new MagicOnionTypeInfo("System.Threading.Tasks", "Task");
public static MagicOnionTypeInfo System_Threading_Tasks_ValueTask { get; } = new MagicOnionTypeInfo("System.Threading.Tasks", "ValueTask", SubType.ValueType);
public static MagicOnionTypeInfo MagicOnion_UnaryResult { get; } = new MagicOnionTypeInfo("MagicOnion", "UnaryResult", SubType.ValueType);
// ReSharper restore InconsistentNaming
}
Expand All @@ -26,6 +27,12 @@ public static class KnownTypes
public IReadOnlyList<MagicOnionTypeInfo> GenericArguments { get; }
public bool HasGenericArguments => GenericArguments.Any();

public MagicOnionTypeInfo GetGenericTypeDefinition()
{
if (!HasGenericArguments) throw new InvalidOperationException("The type is not constructed generic type.");
return MagicOnionTypeInfo.Create(Namespace, Name, Array.Empty<MagicOnionTypeInfo>(), IsValueType);
}

public bool IsArray => _subType == SubType.Array;
public int ArrayRank { get; }
public MagicOnionTypeInfo ElementType { get; }
Expand Down Expand Up @@ -120,6 +127,7 @@ public static MagicOnionTypeInfo Create(string @namespace, string name, MagicOni
if (@namespace == "System" && name == "String") return KnownTypes.System_String;
if (@namespace == "System" && name == "Boolean") return KnownTypes.System_Boolean;
if (@namespace == "System.Threading.Tasks" && name == "Task" && genericArguments.Length == 0) return KnownTypes.System_Threading_Tasks_Task;
if (@namespace == "System.Threading.Tasks" && name == "ValueTask" && genericArguments.Length == 0) return KnownTypes.System_Threading_Tasks_ValueTask;
if (@namespace == "MagicOnion" && name == "UnaryResult" && genericArguments.Length == 0) return KnownTypes.MagicOnion_UnaryResult;

return new MagicOnionTypeInfo(@namespace, name, isValueType ? SubType.ValueType : SubType.None, arrayRank:0, genericArguments);
Expand Down
2 changes: 2 additions & 0 deletions src/MagicOnion.GeneratorCore/CodeAnalysis/MethodCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ MagicOnionStreamingHubInfo.MagicOnionHubMethodInfo CreateHubMethodInfoFromMethod
switch (methodReturnType.FullNameOpenType)
{
case "global::System.Threading.Tasks.Task":
case "global::System.Threading.Tasks.ValueTask":
//responseType = MagicOnionTypeInfo.KnownTypes.MessagePack_Nil;
break;
case "global::System.Threading.Tasks.Task<>":
case "global::System.Threading.Tasks.ValueTask<>":
responseType = methodReturnType.GenericArguments[0];
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,30 @@ static void EmitHubMethods(StreamingHubClientBuildContext ctx, bool isFireAndFor
_ => $", {method.Parameters.ToNewDynamicArgumentTuple()}",
};

ctx.TextWriter.WriteLines($"""
public {method.MethodReturnType.FullName} {method.MethodName}({method.Parameters.ToMethodSignaturize()})
=> {(isFireAndForget ? "parent.WriteMessageFireAndForgetAsync" : "base.WriteMessageWithResponseAsync")}<{method.RequestType.FullName}, {method.ResponseType.FullName}>({method.HubId}{writeMessageParameters});
""");
if (method.MethodReturnType == MagicOnionTypeInfo.KnownTypes.System_Threading_Tasks_ValueTask)
{
// ValueTask
ctx.TextWriter.WriteLines($"""
public {method.MethodReturnType.FullName} {method.MethodName}({method.Parameters.ToMethodSignaturize()})
=> new global::System.Threading.Tasks.ValueTask({(isFireAndForget ? "parent.WriteMessageFireAndForgetAsync" : "base.WriteMessageWithResponseAsync")}<{method.RequestType.FullName}, {method.ResponseType.FullName}>({method.HubId}{writeMessageParameters}));
""");
}
else if (method.MethodReturnType.HasGenericArguments && method.MethodReturnType.GetGenericTypeDefinition() == MagicOnionTypeInfo.KnownTypes.System_Threading_Tasks_ValueTask)
{
// ValueTask<T>
ctx.TextWriter.WriteLines($"""
public {method.MethodReturnType.FullName} {method.MethodName}({method.Parameters.ToMethodSignaturize()})
=> new global::System.Threading.Tasks.ValueTask<{method.ResponseType.FullName}>({(isFireAndForget ? "parent.WriteMessageFireAndForgetAsync" : "base.WriteMessageWithResponseAsync")}<{method.RequestType.FullName}, {method.ResponseType.FullName}>({method.HubId}{writeMessageParameters}));
""");
}
else
{
// Task, Task<T>
ctx.TextWriter.WriteLines($"""
public {method.MethodReturnType.FullName} {method.MethodName}({method.Parameters.ToMethodSignaturize()})
=> {(isFireAndForget ? "parent.WriteMessageFireAndForgetAsync" : "base.WriteMessageWithResponseAsync")}<{method.RequestType.FullName}, {method.ResponseType.FullName}>({method.HubId}{writeMessageParameters});
""");
}
} // #endif
}

Expand Down
4 changes: 2 additions & 2 deletions src/MagicOnion.Server/Hubs/StreamingHubContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public ConcurrentDictionary<string, object> Items
internal Type? responseType;

// helper for reflection
internal async ValueTask WriteResponseMessageNil(Task value)
internal async ValueTask WriteResponseMessageNil(ValueTask value)
{
if (MessageId == -1) // don't write.
{
Expand Down Expand Up @@ -70,7 +70,7 @@ byte[] BuildMessage()
responseType = typeof(Nil);
}

internal async ValueTask WriteResponseMessage<T>(Task<T> value)
internal async ValueTask WriteResponseMessage<T>(ValueTask<T> value)
{
if (MessageId == -1) // don't write.
{
Expand Down
119 changes: 85 additions & 34 deletions src/MagicOnion.Server/Hubs/StreamingHubHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using MessagePack;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using Grpc.Core;
using MagicOnion.Server.Filters;
using MagicOnion.Server.Filters.Internal;
Expand Down Expand Up @@ -59,9 +60,7 @@ public StreamingHubHandler(Type classType, MethodInfo methodInfo, StreamingHubHa
var invokeHubMethodFunc = Expression.Lambda(callHubMethod, contextArg, requestArg).Compile();

// Create a StreamingHub method invoker and a wrapped-invoke method.
Type invokerType = metadata.ResponseType is null
? typeof(StreamingHubMethodInvoker<>).MakeGenericType(metadata.RequestType)
: typeof(StreamingHubMethodInvoker<,>).MakeGenericType(metadata.RequestType, metadata.ResponseType);
Type invokerType = StreamingHubMethodInvoker.CreateInvokerTypeFromMetadata(metadata);
StreamingHubMethodInvoker invoker = (StreamingHubMethodInvoker)Activator.CreateInstance(invokerType, messageSerializer, invokeHubMethodFunc)!;

var filters = FilterHelper.GetFilters(handlerOptions.GlobalStreamingHubFilters, classType, methodInfo);
Expand All @@ -73,23 +72,63 @@ public StreamingHubHandler(Type classType, MethodInfo methodInfo, StreamingHubHa
}
}

abstract class StreamingHubMethodInvoker
public override string ToString()
=> toStringCache;

public override int GetHashCode()
=> getHashCodeCache;

public bool Equals(StreamingHubHandler? other)
=> other != null && HubName.Equals(other.HubName) && MethodInfo.Name.Equals(other.MethodInfo.Name);
}

/// <summary>
/// Options for StreamingHubHandler construction.
/// </summary>
public class StreamingHubHandlerOptions
{
public IList<StreamingHubFilterDescriptor> GlobalStreamingHubFilters { get; }

public IMagicOnionMessageSerializerProvider MessageSerializer { get; }

public StreamingHubHandlerOptions(MagicOnionOptions options)
{
protected IMagicOnionMessageSerializer MessageSerializer { get; }
GlobalStreamingHubFilters = options.GlobalStreamingHubFilters;
MessageSerializer = options.MessageSerializer;
}
}

protected StreamingHubMethodInvoker(IMagicOnionMessageSerializer messageSerializer)
{
MessageSerializer = messageSerializer;
}
internal abstract class StreamingHubMethodInvoker
{
protected IMagicOnionMessageSerializer MessageSerializer { get; }

public abstract ValueTask InvokeAsync(StreamingHubContext context);
protected StreamingHubMethodInvoker(IMagicOnionMessageSerializer messageSerializer)
{
MessageSerializer = messageSerializer;
}

public abstract ValueTask InvokeAsync(StreamingHubContext context);

public static Type CreateInvokerTypeFromMetadata(in StreamingHubMethodHandlerMetadata metadata)
{
var isTaskOrTaskOfT = metadata.InterfaceMethod.ReturnType == typeof(Task) ||
(metadata.InterfaceMethod.ReturnType is { IsGenericType: true } t && t.BaseType == typeof(Task));
return isTaskOrTaskOfT
? (metadata.ResponseType is null
? typeof(StreamingHubMethodInvokerTask<>).MakeGenericType(metadata.RequestType)
: typeof(StreamingHubMethodInvokerTask<,>).MakeGenericType(metadata.RequestType, metadata.ResponseType)
)
: (metadata.ResponseType is null
? typeof(StreamingHubMethodInvokerValueTask<>).MakeGenericType(metadata.RequestType)
: typeof(StreamingHubMethodInvokerValueTask<,>).MakeGenericType(metadata.RequestType, metadata.ResponseType)
);
}

sealed class StreamingHubMethodInvoker<TRequest, TResponse> : StreamingHubMethodInvoker
sealed class StreamingHubMethodInvokerTask<TRequest, TResponse> : StreamingHubMethodInvoker
{
readonly Func<StreamingHubContext, TRequest, Task<TResponse>> hubMethodFunc;

public StreamingHubMethodInvoker(IMagicOnionMessageSerializer messageSerializer, Delegate hubMethodFunc) : base(messageSerializer)
public StreamingHubMethodInvokerTask(IMagicOnionMessageSerializer messageSerializer, Delegate hubMethodFunc) : base(messageSerializer)
{
this.hubMethodFunc = (Func<StreamingHubContext, TRequest, Task<TResponse>>)hubMethodFunc;
}
Expand All @@ -99,15 +138,15 @@ public override ValueTask InvokeAsync(StreamingHubContext context)
var seq = new ReadOnlySequence<byte>(context.Request);
TRequest request = MessageSerializer.Deserialize<TRequest>(seq);
Task<TResponse> response = hubMethodFunc(context, request);
return context.WriteResponseMessage(response);
return context.WriteResponseMessage(new ValueTask<TResponse>(response));
}
}

sealed class StreamingHubMethodInvoker<TRequest> : StreamingHubMethodInvoker
sealed class StreamingHubMethodInvokerTask<TRequest> : StreamingHubMethodInvoker
{
readonly Func<StreamingHubContext, TRequest, Task> hubMethodFunc;

public StreamingHubMethodInvoker(IMagicOnionMessageSerializer messageSerializer, Delegate hubMethodFunc) : base(messageSerializer)
public StreamingHubMethodInvokerTask(IMagicOnionMessageSerializer messageSerializer, Delegate hubMethodFunc) : base(messageSerializer)
{
this.hubMethodFunc = (Func<StreamingHubContext, TRequest, Task>)hubMethodFunc;
}
Expand All @@ -117,32 +156,44 @@ public override ValueTask InvokeAsync(StreamingHubContext context)
var seq = new ReadOnlySequence<byte>(context.Request);
TRequest request = MessageSerializer.Deserialize<TRequest>(seq);
Task response = hubMethodFunc(context, request);
return context.WriteResponseMessageNil(response);
return context.WriteResponseMessageNil(new ValueTask(response));
}
}

public override string ToString()
=> toStringCache;
sealed class StreamingHubMethodInvokerValueTask<TRequest, TResponse> : StreamingHubMethodInvoker
{
readonly Func<StreamingHubContext, TRequest, ValueTask<TResponse>> hubMethodFunc;

public override int GetHashCode()
=> getHashCodeCache;
public StreamingHubMethodInvokerValueTask(IMagicOnionMessageSerializer messageSerializer, Delegate hubMethodFunc) : base(messageSerializer)
{
this.hubMethodFunc = (Func<StreamingHubContext, TRequest, ValueTask<TResponse>>)hubMethodFunc;
}

public bool Equals(StreamingHubHandler? other)
=> other != null && HubName.Equals(other.HubName) && MethodInfo.Name.Equals(other.MethodInfo.Name);
}
public override ValueTask InvokeAsync(StreamingHubContext context)
{
var seq = new ReadOnlySequence<byte>(context.Request);
TRequest request = MessageSerializer.Deserialize<TRequest>(seq);
ValueTask<TResponse> response = hubMethodFunc(context, request);
return context.WriteResponseMessage(response);
}
}

/// <summary>
/// Options for StreamingHubHandler construction.
/// </summary>
public class StreamingHubHandlerOptions
{
public IList<StreamingHubFilterDescriptor> GlobalStreamingHubFilters { get; }
sealed class StreamingHubMethodInvokerValueTask<TRequest> : StreamingHubMethodInvoker
{
readonly Func<StreamingHubContext, TRequest, ValueTask> hubMethodFunc;

public IMagicOnionMessageSerializerProvider MessageSerializer { get; }
public StreamingHubMethodInvokerValueTask(IMagicOnionMessageSerializer messageSerializer, Delegate hubMethodFunc) : base(messageSerializer)
{
this.hubMethodFunc = (Func<StreamingHubContext, TRequest, ValueTask>)hubMethodFunc;
}

public StreamingHubHandlerOptions(MagicOnionOptions options)
{
GlobalStreamingHubFilters = options.GlobalStreamingHubFilters;
MessageSerializer = options.MessageSerializer;
public override ValueTask InvokeAsync(StreamingHubContext context)
{
var seq = new ReadOnlySequence<byte>(context.Request);
TRequest request = MessageSerializer.Deserialize<TRequest>(seq);
ValueTask response = hubMethodFunc(context, request);
return context.WriteResponseMessageNil(response);
}
}

}
16 changes: 8 additions & 8 deletions src/MagicOnion.Server/Internal/MethodHandlerMetadata.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public static StreamingHubMethodHandlerMetadata CreateStreamingHubMethodHandlerM
{
var hubInterface = serviceClass.GetInterfaces().First(x => x.GetTypeInfo().IsGenericType && x.GetGenericTypeDefinition() == typeof(IStreamingHub<,>)).GetGenericArguments()[0];
var parameters = methodInfo.GetParameters();
var responseType = UnwrapStreamingHubResponseType(methodInfo, out var responseIsTask);
var responseType = UnwrapStreamingHubResponseType(methodInfo, out var responseIsTaskOrValueTask);
var requestType = GetRequestTypeFromMethod(methodInfo, parameters);

var attributeLookup = serviceClass.GetCustomAttributes(true)
Expand All @@ -105,9 +105,9 @@ public static StreamingHubMethodHandlerMetadata CreateStreamingHubMethodHandlerM

var interfaceMethodInfo = ResolveInterfaceMethod(serviceClass, hubInterface, methodInfo.Name);

if (!responseIsTask)
if (!responseIsTaskOrValueTask)
{
throw new InvalidOperationException($"A type of the StreamingHub method must be Task or Task<T>. (Member:{serviceClass.Name}.{methodInfo.Name})");
throw new InvalidOperationException($"A type of the StreamingHub method must be Task, Task<T>, ValueTask or ValueTask<T>. (Member:{serviceClass.Name}.{methodInfo.Name})");
}

var methodId = interfaceMethodInfo.GetCustomAttribute<MethodIdAttribute>()?.MethodId ?? FNV1A32.GetHashCode(interfaceMethodInfo.Name);
Expand Down Expand Up @@ -187,19 +187,19 @@ static Type UnwrapUnaryResponseType(MethodInfo methodInfo, out MethodType method
throw new InvalidOperationException($"The method '{methodInfo.Name}' has invalid return type. path:{methodInfo.DeclaringType!.Name + "/" + methodInfo.Name} type:{methodInfo.ReturnType.Name}");
}

static Type? UnwrapStreamingHubResponseType(MethodInfo methodInfo, out bool responseIsTask)
static Type? UnwrapStreamingHubResponseType(MethodInfo methodInfo, out bool responseIsTaskOrValueTask)
{
var t = methodInfo.ReturnType;

// Task<T>
if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Task<>))
if (t.IsGenericType && (t.GetGenericTypeDefinition() == typeof(Task<>) || t.GetGenericTypeDefinition() == typeof(ValueTask<>)))
{
responseIsTask = true;
responseIsTaskOrValueTask = true;
return t.GetGenericArguments()[0];
}
else if (t == typeof(Task))
else if (t == typeof(Task) || t == typeof(ValueTask))
{
responseIsTask = true;
responseIsTaskOrValueTask = true;
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,4 +698,15 @@ public void EnumerateDependentTypes_Generics_Nested()
MagicOnionTypeInfo.CreateFromType<byte[]>(),
MagicOnionTypeInfo.CreateFromType<byte>());
}

[Fact]
public void GetGenericTypeDefinition()
{
// Arrange
var typeInfo = MagicOnionTypeInfo.CreateFromType<ValueTuple<int, string>>();
// Act
var genericDefinition = typeInfo.GetGenericTypeDefinition();
// Assert
genericDefinition.Should().Be(MagicOnionTypeInfo.Create("System", "ValueTuple", Array.Empty<MagicOnionTypeInfo>(), true));
}
}
Loading