Skip to content

Commit

Permalink
adds support for handling IEnumerable<T> and IEnumerable in foreach
Browse files Browse the repository at this point in the history
part of #235
  • Loading branch information
adrianoc committed Jul 22, 2023
1 parent 54cd5c9 commit 7647cfa
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 18 deletions.
119 changes: 119 additions & 0 deletions Cecilifier.Core.Tests/Tests/Unit/ForEachStatementTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;

namespace Cecilifier.Core.Tests.Tests.Unit;
Expand Down Expand Up @@ -40,4 +44,119 @@ static void Main()
\s+il_main_\d+.Emit\(OpCodes.Callvirt, m_getEnumerator_\d+\);
"""), "GetEnumerator() defined in the snippet should be used.");
}

[Test]
public void GenericEnumerable()
{
var result = RunCecilifier("""
using System;
class Foo
{
public void M(System.Collections.Generic.IList<int> e)
{
foreach(var v in e)
Console.WriteLine(v);
}
}
""");
var cecilifiedCode = result.GeneratedCode.ReadToEnd();
Assert.That(cecilifiedCode, Does.Match("""
\s+//foreach\(var v in e\)...
\s+il_M_2.Emit\(OpCodes.Ldarg_1\);
"""), "enumerable passed to the method should be used.");

Assert.That(cecilifiedCode, Does.Match("""
\s+//variable to store the returned 'IEnumerator<T>'.
\s+il_M_\d+.Emit\(OpCodes.Callvirt, .+ImportReference\(.+ResolveMethod\(.+System.Collections.Generic.IEnumerable<System.Int32>.+, "GetEnumerator",.+\)\)\);
"""), "IEnumerable<int>.GetEnumerator() defined in the snippet should be used.");
}

[Test]
public void EnumerableImplementingGenericAndNonGenericIEnumerator()
{
// The difference from this test to the one above is very small, but very important: this test checks that List<T>.Enumerable, a value type
// implementing the enumerator pattern ...
var listOfTEnumerator = typeof(List<>.Enumerator);

var expected= new[] { typeof(IEnumerator<>), typeof(IEnumerator), typeof(IDisposable) };
CollectionAssert.AreEquivalent(
expected,
listOfTEnumerator.GetInterfaces().Select(itf => itf.IsConstructedGenericType ? itf.GetGenericTypeDefinition() : itf));

var result = RunCecilifier("""
using System;
class Foo
{
public void M(System.Collections.Generic.List<int> e)
{
foreach(var v in e)
Console.WriteLine(v);
}
}
""");
var cecilifiedCode = result.GeneratedCode.ReadToEnd();
Assert.That(cecilifiedCode, Does.Match("""
//variable to store the returned 'IEnumerator<T>'.
\s+il_M_\d+.Emit\(OpCodes.Callvirt, .+ImportReference\(.+ResolveMethod\(typeof\(System.Collections.Generic.List<System.Int32>\), "GetEnumerator",.+\)\)\);
\s+var l_enumerator_\d+ = new VariableDefinition\(.+ImportReference\(typeof\(System.Collections.Generic.List<int>.Enumerator\)\)\);
"""));

Assert.That(cecilifiedCode, Does.Match("""
il_M_\d+.Emit\(OpCodes.Ldloca, l_enumerator_\d+\);
\s+il_M_\d+.Emit\(OpCodes.Call, .+ImportReference\(.+ResolveMethod\(typeof\(.+List<System.Int32>.Enumerator\), "MoveNext",.+\)\)\);
"""));

Assert.That(cecilifiedCode, Does.Match("""
il_M_\d+.Emit\(OpCodes.Ldloca, l_enumerator_\d+\);
\s+il_M_\d+.Emit\(OpCodes.Call, .+ImportReference\(.+ResolveMethod\(typeof\(.+List<System.Int32>.Enumerator\), "get_Current",.+\)\)\);
"""));
}

// I've considered adding a test for instantiated IEnumerable<T> (for instance, IEnumerable<int>) but it doesn't look like to add any value since the generated code
// is very similar to the one in this test and any open/closed differences should be covered by generics handling.
[Test]
public void OpenIEnumerable()
{
var result = RunCecilifier("void Run<T>(System.Collections.Generic.IEnumerable<T> e) { foreach(var v in e) {} }");
var cecilifiedCode = result.GeneratedCode.ReadToEnd();

Assert.That(cecilifiedCode, Does.Match("""
il_run_\d+.Emit\(OpCodes.Callvirt, .+ImportReference\(.+ResolveMethod\(typeof\(System.Collections.IEnumerator\), "MoveNext",.+\)\)\);
"""));
Assert.That(cecilifiedCode, Does.Match("""var l_openget_Current_\d+ = .+ImportReference\(typeof\(.+IEnumerator<>\)\).Resolve\(\).Methods.First\(m => m.Name == "get_Current"\);"""));

/*
IL_0000: ldarg.0
IL_0001: callvirt instance class [System.Runtime]System.Collections.Generic.IEnumerator`1<!0> class [System.Runtime]System.Collections.Generic.IEnumerable`1<!!T>::GetEnumerator()
IL_0006: stloc.0
.try
{
// sequence point: hidden
IL_0007: br.s IL_0010
// loop start (head: IL_0010)
IL_0009: ldloc.0
IL_000a: callvirt instance !0 class [System.Runtime]System.Collections.Generic.IEnumerator`1<!!T>::get_Current()
IL_000f: pop
IL_0010: ldloc.0
IL_0011: callvirt instance bool [System.Runtime]System.Collections.IEnumerator::MoveNext()
IL_0016: brtrue.s IL_0009
// end loop
IL_0018: leave.s IL_0024
} // end .try
finally
{
// sequence point: hidden
IL_001a: ldloc.0
IL_001b: brfalse.s IL_0023
IL_001d: ldloc.0
IL_001e: callvirt instance void [System.Runtime]System.IDisposable::Dispose()
// sequence point: hidden
IL_0023: endfinally
} // end handler
*/
}
}
88 changes: 71 additions & 17 deletions Cecilifier.Core/AST/StatementVisitor.ForEach.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,52 @@ public override void VisitForEachStatement(ForEachStatementSyntax node)
{
ExpressionVisitor.Visit(Context, _ilVar, node.Expression);

var forEachTargetType = Context.GetTypeInfo(node.Expression).Type.EnsureNotNull();

var getEnumeratorMethod = forEachTargetType.GetMembers("GetEnumerator").OfType<IMethodSymbol>().Single();
var enumerableType = Context.GetTypeInfo(node.Expression).Type.EnsureNotNull();
var getEnumeratorMethod = GetEnumeratorMethodFor(enumerableType);
var enumeratorType = EnumeratorTypeFor(getEnumeratorMethod);
var moveNextMethod = enumeratorType.GetMembers("MoveNext").Single().EnsureNotNull<ISymbol, IMethodSymbol>();
var currentGetter = getEnumeratorMethod.ReturnType.GetMembers("get_Current").Single().EnsureNotNull<ISymbol, IMethodSymbol>();

var enumeratorMoveNextMethod = MoveNextMethodFor(enumeratorType);
var enumeratorCurrentMethod = CurrentMethodFor(enumeratorType);

ProcessForEach(node, getEnumeratorMethod, enumeratorCurrentMethod, enumeratorMoveNextMethod);
}

private void ProcessForEach(ForEachStatementSyntax node, IMethodSymbol getEnumeratorMethod, IMethodSymbol enumeratorCurrentMethod, IMethodSymbol enumeratorMoveNextMethod)
{
// Adds a variable to store current value in the foreach loop.
Context.WriteNewLine();
Context.WriteComment("variable to store current value in the foreach loop.");
var foreachCurrentValueVarName = CodeGenerationHelpers.AddLocalVariableToCurrentMethod(Context, node.Identifier.ValueText, Context.TypeResolver.Resolve(currentGetter.GetMemberType())).VariableName;
var foreachCurrentValueVarName = CodeGenerationHelpers.AddLocalVariableToCurrentMethod(Context, node.Identifier.ValueText, Context.TypeResolver.Resolve(enumeratorCurrentMethod.GetMemberType())).VariableName;

// Get the enumerator..
Context.WriteNewLine();
Context.WriteComment("variable to store the returned 'IEnumerator<T>'.");
AddMethodCall(_ilVar, getEnumeratorMethod);
var enumeratorVariableName = CodeGenerationHelpers.StoreTopOfStackInLocalVariable(Context, _ilVar, "enumerator", getEnumeratorMethod.ReturnType).VariableName;

var endOfLoopLabelVar = Context.Naming.Label("endForEach");
CreateCilInstruction(_ilVar, endOfLoopLabelVar, OpCodes.Nop);

// loop while enumerable.MoveNext() == true

var forEachLoopBegin = AddCilInstructionWithLocalVariable(_ilVar, OpCodes.Nop);

Context.EmitCilInstruction(_ilVar, OpCodes.Ldloc, enumeratorVariableName);
AddMethodCall(_ilVar, moveNextMethod);

var loadOpCode = getEnumeratorMethod.ReturnType.IsValueType || getEnumeratorMethod.ReturnType.TypeKind == TypeKind.TypeParameter ? OpCodes.Ldloca : OpCodes.Ldloc;
//var loadOpCode = OpCodes.Ldloc;
Context.EmitCilInstruction(_ilVar, loadOpCode, enumeratorVariableName);
AddMethodCall(_ilVar, enumeratorMoveNextMethod);
Context.EmitCilInstruction(_ilVar, OpCodes.Brfalse, endOfLoopLabelVar);

Context.EmitCilInstruction(_ilVar, OpCodes.Ldloc, enumeratorVariableName);
AddMethodCall(_ilVar, currentGetter);
Context.EmitCilInstruction(_ilVar, loadOpCode, enumeratorVariableName);
AddMethodCall(_ilVar, enumeratorCurrentMethod);
Context.EmitCilInstruction(_ilVar, OpCodes.Stloc, foreachCurrentValueVarName);

// process body of foreach
Context.WriteNewLine();
Context.WriteComment("foreach body");
node.Statement.Accept(this);
Context.WriteComment("end of foreach body");
Context.WriteNewLine();

Context.EmitCilInstruction(_ilVar, OpCodes.Br, forEachLoopBegin);
Context.WriteNewLine();
Context.WriteComment("end of foreach loop");
Expand All @@ -65,14 +71,62 @@ public override void VisitForEachVariableStatement(ForEachVariableStatementSynta
base.VisitForEachVariableStatement(node);
}

private IMethodSymbol GetEnumeratorMethodFor(ITypeSymbol enumerableType)
{
var interfacesToCheck = new[] { Context.RoslynTypeSystem.SystemCollectionsGenericIEnumerableOfT, Context.RoslynTypeSystem.SystemCollectionsIEnumerable };
return GetMethodOnTypeOrImplementedInterfaces(enumerableType, interfacesToCheck, "GetEnumerator");
}

private IMethodSymbol GetMethodOnTypeOrImplementedInterfaces(ITypeSymbol inType, ITypeSymbol[] interfacesToCheck, string methodName)
{
var found = inType.GetMembers(methodName).SingleOrDefault();
if (found != null)
return (IMethodSymbol) found;

int i = -1;
while (found == null && ++i < interfacesToCheck.Length)
{
found = inType.Interfaces.SingleOrDefault(itf => SymbolEqualityComparer.Default.Equals(itf.OriginalDefinition, interfacesToCheck[i]))?.EnsureNotNull<ISymbol, ITypeSymbol>().GetMembers(methodName).SingleOrDefault();
}

return found.EnsureNotNull<ISymbol, IMethodSymbol>();
}

/*
* MoveNext() method may be implemented in ...
* 1. IEnumerator
* 2. A type following the enumerator pattern
*/
private IMethodSymbol MoveNextMethodFor(ITypeSymbol enumeratorType)
{
var interfacesToCheck = new[] { Context.RoslynTypeSystem.SystemCollectionsGenericIEnumeratorOfT, Context.RoslynTypeSystem.SystemCollectionsIEnumerator };
return GetMethodOnTypeOrImplementedInterfaces(enumeratorType, interfacesToCheck, "MoveNext");
}

private IMethodSymbol CurrentMethodFor(ITypeSymbol enumeratorType)
{
var interfacesToCheck = new[] { Context.RoslynTypeSystem.SystemCollectionsGenericIEnumeratorOfT, Context.RoslynTypeSystem.SystemCollectionsIEnumerator };
return GetMethodOnTypeOrImplementedInterfaces(enumeratorType, interfacesToCheck, "get_Current");
}

/*
* either the type returned by GetEnumerator() implements `IEnumerator` interface *or*
* it abides to the enumerator patterns, i.e, it has the following members:
* it abides to the enumerator pattern, i.e, it has the following members:
* 1. public bool MoveNext() method
* 2. public T Current property ('T' can be any type)
*/
private ITypeSymbol EnumeratorTypeFor(IMethodSymbol getEnumeratorMethod)
{
var moveNext = getEnumeratorMethod.ReturnType.GetMembers("MoveNext").SingleOrDefault();
if (moveNext != null)
return getEnumeratorMethod.ReturnType;

if (SymbolEqualityComparer.Default.Equals(getEnumeratorMethod.ReturnType.OriginalDefinition, Context.RoslynTypeSystem.SystemCollectionsGenericIEnumeratorOfT))
return getEnumeratorMethod.ReturnType;

if (SymbolEqualityComparer.Default.Equals(getEnumeratorMethod.ReturnType.OriginalDefinition, Context.RoslynTypeSystem.SystemCollectionsIEnumerator))
return getEnumeratorMethod.ReturnType;

var enumeratorType = getEnumeratorMethod.ReturnType.Interfaces.SingleOrDefault(itf => itf.Name == "IEnumerator");
return enumeratorType ?? getEnumeratorMethod.ReturnType;
}
Expand Down
4 changes: 4 additions & 0 deletions Cecilifier.Core/Misc/CecilifierContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ public TypeInfo GetTypeInfo(ExpressionSyntax expressionSyntax)

public void WriteCecilExpression(string expression)
{
if (expression.Contains("Ldloc"))
{
int x = 01;
}
CecilifiedLineNumber += expression.CountNewLines();
output.AddLast($"{identation}{expression}");
}
Expand Down
12 changes: 11 additions & 1 deletion Cecilifier.Core/TypeSystem/RoslynTypeSystem.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using Cecilifier.Core.AST;
using Microsoft.CodeAnalysis;
Expand Down Expand Up @@ -34,7 +36,11 @@ public RoslynTypeSystem(IVisitorContext ctx)
IsByRefLikeAttribute = ctx.SemanticModel.Compilation.GetTypeByMetadataName(typeof(IsByRefLikeAttribute).FullName);
SystemObsoleteAttribute = ctx.SemanticModel.Compilation.GetTypeByMetadataName(typeof(ObsoleteAttribute).FullName);
SystemValueType = ctx.SemanticModel.Compilation.GetTypeByMetadataName(typeof(ValueType).FullName);
SystemRuntimeCompilerServicesRuntimeHelpers = ctx.SemanticModel.Compilation.GetTypeByMetadataName(typeof(RuntimeHelpers).FullName);
SystemRuntimeCompilerServicesRuntimeHelpers = ctx.SemanticModel.Compilation.GetTypeByMetadataName(typeof(RuntimeHelpers).FullName);
SystemCollectionsIEnumerator = ctx.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_IEnumerator);
SystemCollectionsGenericIEnumeratorOfT = ctx.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerator_T);
SystemCollectionsIEnumerable = ctx.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_IEnumerable);
SystemCollectionsGenericIEnumerableOfT = ctx.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T);
}

public ITypeSymbol SystemIndex { get; }
Expand All @@ -51,6 +57,10 @@ public RoslynTypeSystem(IVisitorContext ctx)
public ITypeSymbol SystemBoolean { get; }
public ITypeSymbol SystemActivator { get; }
public ITypeSymbol SystemIDisposable { get; }
public ITypeSymbol SystemCollectionsIEnumerator { get; }
public ITypeSymbol SystemCollectionsGenericIEnumeratorOfT { get; }
public ITypeSymbol SystemCollectionsIEnumerable { get; }
public ITypeSymbol SystemCollectionsGenericIEnumerableOfT { get; }
public ITypeSymbol CallerArgumentExpressionAttribute { get; }
public ITypeSymbol IsReadOnlyAttribute { get; }
public ITypeSymbol IsByRefLikeAttribute { get; }
Expand Down

0 comments on commit 7647cfa

Please sign in to comment.