Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MVVM Toolkit LINQ expression issues on .NET Framework #4282

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions Microsoft.Toolkit.Mvvm/ComponentModel/ObservableValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,21 @@ static Action<object> GetValidationAction(Type type)
// Fallback method to create the delegate with a compiled LINQ expression
static Action<object> GetValidationActionFallback(Type type)
{
// Get the collection of all properties to validate
(string Name, MethodInfo GetMethod)[] validatableProperties = (
from property in type.GetProperties(BindingFlags.Instance | BindingFlags.Public)
where property.GetIndexParameters().Length == 0 &&
property.GetCustomAttributes<ValidationAttribute>(true).Any()
let getMethod = property.GetMethod
where getMethod is not null
select (property.Name, getMethod)).ToArray();

// Short path if there are no properties to validate
if (validatableProperties.Length == 0)
{
return static _ => { };
}

// MyViewModel inst0 = (MyViewModel)arg0;
ParameterExpression arg0 = Expression.Parameter(typeof(object));
UnaryExpression inst0 = Expression.Convert(arg0, type);
Expand All @@ -513,14 +528,10 @@ static Action<object> GetValidationActionFallback(Type type)
// ObservableValidator externally, but that is fine because IL doesn't really have
// a concept of member visibility, that's purely a C# build-time feature.
BlockExpression body = Expression.Block(
from property in type.GetProperties(BindingFlags.Instance | BindingFlags.Public)
where property.GetIndexParameters().Length == 0 &&
property.GetCustomAttributes<ValidationAttribute>(true).Any()
let getter = property.GetMethod
where getter is not null
from property in validatableProperties
select Expression.Call(inst0, validateMethod, new Expression[]
{
Expression.Convert(Expression.Call(inst0, getter), typeof(object)),
Expression.Convert(Expression.Call(inst0, property.GetMethod), typeof(object)),
Expression.Constant(property.Name)
}));

Expand Down
20 changes: 15 additions & 5 deletions Microsoft.Toolkit.Mvvm/Messaging/IMessengerExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ static Action<IMessenger, object, TToken> LoadRegistrationMethodsForType(Type re
// The LINQ codegen bloat is not really important for the same reason.
static Action<IMessenger, object, TToken> LoadRegistrationMethodsForTypeFallback(Type recipientType)
{
// Get the collection of validation methods
MethodInfo[] registrationMethods = (
from interfaceType in recipientType.GetInterfaces()
where interfaceType.IsGenericType &&
interfaceType.GetGenericTypeDefinition() == typeof(IRecipient<>)
let messageType = interfaceType.GenericTypeArguments[0]
select MethodInfos.RegisterIRecipient.MakeGenericMethod(messageType, typeof(TToken))).ToArray();

// Short path if there are no message handlers to register
if (registrationMethods.Length == 0)
{
return static (_, _, _) => { };
}

// Input parameters (IMessenger instance, non-generic recipient, token)
ParameterExpression
arg0 = Expression.Parameter(typeof(IMessenger)),
Expand All @@ -178,11 +192,7 @@ static Action<IMessenger, object, TToken> LoadRegistrationMethodsForTypeFallback
// We also add an explicit object conversion to cast the input recipient type to
// the actual specific type, so that the exposed message handlers are accessible.
BlockExpression body = Expression.Block(
from interfaceType in recipientType.GetInterfaces()
where interfaceType.IsGenericType &&
interfaceType.GetGenericTypeDefinition() == typeof(IRecipient<>)
let messageType = interfaceType.GenericTypeArguments[0]
let registrationMethod = MethodInfos.RegisterIRecipient.MakeGenericMethod(messageType, typeof(TToken))
from registrationMethod in registrationMethods
select Expression.Call(registrationMethod, new Expression[]
{
arg0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
using Microsoft.Toolkit.Mvvm.ComponentModel;
using Microsoft.VisualStudio.TestTools.UnitTesting;

#pragma warning disable SA1124
#pragma warning disable SA1124, SA1307, SA1401

#nullable enable

Expand Down Expand Up @@ -371,7 +371,7 @@ public partial class ModelWithValuePropertyWithValidation : ObservableValidator
[MinLength(5)]
private string? value;
}

public partial class ViewModelWithValidatableGeneratedProperties : ObservableValidator
{
[Required]
Expand Down
94 changes: 94 additions & 0 deletions UnitTests/UnitTests.Shared/Mvvm/Test_ObservableValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.ComponentModel;
using System.ComponentModel.DataAnnotations;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using Microsoft.Toolkit.Mvvm.ComponentModel;
using Microsoft.VisualStudio.TestTools.UnitTesting;
Expand Down Expand Up @@ -335,6 +336,54 @@ public void Test_ObservableValidator_ValidateAllProperties()
model.Age = -10;

model.ValidateAllProperties();

Assert.IsTrue(model.HasErrors);
Assert.IsTrue(events.Count == 1);
Assert.IsTrue(events.Any(e => e.PropertyName == nameof(Person.Age)));
}

[TestCategory("Mvvm")]
[TestMethod]
public void Test_ObservableValidator_ValidateAllProperties_WithFallback()
{
var model = new PersonWithDeferredValidation();
var events = new List<DataErrorsChangedEventArgs>();

MethodInfo[] staticMethods = typeof(ObservableValidator).GetMethods(BindingFlags.Static | BindingFlags.NonPublic);
MethodInfo validationMethod = staticMethods.Single(static m => m.Name.Contains("GetValidationActionFallback"));
Func<Type, Action<object>> validationFunc = (Func<Type, Action<object>>)validationMethod.CreateDelegate(typeof(Func<Type, Action<object>>));
Action<object> validationAction = validationFunc(model.GetType());

model.ErrorsChanged += (s, e) => events.Add(e);

validationAction(model);

Assert.IsTrue(model.HasErrors);
Assert.IsTrue(events.Count == 2);

// Note: we can't use an index here because the order used to return properties
// from reflection APIs is an implementation detail and might change at any time.
Assert.IsTrue(events.Any(e => e.PropertyName == nameof(Person.Name)));
Assert.IsTrue(events.Any(e => e.PropertyName == nameof(Person.Age)));

events.Clear();

model.Name = "James";
model.Age = 42;

validationAction(model);

Assert.IsFalse(model.HasErrors);
Assert.IsTrue(events.Count == 2);
Assert.IsTrue(events.Any(e => e.PropertyName == nameof(Person.Name)));
Assert.IsTrue(events.Any(e => e.PropertyName == nameof(Person.Age)));

events.Clear();

model.Age = -10;

validationAction(model);

Assert.IsTrue(model.HasErrors);
Assert.IsTrue(events.Count == 1);
Assert.IsTrue(events.Any(e => e.PropertyName == nameof(Person.Age)));
Expand Down Expand Up @@ -414,6 +463,34 @@ public void Test_ObservableValidator_ValidationWithFormattedDisplayName()
Assert.AreEqual(allErrors[1].ErrorMessage, $"SECOND: {nameof(ValidationWithDisplayName.AnotherRequiredField)}.");
}

// See: https://github.com/CommunityToolkit/WindowsCommunityToolkit/issues/4272
[TestCategory("Mvvm")]
[TestMethod]
[DataRow(typeof(MyBase))]
[DataRow(typeof(MyDerived2))]
public void Test_ObservableRecipient_ValidationOnNonValidatableProperties(Type type)
{
MyBase viewmodel = (MyBase)Activator.CreateInstance(type);

viewmodel.ValidateAll();
}

// See: https://github.com/CommunityToolkit/WindowsCommunityToolkit/issues/4272
[TestCategory("Mvvm")]
[TestMethod]
[DataRow(typeof(MyBase))]
[DataRow(typeof(MyDerived2))]
public void Test_ObservableRecipient_ValidationOnNonValidatableProperties_WithFallback(Type type)
{
MyBase viewmodel = (MyBase)Activator.CreateInstance(type);

MethodInfo[] staticMethods = typeof(ObservableValidator).GetMethods(BindingFlags.Static | BindingFlags.NonPublic);
MethodInfo validationMethod = staticMethods.Single(static m => m.Name.Contains("GetValidationActionFallback"));
Func<Type, Action<object>> validationFunc = (Func<Type, Action<object>>)validationMethod.CreateDelegate(typeof(Func<Type, Action<object>>));

validationFunc(viewmodel.GetType())(viewmodel);
}

public class Person : ObservableValidator
{
private string name;
Expand Down Expand Up @@ -631,5 +708,22 @@ public string AnotherRequiredField
set => SetProperty(ref this.anotherRequiredField, value, true);
}
}

public class MyBase : ObservableValidator
{
public int? MyDummyInt { get; set; } = 0;

public void ValidateAll()
{
ValidateAllProperties();
}
}

public class MyDerived2 : MyBase
{
public string Name { get; set; }

public int SomeRandomproperty { get; set; }
}
}
}