From f047ab08fd9873b7e6a071bf31e7213742d4b12a Mon Sep 17 00:00:00 2001 From: Adriano Carlos Verona Date: Fri, 16 Jun 2023 08:39:43 -0400 Subject: [PATCH] adds support for foreach on classes implementing the enumerable pattern (part of #235) the 'enumerator' pattern can be summarized as, give a type that implements: 1. a 'public bool MoveNext()' method and 2. has a public property named 'Current' it can be used as the target of a foreach. --- .../Tests/Unit/ForEachStatementTests.cs | 43 ++++++++++ .../AST/StatementVisitor.ForEach.cs | 80 +++++++++++++++++++ Cecilifier.Core/AST/StatementVisitor.cs | 1 - Cecilifier.Core/AST/SyntaxWalkerBase.cs | 1 + 4 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 Cecilifier.Core.Tests/Tests/Unit/ForEachStatementTests.cs create mode 100644 Cecilifier.Core/AST/StatementVisitor.ForEach.cs diff --git a/Cecilifier.Core.Tests/Tests/Unit/ForEachStatementTests.cs b/Cecilifier.Core.Tests/Tests/Unit/ForEachStatementTests.cs new file mode 100644 index 00000000..45847583 --- /dev/null +++ b/Cecilifier.Core.Tests/Tests/Unit/ForEachStatementTests.cs @@ -0,0 +1,43 @@ +using NUnit.Framework; + +namespace Cecilifier.Core.Tests.Tests.Unit; + +[TestFixture] +public class ForEachStatementTests : CecilifierUnitTestBase +{ + // https://cutt.ly/swrhz6VE + //[TestCase("struct")] + [TestCase("sealed class")] + public void NonDisposableGetEnumeratorPattern(string enumeratorKind) + { + // Compiler uses GetEnumerator() method, does not require implementing IEnumerable + var result = RunCecilifier($$""" + public {{enumeratorKind}} Enumerator + { + public int Current => 1; + public bool MoveNext() => false; + + public Enumerator GetEnumerator() => default(Enumerator); + } + + //TODO: change to top level statements when order of visiting of top level/classes gets fixed. + class Driver + { + static void Main() + { + foreach(var v in new Enumerator()) {} + } + } + """); + var cecilifiedCode = result.GeneratedCode.ReadToEnd(); + Assert.That(cecilifiedCode, Does.Match(""" + \s+//foreach\(var v in new Enumerator\(\)\) {} + \s+il_main_\d+.Emit\(OpCodes.Newobj, ctor_enumerator_\d+\); + """), "enumerator type defined in the snippet should be used."); + + Assert.That(cecilifiedCode, Does.Match(""" + \s+//variable to store the returned 'IEnumerator'. + \s+il_main_\d+.Emit\(OpCodes.Callvirt, m_getEnumerator_\d+\); + """), "GetEnumerator() defined in the snippet should be used."); + } +} diff --git a/Cecilifier.Core/AST/StatementVisitor.ForEach.cs b/Cecilifier.Core/AST/StatementVisitor.ForEach.cs new file mode 100644 index 00000000..8a5bf6ab --- /dev/null +++ b/Cecilifier.Core/AST/StatementVisitor.ForEach.cs @@ -0,0 +1,80 @@ +using System.Linq; +using Cecilifier.Core.Extensions; +using Cecilifier.Core.Misc; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Mono.Cecil.Cil; + +namespace Cecilifier.Core.AST +{ + internal partial class StatementVisitor + { + public override void VisitForEachStatement(ForEachStatementSyntax node) + { + ExpressionVisitor.Visit(Context, _ilVar, node.Expression); + + var forEachTargetType = Context.GetTypeInfo(node.Expression).Type.EnsureNotNull(); + + var getEnumeratorMethod = forEachTargetType.GetMembers("GetEnumerator").OfType().Single(); + var enumeratorType = EnumeratorTypeFor(getEnumeratorMethod); + var moveNextMethod = enumeratorType.GetMembers("MoveNext").Single().EnsureNotNull(); + var currentGetter = getEnumeratorMethod.ReturnType.GetMembers("get_Current").Single().EnsureNotNull(); + + // Adds a variable to store current value in the foreach loop. + Context.WriteNewLine(); + Context.WriteComment("variable to store current value in the foreach loop."); + var foreachCurrentValueVarName = CodeGenerationHelpers.AddLocalVariableToCurrentMethod(Context, node.Identifier.ValueText, Context.TypeResolver.Resolve(currentGetter.GetMemberType())).VariableName; + + // Get the enumerator.. + Context.WriteNewLine(); + Context.WriteComment("variable to store the returned 'IEnumerator'."); + AddMethodCall(_ilVar, getEnumeratorMethod); + var enumeratorVariableName = CodeGenerationHelpers.StoreTopOfStackInLocalVariable(Context, _ilVar, "enumerator", getEnumeratorMethod.ReturnType).VariableName; + + var endOfLoopLabelVar = Context.Naming.Label("endForEach"); + CreateCilInstruction(_ilVar, endOfLoopLabelVar, OpCodes.Nop); + + // loop while enumerable.MoveNext() == true + + var forEachLoopBegin = AddCilInstructionWithLocalVariable(_ilVar, OpCodes.Nop); + + Context.EmitCilInstruction(_ilVar, OpCodes.Ldloc, enumeratorVariableName); + AddMethodCall(_ilVar, moveNextMethod); + Context.EmitCilInstruction(_ilVar, OpCodes.Brfalse, endOfLoopLabelVar); + + Context.EmitCilInstruction(_ilVar, OpCodes.Ldloc, enumeratorVariableName); + AddMethodCall(_ilVar, currentGetter); + Context.EmitCilInstruction(_ilVar, OpCodes.Stloc, foreachCurrentValueVarName); + + // process body of foreach + Context.WriteNewLine(); + Context.WriteComment("foreach body"); + node.Statement.Accept(this); + Context.WriteComment("end of foreach body"); + Context.WriteNewLine(); + + Context.EmitCilInstruction(_ilVar, OpCodes.Br, forEachLoopBegin); + Context.WriteNewLine(); + Context.WriteComment("end of foreach loop"); + Context.WriteCecilExpression($"{_ilVar}.Append({endOfLoopLabelVar});"); + Context.WriteNewLine(); + } + + public override void VisitForEachVariableStatement(ForEachVariableStatementSyntax node) + { + base.VisitForEachVariableStatement(node); + } + + /* + * either the type returned by GetEnumerator() implements `IEnumerator` interface *or* + * it abides to the enumerator patterns, i.e, it has the following members: + * 1. public bool MoveNext() method + * 2. public T Current property ('T' can be any type) + */ + private ITypeSymbol EnumeratorTypeFor(IMethodSymbol getEnumeratorMethod) + { + var enumeratorType = getEnumeratorMethod.ReturnType.Interfaces.SingleOrDefault(itf => itf.Name == "IEnumerator"); + return enumeratorType ?? getEnumeratorMethod.ReturnType; + } + } +} diff --git a/Cecilifier.Core/AST/StatementVisitor.cs b/Cecilifier.Core/AST/StatementVisitor.cs index 20df31eb..aa1d2429 100644 --- a/Cecilifier.Core/AST/StatementVisitor.cs +++ b/Cecilifier.Core/AST/StatementVisitor.cs @@ -267,7 +267,6 @@ void FinallyBlockHandler(string finallyEndVar) } public override void VisitLocalFunctionStatement(LocalFunctionStatementSyntax node) => node.Accept(new MethodDeclarationVisitor(Context)); - public override void VisitForEachStatement(ForEachStatementSyntax node) { LogUnsupportedSyntax(node); } public override void VisitWhileStatement(WhileStatementSyntax node) { LogUnsupportedSyntax(node); } public override void VisitLockStatement(LockStatementSyntax node) { LogUnsupportedSyntax(node); } public override void VisitUnsafeStatement(UnsafeStatementSyntax node) { LogUnsupportedSyntax(node); } diff --git a/Cecilifier.Core/AST/SyntaxWalkerBase.cs b/Cecilifier.Core/AST/SyntaxWalkerBase.cs index 1807665b..cec804d2 100644 --- a/Cecilifier.Core/AST/SyntaxWalkerBase.cs +++ b/Cecilifier.Core/AST/SyntaxWalkerBase.cs @@ -72,6 +72,7 @@ protected void AddMethodCall(string ilVar, IMethodSymbol method, bool isAccessOn } else { + EnsureForwardedMethod(Context, method, Array.Empty()); var operand = method.MethodResolverExpression(Context); Context.EmitCilInstruction(ilVar, opCode, operand); }