Skip to content

Commit

Permalink
Fix nullable field types due to compiler optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
EdwardCooke committed Dec 23, 2024
1 parent 7bed5cf commit 0e3bbac
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 23 deletions.
64 changes: 63 additions & 1 deletion YamlDotNet.Test/Serialization/DeserializerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,55 @@ public void DeserializeWithoutDuplicateKeyChecking_YamlWithDuplicateKeys_DoesNot
}

[Fact]
public void EnforceNulalbleTypesWhenNullThrowsException()
public void EnforceNullableWhenClassIsDefaultNullableThrows()
{
var deserializer = new DeserializerBuilder().WithEnforceNullability().Build();
var yaml = @"
TestString: null
TestBool: null
TestBool1: null
";
try
{
var test = deserializer.Deserialize<NullableDefaultClass>(yaml);
}
catch (YamlException e)
{
if (e.InnerException is NullReferenceException)
{
return;
}
}

throw new Exception("Non nullable property was set to null.");
}

[Fact]
public void EnforceNullableWhenClassIsNotDefaultNullableThrows()
{
var deserializer = new DeserializerBuilder().WithEnforceNullability().Build();
var yaml = @"
TestString: null
TestBool: null
TestBool1: null
";
try
{
var test = deserializer.Deserialize<NullableNotDefaultClass>(yaml);
}
catch (YamlException e)
{
if (e.InnerException is NullReferenceException)
{
return;
}
}

throw new Exception("Non nullable property was set to null.");
}

[Fact]
public void EnforceNullableTypesWhenNullThrowsException()
{
var deserializer = new DeserializerBuilder().WithEnforceNullability().Build();
var yaml = @"
Expand Down Expand Up @@ -589,6 +637,20 @@ public enum EnumMemberedEnum
#endif

#nullable enable
public class NullableDefaultClass
{
public string? TestString { get; set; }
public string? TestBool { get; set; }
public string TestBool1 { get; set; } = "";
}

public class NullableNotDefaultClass
{
public string? TestString { get; set; }
public string TestBool { get; set; } = "";
public string TestBool1 { get; set; } = "";
}

public class NonNullableClass
{
public string Test { get; set; } = "Some default value";
Expand Down
41 changes: 19 additions & 22 deletions YamlDotNet/ReflectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -309,35 +309,32 @@ public static Attribute[] GetAllCustomAttributes<TAttribute>(this PropertyInfo m

public static bool AcceptsNull(this MemberInfo member)
{
var result = true; //default to allowing nulls, this will be set to false if there is a null context on the type
#if NET8_0_OR_GREATER
var typeHasNullContext = TypesHaveNullContext.GetOrAdd(member.DeclaringType, (Type t) =>
{
var attributes = t.GetCustomAttributes(typeof(System.Runtime.CompilerServices.NullableContextAttribute), true);
return (attributes?.Length ?? 0) > 0;
});
var classAttributes = member.DeclaringType.GetCustomAttributes(typeof(System.Runtime.CompilerServices.NullableContextAttribute), true);
var defaultFlag = classAttributes.OfType<System.Runtime.CompilerServices.NullableContextAttribute>().FirstOrDefault()?.Flag ?? 0;

if (typeHasNullContext)
{
// we have a nullable context on that type, only allow null if the NullableAttribute is on the member.
var memberAttributes = member.GetCustomAttributes(typeof(System.Runtime.CompilerServices.NullableAttribute), true);
result = (memberAttributes?.Length ?? 0) > 0;
}
// we have a nullable context on that type, only allow null if the NullableAttribute is on the member.
var memberAttributes = member.GetCustomAttributes(typeof(System.Runtime.CompilerServices.NullableAttribute), true);
var nullableFlag = memberAttributes.OfType<System.Runtime.CompilerServices.NullableAttribute>().FirstOrDefault()?.NullableFlags.Any(flag => flag == 2);
var result = nullableFlag ?? defaultFlag == 2;

return result;
#else
var typeHasNullContext = TypesHaveNullContext.GetOrAdd(member.DeclaringType, (Type t) =>
var classAttributes = member.DeclaringType.GetCustomAttributes(true);
var classAttribute = classAttributes.FirstOrDefault(x => x.GetType().FullName == "System.Runtime.CompilerServices.NullableContextAttribute");
var defaultFlag = 0;
if (classAttribute != null)
{
var attributes = t.GetCustomAttributes(true);
return attributes.Any(x => x.GetType().FullName == "System.Runtime.CompilerServices.NullableContextAttribute");
});

if (typeHasNullContext)
{
var memberAttributes = member.GetCustomAttributes(true);
result = memberAttributes.Any(x => x.GetType().FullName == "System.Runtime.CompilerServices.NullableAttribute");
var classAttributeType = classAttribute.GetType();
var classProperty = classAttributeType.GetProperty("Flag")!;
defaultFlag = (byte)classProperty.GetValue(classAttribute)!;
}

var memberAttributes = member.GetCustomAttributes(true);
var memberAttribute = memberAttributes.FirstOrDefault(x => x.GetType().FullName == "System.Runtime.CompilerServices.NullableAttribute");
var memberAttributeType = memberAttribute?.GetType();
var memberProperty = memberAttributeType?.GetProperty("NullableFlags")!;
var flags = (byte[])memberProperty.GetValue(memberAttribute)!;
var result = flags.Any(x => x == 2) || defaultFlag == 2;
return result;
#endif
}
Expand Down

0 comments on commit 0e3bbac

Please sign in to comment.