Skip to content

Commit

Permalink
implements 'IEquatable<T>' interface on records (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianoc committed Jun 9, 2024
1 parent aaeec3f commit 0b0c67e
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 15 deletions.
11 changes: 10 additions & 1 deletion Cecilifier.Core.Tests/Tests/Unit/TypeTests.Records.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,22 @@ public void EqualityContractProperty_WhenReferenceRecord_IsEmitted(string kind)
[TestCase(null)]
public void RecordType_Implements_IEquatable(string classOrStruct)
{
var result = RunCecilifier($"public record {classOrStruct} TheRecord;");
var result = RunCecilifier($"public record {classOrStruct} TheRecord(int Value);");

var cecilifiedCode = result.GeneratedCode.ReadToEnd();
Assert.That(cecilifiedCode, Does.Match($"//Record {(classOrStruct ?? "class").PascalCase()} : TheRecord"));
Assert.That(cecilifiedCode, Does.Match("""
(?<recVar>rec_theRecord_\d+)\.Interfaces\.Add\(new InterfaceImplementation\(.+ImportReference\(typeof\(System.IEquatable<>\)\).MakeGenericInstanceType\(\k<recVar>\)\)\);
\s+assembly.MainModule.Types.Add\(\k<recVar>\);
"""));

Assert.That(cecilifiedCode, Does.Match("""
//IEquatable<>.Equals\(TheRecord other\)
\s+var (?<eq>m_equals_\d+) = new MethodDefinition\("Equals", (MethodAttributes.)Public \| \1HideBySig \| \1NewSlot \| \1Virtual, .+TypeSystem.Boolean\);
\s+var (?<param>p_other_\d+) = new ParameterDefinition\("other", ParameterAttributes.None, (?<rec_var>rec_theRecord_\d+)\);
\s+\k<eq>.Parameters.Add\(\k<param>\);
\s+\k<rec_var>.Methods.Add\(\k<eq>\);
"""));
}

[Test]
Expand All @@ -94,6 +102,7 @@ public void PrimaryConstructorParameters_AreMappedToPublicProperties2()

private static void AssertPropertiesFromPrimaryConstructor(string[] expectedNameTypePairs, string cecilifiedCode)
{
//TODO: Take inheritance into account
Span<Range> ranges = stackalloc Range[2];
foreach (var pair in expectedNameTypePairs)
{
Expand Down
15 changes: 13 additions & 2 deletions Cecilifier.Core/CodeGeneration/PrimaryConstructor.Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ void AddInit()
context.WriteCecilExpressions([$"var {ilVar} = {setMethodVar}.Body.GetILProcessor();"]);

propertyGenerator.AddAutoSetterMethodImplementation(in propertyData, ilVar);
context.EmitCilInstruction(ilVar, OpCodes.Ret);
}
}
}
Expand All @@ -118,16 +119,20 @@ internal static void AddPrimaryConstructor(IVisitorContext context, string recor
$"{recordTypeDefinitionVariable}.Methods.Add({ctorVar});"
]);

var ctorIlVar = context.Naming.Constructor(typeDeclaration, false);
if (typeDeclaration.ParameterList?.Parameters == null)
return;

var ctorIlVar = context.Naming.ILProcessor($"ctor_{typeDeclaration.Identifier.ValueText}");
var ctorExps = CecilDefinitionsFactory.MethodBody2(context.Naming, ctorVar, ctorIlVar, Array.Empty<InstructionRepresentation>());
context.WriteCecilExpressions(ctorExps);

//TODO: Extract code from RecordGenerator.GetUniqueParameters() and use instead of `typeDeclaration.ParameterList?.Parameters`
foreach (var parameter in typeDeclaration.ParameterList?.Parameters)
{
context.WriteComment($"Parameter: {parameter.Identifier}");
var paramVar = context.Naming.Parameter(parameter);
var parameterType = context.TypeResolver.Resolve(context.SemanticModel.GetTypeInfo(parameter.Type!).Type);
var paramExps = CecilDefinitionsFactory.Parameter(parameter.Identifier.ValueText, RefKind.None, false, ctorVar, paramVar, parameterType, "paramAttrs", ("", false));
var paramExps = CecilDefinitionsFactory.Parameter(parameter.Identifier.ValueText, RefKind.None, false, ctorVar, paramVar, parameterType, Constants.ParameterAttributes.None, ("", false));
context.WriteCecilExpressions(paramExps);

context.EmitCilInstruction(ctorIlVar, OpCodes.Ldarg_0);
Expand All @@ -139,5 +144,11 @@ internal static void AddPrimaryConstructor(IVisitorContext context, string recor

context.EmitCilInstruction(ctorIlVar, OpCodes.Stfld, backingFieldVar.VariableName);
}

//TODO: Take inheritance into account.
var baseCtor = context.RoslynTypeSystem.SystemObject.GetMembers().OfType<IMethodSymbol>().Single(m => m is { Name: ".ctor" }).MethodResolverExpression(context);
context.EmitCilInstruction(ctorIlVar, OpCodes.Ldarg_0);
context.EmitCilInstruction(ctorIlVar, OpCodes.Call, baseCtor);
context.EmitCilInstruction(ctorIlVar, OpCodes.Ret);
}
}
207 changes: 195 additions & 12 deletions Cecilifier.Core/CodeGeneration/Record.Generator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Cecilifier.Core.AST;
using Cecilifier.Core.Extensions;
using Cecilifier.Core.Misc;
using Cecilifier.Core.Naming;
using Cecilifier.Core.Variables;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Mono.Cecil.Cil;

