Skip to content

Commit

Permalink
Avoid exception checking nullability
Browse files Browse the repository at this point in the history
- Do not throw if we cannot determine the nullability of a dictionary.
- Clean-up some code analysis suggestions.

Resolves #3070.
Resolves #2793.
  • Loading branch information
martincostello committed Nov 23, 2024
1 parent b8e1f0f commit 00890c9
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,19 @@ public static bool IsDictionaryValueNonNullable(this MemberInfo memberInfo)
{
#if NET6_0_OR_GREATER
var nullableInfo = GetNullabilityInfo(memberInfo);
if (nullableInfo.GenericTypeArguments.Length != 2)
{
var length = nullableInfo.GenericTypeArguments.Length;
var type = nullableInfo.Type.FullName;
var container = memberInfo.DeclaringType.FullName;
var member = memberInfo.Name;
throw new InvalidOperationException($"Expected Dictionary to have two generic type arguments but it had {length}. Member: {container}.{member} Type: {type}.");
}

return nullableInfo.GenericTypeArguments[1].ReadState == NullabilityState.NotNull;
// Assume one generic argument means TKey and TValue are the same type.
// Assume two generic arguments match TKey and TValue for a dictionary.
// A better solution would be to inspect the type declaration (base types,
// interfaces, etc.) to determine if the type is a dictionary, but the
// nullability information is not available to be able to do that.
// See https://stackoverflow.com/q/75786306/1064169.
return nullableInfo.GenericTypeArguments.Length switch
{
1 => nullableInfo.GenericTypeArguments[0].ReadState == NullabilityState.NotNull,
2 => nullableInfo.GenericTypeArguments[1].ReadState == NullabilityState.NotNull,
_ => false,
};
#else
var memberType = memberInfo.MemberType == MemberTypes.Field
? ((FieldInfo)memberInfo).FieldType
Expand Down Expand Up @@ -156,19 +159,19 @@ private static bool GetNullableFallbackValue(this MemberInfo memberInfo)
{
var declaringTypes = memberInfo.DeclaringType.IsNested
? GetDeclaringTypeChain(memberInfo)
: new List<Type>(1) { memberInfo.DeclaringType };
: [memberInfo.DeclaringType];

foreach (var declaringType in declaringTypes)
{
var attributes = (IEnumerable<object>)declaringType.GetCustomAttributes(false);
IEnumerable<object> attributes = declaringType.GetCustomAttributes(false);

var nullableContext = attributes
.FirstOrDefault(attr => string.Equals(attr.GetType().FullName, NullableContextAttributeFullTypeName));

if (nullableContext != null)
{
if (nullableContext.GetType().GetField(FlagFieldName) is FieldInfo field &&
field.GetValue(nullableContext) is byte flag && flag == NotAnnotated)
field.GetValue(nullableContext) is byte flag && flag == NotAnnotated)
{
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ public SchemaGenerator(SchemaGeneratorOptions generatorOptions, ISerializerDataC
{
}

public SchemaGenerator(SchemaGeneratorOptions generatorOptions, ISerializerDataContractResolver serializerDataContractResolver, IOptions<MvcOptions> mvcOptions)
public SchemaGenerator(
SchemaGeneratorOptions generatorOptions,
ISerializerDataContractResolver serializerDataContractResolver,
IOptions<MvcOptions> mvcOptions)
{
_generatorOptions = generatorOptions;
_serializerDataContractResolver = serializerDataContractResolver;
Expand Down Expand Up @@ -104,7 +107,7 @@ private OpenApiSchema GenerateSchemaForMember(
var genericTypes = modelType
.GetInterfaces()
#if NETSTANDARD2_0
.Concat(new[] { modelType })
.Concat([modelType])
#else
.Append(modelType)
#endif
Expand Down Expand Up @@ -309,7 +312,7 @@ private static OpenApiSchema CreatePrimitiveSchema(DataContract dataContract)
};

#pragma warning disable CS0618 // Type or member is obsolete
// For backcompat only - EnumValues is obsolete
// For backwards compatibility only - EnumValues is obsolete
if (dataContract.EnumValues != null)
{
schema.Enum = dataContract.EnumValues
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Xunit;

namespace Swashbuckle.AspNetCore.SwaggerGen;

#nullable enable

public static class MemberInfoExtensionsTests
{
[Theory]
[InlineData(typeof(MyClass), nameof(MyClass.DictionaryInt32NonNullable), true)]
[InlineData(typeof(MyClass), nameof(MyClass.DictionaryInt32Nullable), false)]
[InlineData(typeof(MyClass), nameof(MyClass.DictionaryStringNonNullable), true)]
[InlineData(typeof(MyClass), nameof(MyClass.DictionaryStringNullable), false)]
[InlineData(typeof(MyClass), nameof(MyClass.IDictionaryInt32NonNullable), true)]
[InlineData(typeof(MyClass), nameof(MyClass.IDictionaryInt32Nullable), false)]
[InlineData(typeof(MyClass), nameof(MyClass.IDictionaryStringNonNullable), true)]
[InlineData(typeof(MyClass), nameof(MyClass.IDictionaryStringNullable), false)]
[InlineData(typeof(MyClass), nameof(MyClass.IReadOnlyDictionaryInt32NonNullable), true)]
[InlineData(typeof(MyClass), nameof(MyClass.IReadOnlyDictionaryInt32Nullable), false)]
[InlineData(typeof(MyClass), nameof(MyClass.IReadOnlyDictionaryStringNonNullable), true)]
[InlineData(typeof(MyClass), nameof(MyClass.IReadOnlyDictionaryStringNullable), false)]
[InlineData(typeof(MyClass), nameof(MyClass.StringDictionary), false)] // There is no way to inspect the nullability of the base class' TValue argument
[InlineData(typeof(MyClass), nameof(MyClass.NullableStringDictionary), false)]
[InlineData(typeof(MyClass), nameof(MyClass.SameTypesDictionary), true)]
[InlineData(typeof(MyClass), nameof(MyClass.CustomDictionaryStringNullable), false)]
[InlineData(typeof(MyClass), nameof(MyClass.CustomDictionaryStringNonNullable), true)]
public static void IsDictionaryValueNonNullable_Returns_Correct_Value(Type type, string memberName, bool expected)
{
// Arrange
var memberInfo = type.GetMember(memberName).First();

// Act
var actual = memberInfo.IsDictionaryValueNonNullable();

// Assert
Assert.Equal(expected, actual);
}

public class MyClass
{
public Dictionary<string, int> DictionaryInt32NonNullable { get; set; } = [];

public Dictionary<string, int?> DictionaryInt32Nullable { get; set; } = [];

public Dictionary<string, string> DictionaryStringNonNullable { get; set; } = [];

public Dictionary<string, string?> DictionaryStringNullable { get; set; } = [];

public IDictionary<string, int> IDictionaryInt32NonNullable { get; set; } = new Dictionary<string, int>();

public IDictionary<string, int?> IDictionaryInt32Nullable { get; set; } = new Dictionary<string, int?>();

public IDictionary<string, string> IDictionaryStringNonNullable { get; set; } = new Dictionary<string, string>();

public IDictionary<string, string?> IDictionaryStringNullable { get; set; } = new Dictionary<string, string?>();

public IReadOnlyDictionary<string, int> IReadOnlyDictionaryInt32NonNullable { get; set; } = new Dictionary<string, int>();

public IReadOnlyDictionary<string, int?> IReadOnlyDictionaryInt32Nullable { get; set; } = new Dictionary<string, int?>();

public IReadOnlyDictionary<string, string> IReadOnlyDictionaryStringNonNullable { get; set; } = new Dictionary<string, string>();

public IReadOnlyDictionary<string, string?> IReadOnlyDictionaryStringNullable { get; set; } = new Dictionary<string, string?>();

public StringDictionary StringDictionary { get; set; } = [];

public NullableStringDictionary NullableStringDictionary { get; set; } = [];

public SameTypesDictionary<string> SameTypesDictionary { get; set; } = [];

public CustomDictionary<string, string?> CustomDictionaryStringNullable { get; set; } = [];

public CustomDictionary<string, string> CustomDictionaryStringNonNullable { get; set; } = [];
}

public class StringDictionary : Dictionary<string, string>;

public class NullableStringDictionary : Dictionary<string, string?>;

public class SameTypesDictionary<T> : Dictionary<T, T> where T : notnull;

public class CustomDictionary<TKey, TValue> : Dictionary<TKey, TValue> where TKey : notnull;
}

0 comments on commit 00890c9

Please sign in to comment.