Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Enum field type bug found when underlying type is set from assembly loaded with MLC #106375

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -303,4 +303,7 @@
<data name="InvalidOperation_UnmatchingSymScope" xml:space="preserve">
<value>Unmatching symbol scope.</value>
</data>
<data name="Argument_MustBeEnum" xml:space="preserve">
<value>Type provided must be an Enum.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ protected override void SetCustomAttributeCore(ConstructorInfo con, ReadOnlySpan

public override Type? ReflectedType => _typeBuilder.ReflectedType;

public override Type UnderlyingSystemType => GetEnumUnderlyingType();
public override Type UnderlyingSystemType => this;

public override Type GetEnumUnderlyingType() => _underlyingField.FieldType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,78 +40,10 @@ internal FieldBuilderImpl(TypeBuilderImpl typeBuilder, string fieldName, Type ty
protected override void SetConstantCore(object? defaultValue)
{
_typeBuilder.ThrowIfCreated();
ValidateDefaultValueType(defaultValue, _fieldType);
_defaultValue = defaultValue;
_attributes |= FieldAttributes.HasDefault;
}

internal static void ValidateDefaultValueType(object? defaultValue, Type destinationType)
{
if (defaultValue == null)
{
// nullable value types can hold null value.
if (destinationType.IsValueType && !(destinationType.IsGenericType && destinationType.GetGenericTypeDefinition() == typeof(Nullable<>)))
{
throw new ArgumentException(SR.Argument_ConstantNull);
}
}
else
{
Type sourceType = defaultValue.GetType();
// We should allow setting a constant value on a ByRef parameter
if (destinationType.IsByRef)
{
destinationType = destinationType.GetElementType()!;
}

// Convert nullable types to their underlying type.
destinationType = Nullable.GetUnderlyingType(destinationType) ?? destinationType;

if (destinationType.IsEnum)
{
Type underlyingType;
if (destinationType is EnumBuilderImpl enumBldr)
{
underlyingType = enumBldr.GetEnumUnderlyingType();

if (sourceType != enumBldr._typeBuilder.UnderlyingSystemType &&
sourceType != underlyingType &&
// If the source type is an enum, should not throw when the underlying types match
sourceType.IsEnum &&
sourceType.GetEnumUnderlyingType() != underlyingType)
{
throw new ArgumentException(SR.Argument_ConstantDoesntMatch);
}
}
else if (destinationType is TypeBuilderImpl typeBldr)
{
underlyingType = typeBldr.UnderlyingSystemType;

if (underlyingType == null || (sourceType != typeBldr.UnderlyingSystemType && sourceType != underlyingType))
{
throw new ArgumentException(SR.Argument_ConstantDoesntMatch);
}
}
else
{
underlyingType = Enum.GetUnderlyingType(destinationType);

if (sourceType != destinationType && sourceType != underlyingType)
{
throw new ArgumentException(SR.Argument_ConstantDoesntMatch);
}
}
}
else
{
if (!destinationType.IsAssignableFrom(sourceType))
{
throw new ArgumentException(SR.Argument_ConstantDoesntMatch);
}
}
}
}

internal void SetData(byte[] data)
{
_rvaData = data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ public override void EmitCall(OpCode opcode, MethodInfo methodInfo, Type[]? opti
}

EmitOpcode(opcode);
UpdateStackSize(GetStackChange(opcode, methodInfo, optionalParameterTypes));
UpdateStackSize(GetStackChange(opcode, methodInfo, _moduleBuilder.GetTypeFromCoreAssembly(CoreTypeId.Void), optionalParameterTypes));
if (optionalParameterTypes == null || optionalParameterTypes.Length == 0)
{
WriteOrReserveToken(_moduleBuilder.TryGetMethodHandle(methodInfo), methodInfo);
Expand All @@ -613,12 +613,12 @@ public override void EmitCall(OpCode opcode, MethodInfo methodInfo, Type[]? opti
}
}

private static int GetStackChange(OpCode opcode, MethodInfo methodInfo, Type[]? optionalParameterTypes)
private static int GetStackChange(OpCode opcode, MethodInfo methodInfo, Type voidType, Type[]? optionalParameterTypes)
{
int stackChange = 0;

// Push the return value if there is one.
if (methodInfo.ReturnType != typeof(void))
if (methodInfo.ReturnType != voidType)
{
stackChange++;
}
Expand Down Expand Up @@ -665,7 +665,7 @@ public override void EmitCalli(OpCode opcode, CallingConventions callingConventi
}
}

int stackChange = GetStackChange(returnType, parameterTypes);
int stackChange = GetStackChange(returnType, _moduleBuilder.GetTypeFromCoreAssembly(CoreTypeId.Void), parameterTypes);

// Pop off VarArg arguments.
if (optionalParameterTypes != null)
Expand All @@ -685,17 +685,17 @@ public override void EmitCalli(OpCode opcode, CallingConventions callingConventi

public override void EmitCalli(OpCode opcode, CallingConvention unmanagedCallConv, Type? returnType, Type[]? parameterTypes)
{
int stackChange = GetStackChange(returnType, parameterTypes);
int stackChange = GetStackChange(returnType, _moduleBuilder.GetTypeFromCoreAssembly(CoreTypeId.Void), parameterTypes);
UpdateStackSize(stackChange);
Emit(OpCodes.Calli);
_il.Token(_moduleBuilder.GetSignatureToken(unmanagedCallConv, returnType, parameterTypes));
}

private static int GetStackChange(Type? returnType, Type[]? parameterTypes)
private static int GetStackChange(Type? returnType, Type voidType, Type[]? parameterTypes)
{
int stackChange = 0;
// If there is a non-void return type, push one.
if (returnType != typeof(void))
if (returnType != voidType)
{
stackChange++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ public ParameterBuilderImpl(MethodBuilderImpl methodBuilder, int sequence, Param

public override void SetConstant(object? defaultValue)
{
Type parameterType = _position == 0 ? _methodBuilder.ReturnType : _methodBuilder.ParameterTypes![_position - 1];
FieldBuilderImpl.ValidateDefaultValueType(defaultValue, parameterType);
_defaultValue = defaultValue;
_attributes |= ParameterAttributes.HasDefault;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ protected override void AddOtherMethodCore(MethodBuilder mdBuilder)
protected override void SetConstantCore(object? defaultValue)
{
_containingType.ThrowIfCreated();
FieldBuilderImpl.ValidateDefaultValueType(defaultValue, _propertyType);
_defaultValue = defaultValue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ protected override MethodBuilder DefineMethodCore(string name, MethodAttributes
{
ThrowIfCreated();


MethodBuilderImpl methodBuilder = new(name, attributes, callingConvention, returnType, returnTypeRequiredCustomModifiers,
returnTypeOptionalCustomModifiers, parameterTypes, parameterTypeRequiredCustomModifiers, parameterTypeOptionalCustomModifiers, _module, this);
_methodDefinitions.Add(methodBuilder);
Expand Down Expand Up @@ -616,23 +615,22 @@ public override Type GetGenericTypeDefinition()
public override string? Namespace => _namespace;
public override Assembly Assembly => _module.Assembly;
public override Module Module => _module;
public override Type UnderlyingSystemType
public override Type UnderlyingSystemType => this;

public override Type GetEnumUnderlyingType()
{
get
if (IsEnum)
{
if (IsEnum)
{
if (_enumUnderlyingType == null)
{
throw new InvalidOperationException(SR.InvalidOperation_NoUnderlyingTypeOnEnum);
}

return _enumUnderlyingType;
}
else
if (_enumUnderlyingType == null)
{
return this;
throw new InvalidOperationException(SR.InvalidOperation_NoUnderlyingTypeOnEnum);
}

return _enumUnderlyingType;
}
else
{
throw new ArgumentException(SR.Argument_MustBeEnum);
}
}
public override bool IsSZArray => false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public static IEnumerable<object[]> DefineLiteral_TestData()
yield return new object[] { typeof(uint), (uint)1 };

yield return new object[] { typeof(int), 0 };
yield return new object[] { typeof(int), 1 };
yield return new object[] { typeof(int), Test.Second };

yield return new object[] { typeof(ulong), (ulong)0 };
yield return new object[] { typeof(ulong), (ulong)1 };
Expand Down Expand Up @@ -100,7 +100,7 @@ public void CreateEnumWithMlc()
PersistedAssemblyBuilder ab = new PersistedAssemblyBuilder(PopulateAssemblyName(), mlc.CoreAssembly);
ModuleBuilder mb = ab.DefineDynamicModule("My Module");
Type intType = mlc.CoreAssembly.GetType("System.Int32");
EnumBuilder enumBuilder = mb.DefineEnum("TestEnum", TypeAttributes.Public, typeof(int));
EnumBuilder enumBuilder = mb.DefineEnum("TestEnum", TypeAttributes.Public, intType);
FieldBuilder field = enumBuilder.DefineLiteral("Default", 0);

enumBuilder.CreateTypeInfo();
Expand All @@ -118,7 +118,7 @@ public void CreateEnumWithMlc()

FieldInfo testField = createdEnum.GetField("Default");
Assert.Equal(createdEnum, testField.FieldType);
Assert.Equal(typeof(int), enumBuilder.GetEnumUnderlyingType());
Assert.Equal(intType, enumBuilder.GetEnumUnderlyingType());
Assert.Equal(FieldAttributes.Public | FieldAttributes.Static | FieldAttributes.Literal | FieldAttributes.HasDefault, testField.Attributes);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ public void SetConstantVariousValues(Type returnType, object defaultValue)
Assert.Equal(defaultValue, property.GetConstantValue());
}

[Theory]
[MemberData(nameof(SetConstant_TestData))]
public void SetConstantVariousValuesMlcCoreAssembly(Type returnType, object defaultValue)
{
using (MetadataLoadContext mlc = new MetadataLoadContext(new CoreMetadataAssemblyResolver()))
{
PersistedAssemblyBuilder ab = new PersistedAssemblyBuilder(new AssemblyName("MyDynamicAssembly"), mlc.CoreAssembly);
ModuleBuilder mb = ab.DefineDynamicModule("My Module");
Type returnTypeFromCore = returnType != typeof(PropertyBuilderTest11.Colors) ? mlc.CoreAssembly.GetType(returnType.FullName, true) : returnType;
TypeBuilder type = mb.DefineType("MyType", TypeAttributes.Public);

PropertyBuilder property = type.DefineProperty("TestProperty", PropertyAttributes.HasDefault, returnTypeFromCore, null);
property.SetConstant(defaultValue);

Assert.Equal(defaultValue, property.GetConstantValue());
}
}

[Fact]
public void SetCustomAttribute_ConstructorInfo_ByteArray_NullConstructorInfo_ThrowsArgumentNullException()
{
Expand Down Expand Up @@ -194,7 +212,6 @@ public void Set_WhenTypeAlreadyCreated_ThrowsInvalidOperationException()
MethodAttributes getMethodAttributes = MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.HideBySig;
MethodBuilder method = type.DefineMethod("TestMethod", getMethodAttributes, typeof(int), null);
method.GetILGenerator().Emit(OpCodes.Ret);
AssertExtensions.Throws<ArgumentException>(() => property.SetConstant((decimal)10));
CustomAttributeBuilder customAttrBuilder = new CustomAttributeBuilder(typeof(IntPropertyAttribute).GetConstructor([typeof(int)]), [10]);
type.CreateType();

Expand All @@ -204,18 +221,5 @@ public void Set_WhenTypeAlreadyCreated_ThrowsInvalidOperationException()
Assert.Throws<InvalidOperationException>(() => property.SetConstant(1));
Assert.Throws<InvalidOperationException>(() => property.SetCustomAttribute(customAttrBuilder));
}

[Fact]
public void SetConstant_ValidationThrows()
{
AssemblySaveTools.PopulateAssemblyBuilderAndTypeBuilder(out TypeBuilder type);
FieldBuilder field = type.DefineField("TestField", typeof(int), FieldAttributes.Private);
PropertyBuilder property = type.DefineProperty("TestProperty", PropertyAttributes.HasDefault, typeof(int), null);

AssertExtensions.Throws<ArgumentException>(() => property.SetConstant((decimal)10));
AssertExtensions.Throws<ArgumentException>(() => property.SetConstant(null));
type.CreateType();
Assert.Throws<InvalidOperationException>(() => property.SetConstant(1));
}
}
}
Loading