namespace Cecilifier.Core.CodeGeneration;

Expand All @@ -12,7 +22,174 @@ internal void AddSyntheticMembers(IVisitorContext context, string recordTypeDefi
AddEqualityContractPropertyIfNeeded(context, recordTypeDefinitionVariable, record);
PrimaryConstructorGenerator.AddPropertiesFrom(context, recordTypeDefinitionVariable, record);
PrimaryConstructorGenerator.AddPrimaryConstructor(context, recordTypeDefinitionVariable, record);
AddIEquatableEquals(context, recordTypeDefinitionVariable, record);
}

//TODO: Record struct (no need to check for null, no need to check EqualityContract, etc)
private void AddIEquatableEquals(IVisitorContext context, string recordTypeDefinitionVariable, TypeDeclarationSyntax record)
{
context.WriteNewLine();
context.WriteComment($"IEquatable<>.Equals({record.Identifier.ValueText} other)");
var equalsMethodVar = context.Naming.SyntheticVariable("Equals", ElementKind.Method);
var exps = CecilDefinitionsFactory.Method(
context,
record.Identifier.ValueText(),
equalsMethodVar,
$"{record.Identifier.ValueText()}.Equals", "Equals",
$"MethodAttributes.Public | MethodAttributes.HideBySig | {Constants.Cecil.InterfaceMethodDefinitionAttributes}", //TODO: No NEWSLOT if in derived record
[new ParameterSpec("other", recordTypeDefinitionVariable, RefKind.None, Constants.ParameterAttributes.None)],
Array.Empty<string>(),
ctx => ctx.TypeResolver.Bcl.System.Boolean, out var methodDefinitionVariable);

context.WriteCecilExpressions([..exps, $"{recordTypeDefinitionVariable}.Methods.Add({equalsMethodVar});"]);

// Compare each unique primary constructor parameter to compute equality.
using (context.DefinitionVariables.WithVariable(methodDefinitionVariable))
{
var uniqueParameters = GetUniqueParameters(context, record);
var equalityDataByType = GenerateEqualityComparerMethods(context, uniqueParameters);

List<InstructionRepresentation> instructions = new();
instructions.AddRange(
[
// reference records only
OpCodes.Ldarg_0,
OpCodes.Ldarg_1,
OpCodes.Beq_S.WithBranchOperand("ReferenceEquals"),

OpCodes.Ldarg_1,
OpCodes.Brfalse_S.WithBranchOperand("NotEquals"),
]);

// TODO: reference records only
instructions.AddRange(
[
OpCodes.Ldarg_0,
OpCodes.Callvirt.WithOperand(_equalityContractGetMethodVar),
OpCodes.Ldarg_1,
OpCodes.Callvirt.WithOperand(_equalityContractGetMethodVar),
OpCodes.Call.WithOperand(TypeEqualityOperator(context)),
OpCodes.Brfalse_S.WithBranchOperand("NotEquals")
]);

foreach (var parameter in uniqueParameters)
{
// load default comparer for parameter type.
// IL_001a: call class [System.Collections]System.Collections.Generic.EqualityComparer`1<!0> class [System.Collections]System.Collections.Generic.EqualityComparer`1<int32>::get_Default()
var paramDefVar = context.DefinitionVariables.GetVariable(Utils.BackingFieldNameForAutoProperty(parameter.Identifier.ValueText()), VariableMemberKind.Field, record.Identifier.ValueText());

// Get the default comparer for the parameter type
var parameterType = context.SemanticModel.GetTypeInfo(parameter.Type!).Type.EnsureNotNull();
instructions.Add(OpCodes.Call.WithOperand(equalityDataByType[parameterType.Name].GetDefaultMethodVar));

// load property backing field for 'this'
instructions.Add(OpCodes.Ldarg_0);
instructions.Add(OpCodes.Ldfld.WithOperand(paramDefVar.VariableName));

// load property backing field for 'other'
instructions.Add(OpCodes.Ldarg_1);
instructions.Add(OpCodes.Ldfld.WithOperand(paramDefVar.VariableName));

// compares both backing fields.
instructions.Add(OpCodes.Callvirt.WithOperand(equalityDataByType[parameterType.Name].EqualsMethodVar));
instructions.Add(OpCodes.Brfalse.WithBranchOperand("NotEquals"));
}
instructions.AddRange(
[
OpCodes.Br_S.WithBranchOperand("ReferenceEquals"), // if the code reached this point all properties matched.
OpCodes.Ldc_I4_0.WithInstructionMarker("NotEquals"),
OpCodes.Ret,
OpCodes.Ldc_I4_1.WithInstructionMarker("ReferenceEquals"),
OpCodes.Ret
]);

var ilVar = context.Naming.ILProcessor("Equals");
var equalsExps = CecilDefinitionsFactory.MethodBody2(context.Naming, equalsMethodVar, ilVar, instructions.ToArray());
context.WriteCecilExpressions(equalsExps);
}
}

