Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AIFunctionFactory support for AOT. #5494

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ class AIFunctionFactory
/// <param name="method">The method to be represented via the created <see cref="AIFunction"/>.</param>
/// <param name="options">Metadata to use to override defaults inferred from <paramref name="method"/>.</param>
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
[RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")]
[RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")]
public static AIFunction Create(Delegate method, AIFunctionFactoryCreateOptions options)
{
_ = Throw.IfNull(method);
Expand All @@ -69,8 +67,6 @@ public static AIFunction Create(Delegate method, string? name, string? descripti
/// <param name="name">The name to use for the <see cref="AIFunction"/>.</param>
/// <param name="description">The description to use for the <see cref="AIFunction"/>.</param>
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
[RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")]
[RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")]
public static AIFunction Create(Delegate method, JsonSerializerOptions options, string? name = null, string? description = null)
{
_ = Throw.IfNull(method);
Expand Down Expand Up @@ -104,8 +100,6 @@ public static AIFunction Create(MethodInfo method, object? target = null)
/// </param>
/// <param name="options">Metadata to use to override defaults inferred from <paramref name="method"/>.</param>
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
[RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")]
[RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")]
public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryCreateOptions options)
{
_ = Throw.IfNull(method);
Expand Down Expand Up @@ -136,8 +130,6 @@ class ReflectionAIFunction : AIFunction
/// This should be <see langword="null"/> if and only if <paramref name="method"/> is a static method.
/// </param>
/// <param name="options">Function creation options.</param>
[RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")]
[RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")]
public ReflectionAIFunction(MethodInfo method, object? target, AIFunctionFactoryCreateOptions options)
{
_ = Throw.IfNull(method);
Expand Down Expand Up @@ -384,8 +376,6 @@ static bool IsAsyncMethod(MethodInfo method)
/// <summary>
/// Gets a delegate for handling the result value of a method, converting it into the <see cref="Task{FunctionResult}"/> to return from the invocation.
/// </summary>
[RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")]
[RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")]
private static Type GetReturnMarshaler(MethodInfo method, out Func<object?, ValueTask<object?>> marshaler)
{
// Handle each known return type for the method
Expand Down Expand Up @@ -416,9 +406,9 @@ private static Type GetReturnMarshaler(MethodInfo method, out Func<object?, Valu
if (returnType.IsGenericType)
{
// Task<T>
if (returnType.GetGenericTypeDefinition() == typeof(Task<>) &&
returnType.GetProperty(nameof(Task<int>.Result), BindingFlags.Public | BindingFlags.Instance)?.GetGetMethod() is MethodInfo taskResultGetter)
if (returnType.GetGenericTypeDefinition() == typeof(Task<>))
{
MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult);
marshaler = async result =>
{
await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false);
Expand All @@ -428,10 +418,10 @@ private static Type GetReturnMarshaler(MethodInfo method, out Func<object?, Valu
}

// ValueTask<T>
if (returnType.GetGenericTypeDefinition() == typeof(ValueTask<>) &&
returnType.GetMethod(nameof(ValueTask<int>.AsTask), BindingFlags.Public | BindingFlags.Instance) is MethodInfo valueTaskAsTask &&
valueTaskAsTask.ReturnType.GetProperty(nameof(ValueTask<int>.Result), BindingFlags.Public | BindingFlags.Instance)?.GetGetMethod() is MethodInfo asTaskResultGetter)
if (returnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask);
MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult);
marshaler = async result =>
{
var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(result), null)!;
Expand Down Expand Up @@ -471,6 +461,20 @@ private static Type GetReturnMarshaler(MethodInfo method, out Func<object?, Valu
#endif
}

private static readonly MethodInfo _taskGetResult = typeof(Task<>).GetProperty(nameof(Task<int>.Result), BindingFlags.Instance | BindingFlags.Public)!.GetMethod!;
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
private static readonly MethodInfo _valueTaskAsTask = typeof(ValueTask<>).GetMethod(nameof(ValueTask<int>.AsTask), BindingFlags.Instance | BindingFlags.Public)!;

[UnconditionalSuppressMessage("Trimming", "IL2070:'this' argument does not satisfy 'DynamicallyAccessedMembersAttribute' in call to target method.",
Justification = "The MethodInfo we are looking for must have already been rooted by virtue of its generic definition being available.")]
private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedType, MethodInfo genericMethodDefinition)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
Debug.Assert(specializedType.IsGenericType && specializedType.GetGenericTypeDefinition() == genericMethodDefinition.DeclaringType, "generic member definition doesn't match type.");
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken);
}

/// <summary>
/// Remove characters from method name that are valid in metadata but shouldn't be used in a method name.
/// This is primarily intended to remove characters emitted by for compiler-generated method name mangling.
Expand Down
Loading