private IDictionary<string, (string GetDefaultMethodVar, string EqualsMethodVar)> GenerateEqualityComparerMethods(IVisitorContext context, IReadOnlyList<ParameterSyntax> uniqueParameters)
{
Dictionary<string, (string, string)> equalityComparerDataByType = new();

foreach (var parameter in uniqueParameters)
{
var openEqualityComparerType = context.TypeResolver.Resolve(context.SemanticModel.Compilation.GetTypeByMetadataName(typeof(EqualityComparer<>).FullName!));
var parameterType = context.SemanticModel.GetTypeInfo(parameter.Type!).Type.EnsureNotNull();
if (equalityComparerDataByType.ContainsKey(parameterType.Name))
continue;

var equalityComparerOfParameterType = openEqualityComparerType.MakeGenericInstanceType(context.TypeResolver.Resolve(parameterType));
var openGetDefaultMethodVar = context.Naming.SyntheticVariable("openget_Default", ElementKind.LocalVariable);

var getDefaultMethodVar = context.Naming.SyntheticVariable($"get_Default_{parameterType.Name}", ElementKind.MemberReference);
string[] defaultPropertyGetterExps =
[
$$"""var {{openGetDefaultMethodVar}} = assembly.MainModule.ImportReference(typeof(System.Collections.Generic.EqualityComparer<>)).Resolve().Methods.First(m => m.Name == "get_Default");""",
$$"""var {{getDefaultMethodVar}} = new MethodReference("get_Default", assembly.MainModule.ImportReference({{openGetDefaultMethodVar}}).ReturnType)""",
"{",
$"\tDeclaringType = {equalityComparerOfParameterType},",
$"\tHasThis = {openGetDefaultMethodVar}.HasThis,",
$"\tExplicitThis = {openGetDefaultMethodVar}.ExplicitThis,",
$"\tCallingConvention = {openGetDefaultMethodVar}.CallingConvention,",
"};"
];

context.WriteCecilExpressions(defaultPropertyGetterExps);

var equalsMethodVar = context.Naming.SyntheticVariable("Equals", ElementKind.MemberReference);
var equalityComparerOpenEqualsMethodVar = context.Naming.SyntheticVariable("Equals", ElementKind.LocalVariable);
string[] equalityComparerEqualsMethodExps = [
$$"""var {{equalityComparerOpenEqualsMethodVar}} = assembly.MainModule.ImportReference(typeof(System.Collections.Generic.EqualityComparer<>)).Resolve().Methods.First(m => m.Name == "Equals");""",
$"""var {equalsMethodVar} = new MethodReference("Equals", assembly.MainModule.ImportReference({equalityComparerOpenEqualsMethodVar}).ReturnType)""",
"{",
$"\tDeclaringType = {equalityComparerOfParameterType},",
$"\tHasThis = {equalityComparerOpenEqualsMethodVar}.HasThis,",
$"\tExplicitThis = {equalityComparerOpenEqualsMethodVar}.ExplicitThis,",
$"\tCallingConvention = {equalityComparerOpenEqualsMethodVar}.CallingConvention",
"};",
$"{equalsMethodVar}.Parameters.Add({equalityComparerOpenEqualsMethodVar}.Parameters[0]);",
$"{equalsMethodVar}.Parameters.Add({equalityComparerOpenEqualsMethodVar}.Parameters[1]);"
];
context.WriteCecilExpressions(equalityComparerEqualsMethodExps);

equalityComparerDataByType[parameterType.Name] = (getDefaultMethodVar, equalsMethodVar);
}

return equalityComparerDataByType;
}

/// <summary>
/// Each primary constructor parameter that does not have a matching parameter
/// in the base record (or the base record of the base record and so on) will have
/// a property associated with it.
/// </summary>
/// <param name="context"></param>
/// <param name="type"></param>
/// <returns>Returns a list of parameters that does not exist in the base type of the <paramref name="type"/> or in its parents.</returns>
static IReadOnlyList<ParameterSyntax> GetUniqueParameters(IVisitorContext context, TypeDeclarationSyntax type)
{
var records = context.SemanticModel.SyntaxTree.GetRoot().DescendantNodesAndSelf().OfType<RecordDeclarationSyntax>().ToArray();
List<ParameterSyntax> basesParameters = new();
var current = type;
while(true)
{
if (current.BaseList?.Types.Count is null or 0)
break;

var baseRecordName = ((IdentifierNameSyntax) current.BaseList!.Types.First().Type).Identifier;
var baseRecord = records.Single(r => r.Identifier.ValueText() == baseRecordName.ValueText);
basesParameters.AddRange(baseRecord.ParameterList!.Parameters);

current = baseRecord;
}

return (IReadOnlyList<ParameterSyntax>)
type.ParameterList?.Parameters.Where(candidate => basesParameters.All(bp => candidate.Identifier.ValueText != bp.Identifier.ValueText)).ToList()
?? Array.Empty<ParameterSyntax>();
}

private void AddEqualityContractPropertyIfNeeded(IVisitorContext context, string recordTypeDefinitionVariable, TypeDeclarationSyntax record)
{
if (record.IsKind(SyntaxKind.RecordStructDeclaration))
Expand Down Expand Up @@ -52,20 +229,13 @@ private void AddEqualityContractPropertyIfNeeded(IVisitorContext context, string
null);

var getterIlVar = context.Naming.ILProcessor("EqualityContract_get");

var getTypeFromHandleSymbol = (IMethodSymbol) context.RoslynTypeSystem.SystemType.GetMembers("GetTypeFromHandle").First();
context.WriteCecilExpression($"var {getterIlVar} = {_equalityContractGetMethodVar}.Body.GetILProcessor();");
context.WriteNewLine();
var getTypeFromHandleSymbol = (IMethodSymbol) context.RoslynTypeSystem.SystemType.GetMembers("GetTypeFromHandle").First();
var getterBodyExps = CecilDefinitionsFactory.MethodBody2(
context.Naming,
_equalityContractGetMethodVar,
getterIlVar,
[
OpCodes.Ldtoken.WithOperand(recordTypeDefinitionVariable),
OpCodes.Call.WithOperand(getTypeFromHandleSymbol.MethodResolverExpression(context)),
OpCodes.Ret
]);

context.WriteCecilExpressions(getterBodyExps);
context.EmitCilInstruction(getterIlVar, OpCodes.Ldtoken, recordTypeDefinitionVariable);
context.EmitCilInstruction(getterIlVar, OpCodes.Call, getTypeFromHandleSymbol.MethodResolverExpression(context));
context.EmitCilInstruction(getterIlVar, OpCodes.Ret);
}

private static bool HasBaseRecord(TypeDeclarationSyntax record)
Expand All @@ -81,5 +251,18 @@ candidate is RecordDeclarationSyntax candidateBase

return found != null;
}

private static string TypeEqualityOperator(IVisitorContext context)
{
var typeEqualityOperator = context.RoslynTypeSystem.SystemType.GetMembers("op_Equality")
.OfType<IMethodSymbol>()
.Single(Has2SystemTypeParameters).MethodResolverExpression(context);

return typeEqualityOperator;

bool Has2SystemTypeParameters(IMethodSymbol candidate) =>
candidate.Parameters.Length == 2
&& SymbolEqualityComparer.Default.Equals(candidate.Parameters[0].Type, candidate.Parameters[1].Type);
//&& SymbolEqualityComparer.Default.Equals(context.RoslynTypeSystem.SystemType.WithNullableAnnotation(NullableAnnotation.Annotated), candidate.Parameters[0].Type);
}
}

0 comments on commit 0b0c67e

Please sign in to comment.