Skip to content

Commit

Permalink
[Compiler Add] allow multiple smart contract exist in one project (#908)
Browse files Browse the repository at this point in the history
* allow multiple smart contract exist in one project

* fix remaining issues

* update unit tests

* code optimization

* Update tests/Neo.SmartContract.TestEngine/TestEngine.cs

* Test it: Compile to artifacts

* clean changes

* fix complication issue

* multiple smart contract topology analysis. Making it easier to support cross contract call.

* Move artifact generation to the test project

* fix conflict

* update neo

* fix error

* this pr apply latest neo to devpack

* update signle contract check

* add comments

---------

Co-authored-by: Shargon <[email protected]>
  • Loading branch information
Jim8y and shargon authored Feb 25, 2024
1 parent 7f64dc4 commit 7add5a9
Show file tree
Hide file tree
Showing 20 changed files with 442 additions and 287 deletions.
318 changes: 101 additions & 217 deletions src/Neo.Compiler.CSharp/CompilationContext.cs

Large diffs are not rendered by default.

249 changes: 249 additions & 0 deletions src/Neo.Compiler.CSharp/CompilationEngine.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
// Copyright (C) 2015-2024 The Neo Project.
//
// The Neo.Compiler.CSharp is free software distributed under the MIT
// software license, see the accompanying file LICENSE in the main directory
// of the project or http://www.opensource.org/licenses/mit-license.php
// for more details.
//
// Redistribution and use in source and binary forms with or without
// modifications are permitted.

extern alias scfx;

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Neo.Json;
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Xml.Linq;
using Akka.Util.Internal;
using BigInteger = System.Numerics.BigInteger;

namespace Neo.Compiler
{
public class CompilationEngine
{
internal Compilation? Compilation;
internal Options Options { get; private set; }
private static readonly MetadataReference[] CommonReferences;
private static readonly Dictionary<string, MetadataReference> MetaReferences = new();
internal readonly Dictionary<INamedTypeSymbol, CompilationContext> Contexts = new(SymbolEqualityComparer.Default);

static CompilationEngine()
{
string coreDir = Path.GetDirectoryName(typeof(object).Assembly.Location)!;
CommonReferences = new MetadataReference[]
{
MetadataReference.CreateFromFile(Path.Combine(coreDir, "System.Runtime.dll")),
MetadataReference.CreateFromFile(Path.Combine(coreDir, "System.Runtime.InteropServices.dll")),
MetadataReference.CreateFromFile(typeof(string).Assembly.Location),
MetadataReference.CreateFromFile(typeof(DisplayNameAttribute).Assembly.Location),
MetadataReference.CreateFromFile(typeof(BigInteger).Assembly.Location)
};
}

public CompilationEngine(Options options)
{
Options = options;
}

public List<CompilationContext> Compile(IEnumerable<string> sourceFiles, IEnumerable<MetadataReference> references)
{
IEnumerable<SyntaxTree> syntaxTrees = sourceFiles.OrderBy(p => p).Select(p => CSharpSyntaxTree.ParseText(File.ReadAllText(p), options: Options.GetParseOptions(), path: p));
CSharpCompilationOptions compilationOptions = new(OutputKind.DynamicallyLinkedLibrary, deterministic: true, nullableContextOptions: Options.Nullable);
Compilation = CSharpCompilation.Create(null, syntaxTrees, references, compilationOptions);
return CompileProjectContracts(Compilation);
}

public List<CompilationContext> CompileSources(string[] sourceFiles)
{
List<MetadataReference> references = new(CommonReferences)
{
MetadataReference.CreateFromFile(typeof(scfx.Neo.SmartContract.Framework.SmartContract).Assembly.Location)
};
return Compile(sourceFiles, references);
}

public List<CompilationContext> CompileProject(string csproj)
{
Compilation = GetCompilation(csproj);
return CompileProjectContracts(Compilation);
}

private List<CompilationContext> CompileProjectContracts(Compilation compilation)
{
var classDependencies = new Dictionary<INamedTypeSymbol, List<INamedTypeSymbol>>(SymbolEqualityComparer.Default);
var allSmartContracts = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);

foreach (var tree in compilation.SyntaxTrees)
{
var semanticModel = compilation.GetSemanticModel(tree);
var classNodes = tree.GetRoot().DescendantNodes().OfType<ClassDeclarationSyntax>();

foreach (var classNode in classNodes)
{
var classSymbol = semanticModel.GetDeclaredSymbol(classNode);
if (classSymbol != null && IsDerivedFromSmartContract(classSymbol, "Neo.SmartContract.Framework.SmartContract", semanticModel))
{
allSmartContracts.Add(classSymbol);
classDependencies[classSymbol] = new List<INamedTypeSymbol>();
foreach (var member in classSymbol.GetMembers())
{
var memberTypeSymbol = (member as IFieldSymbol)?.Type ?? (member as IPropertySymbol)?.Type;
if (memberTypeSymbol is INamedTypeSymbol namedTypeSymbol && allSmartContracts.Contains(namedTypeSymbol))
{
classDependencies[classSymbol].Add(namedTypeSymbol);
}
}
}
}
}

// Verify if there is any valid smart contract class
if (classDependencies.Count == 0) throw new FormatException("No valid neo SmartContract found. Please make sure your contract is subclass of SmartContract and is not abstract.");
// Check contract dependencies, make sure there is no cycle in the dependency graph
var sortedClasses = TopologicalSort(classDependencies);
foreach (var classSymbol in sortedClasses)
{
new CompilationContext(this, classSymbol).Compile();
}

return Contexts.Select(p => p.Value).ToList();
}

private static List<INamedTypeSymbol> TopologicalSort(Dictionary<INamedTypeSymbol, List<INamedTypeSymbol>> dependencies)
{
var sorted = new List<INamedTypeSymbol>();
var visited = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);
var visiting = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default); // 添加中间状态以检测循环依赖

void Visit(INamedTypeSymbol classSymbol)
{
if (visited.Contains(classSymbol))
{
return;
}
if (!visiting.Add(classSymbol))
{
throw new InvalidOperationException("Cyclic dependency detected");
}

if (dependencies.TryGetValue(classSymbol, out var dependency))
{
foreach (var dep in dependency)
{
Visit(dep);
}
}

visiting.Remove(classSymbol);
visited.Add(classSymbol);
sorted.Add(classSymbol);
}

foreach (var classSymbol in dependencies.Keys)
{
Visit(classSymbol);
}

return sorted;
}

static bool IsDerivedFromSmartContract(INamedTypeSymbol classSymbol, string smartContractFullyQualifiedName, SemanticModel semanticModel)
{
var baseType = classSymbol.BaseType;
while (baseType != null)
{
if (baseType.ToDisplayString() == smartContractFullyQualifiedName)
{
return true;
}
baseType = baseType.BaseType;
}
return false;
}

public Compilation GetCompilation(string csproj)
{
string folder = Path.GetDirectoryName(csproj)!;
string obj = Path.Combine(folder, "obj");
HashSet<string> sourceFiles = Directory.EnumerateFiles(folder, "*.cs", SearchOption.AllDirectories)
.Where(p => !p.StartsWith(obj))
.GroupBy(Path.GetFileName)
.Select(g => g.First())
.ToHashSet(StringComparer.OrdinalIgnoreCase);
List<MetadataReference> references = new(CommonReferences);
CSharpCompilationOptions compilationOptions = new(OutputKind.DynamicallyLinkedLibrary, deterministic: true, nullableContextOptions: Options.Nullable);
XDocument document = XDocument.Load(csproj);
sourceFiles.UnionWith(document.Root!.Elements("ItemGroup").Elements("Compile").Attributes("Include").Select(p => Path.GetFullPath(p.Value, folder)));
Process.Start(new ProcessStartInfo
{
FileName = "dotnet",
Arguments = $"restore \"{csproj}\"",
WorkingDirectory = folder
})!.WaitForExit();
string assetsPath = Path.Combine(folder, "obj", "project.assets.json");
JObject assets = (JObject)JToken.Parse(File.ReadAllBytes(assetsPath))!;
foreach (var (name, package) in ((JObject)assets["targets"]![0]!).Properties)
{
MetadataReference? reference = GetReference(name, (JObject)package!, assets, folder, Options, compilationOptions);
if (reference is not null) references.Add(reference);
}
IEnumerable<SyntaxTree> syntaxTrees = sourceFiles.OrderBy(p => p).Select(p => CSharpSyntaxTree.ParseText(File.ReadAllText(p), options: Options.GetParseOptions(), path: p));
return CSharpCompilation.Create(assets["project"]!["restore"]!["projectName"]!.GetString(), syntaxTrees, references, compilationOptions);
}

private MetadataReference? GetReference(string name, JObject package, JObject assets, string folder, Options options, CSharpCompilationOptions compilationOptions)
{
string assemblyName = Path.GetDirectoryName(name)!;
if (!MetaReferences.TryGetValue(assemblyName, out var reference))
{
switch (assets["libraries"]![name]!["type"]!.GetString())
{
case "package":
string packagesPath = assets["project"]!["restore"]!["packagesPath"]!.GetString();
string namePath = assets["libraries"]![name]!["path"]!.GetString();
string[] files = ((JArray)assets["libraries"]![name]!["files"]!)
.Select(p => p!.GetString())
.Where(p => p.StartsWith("src/"))
.ToArray();
if (files.Length == 0)
{
JObject? dllFiles = (JObject?)(package["compile"] ?? package["runtime"]);
if (dllFiles is null) return null;
foreach (var (file, _) in dllFiles.Properties)
{
if (file.EndsWith("_._")) continue;
string path = Path.Combine(packagesPath, namePath, file);
if (!File.Exists(path)) continue;
reference = MetadataReference.CreateFromFile(path);
break;
}
if (reference is null) return null;
}
else
{
IEnumerable<SyntaxTree> st = files.OrderBy(p => p).Select(p => Path.Combine(packagesPath, namePath, p)).Select(p => CSharpSyntaxTree.ParseText(File.ReadAllText(p), path: p));
CSharpCompilation cr = CSharpCompilation.Create(assemblyName, st, CommonReferences, compilationOptions);
reference = cr.ToMetadataReference();
}
break;
case "project":
string msbuildProject = assets["libraries"]![name]!["msbuildProject"]!.GetString();
msbuildProject = Path.GetFullPath(msbuildProject, folder);
reference = GetCompilation(msbuildProject).ToMetadataReference();
break;
default:
throw new NotSupportedException();
}
MetaReferences.Add(assemblyName, reference);
}
return reference;
}
}
}
12 changes: 6 additions & 6 deletions src/Neo.Compiler.CSharp/MethodConvert/CallHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private Instruction Call(InteropDescriptor descriptor)

private Instruction Call(UInt160 hash, string method, ushort parametersCount, bool hasReturnValue, CallFlags callFlags = CallFlags.All)
{
ushort token = context.AddMethodToken(hash, method, parametersCount, hasReturnValue, callFlags);
ushort token = _context.AddMethodToken(hash, method, parametersCount, hasReturnValue, callFlags);
return AddInstruction(new Instruction
{
OpCode = OpCode.CALLT,
Expand All @@ -58,7 +58,7 @@ private void Call(SemanticModel model, IMethodSymbol symbol, bool instanceOnStac
}
else
{
convert = context.ConvertMethod(model, symbol);
convert = _context.ConvertMethod(model, symbol);
methodCallingConvention = convert._callingConvention;
}
bool isConstructor = symbol.MethodKind == MethodKind.Constructor;
Expand Down Expand Up @@ -103,8 +103,8 @@ private void Call(SemanticModel model, IMethodSymbol symbol, ExpressionSyntax? i
else
{
convert = symbol.ReducedFrom is null
? context.ConvertMethod(model, symbol)
: context.ConvertMethod(model, symbol.ReducedFrom);
? _context.ConvertMethod(model, symbol)
: _context.ConvertMethod(model, symbol.ReducedFrom);
methodCallingConvention = convert._callingConvention;
}
if (!symbol.IsStatic && methodCallingConvention != CallingConvention.Cdecl)
Expand Down Expand Up @@ -143,7 +143,7 @@ private void Call(SemanticModel model, IMethodSymbol symbol, CallingConvention c
}
else
{
convert = context.ConvertMethod(model, symbol);
convert = _context.ConvertMethod(model, symbol);
methodCallingConvention = convert._callingConvention;
}
int pc = symbol.Parameters.Length;
Expand Down Expand Up @@ -175,7 +175,7 @@ private void Call(SemanticModel model, IMethodSymbol symbol, CallingConvention c

private void EmitCall(MethodConvert target)
{
if (target._inline && !context.Options.NoInline)
if (target._inline && !_context.Options.NoInline)
for (int i = 0; i < target._instructions.Count - 1; i++)
AddInstruction(target._instructions[i].Clone());
else
Expand Down
8 changes: 4 additions & 4 deletions src/Neo.Compiler.CSharp/MethodConvert/ConstructorConvert.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,19 @@ private void ProcessConstructorInitializer(SemanticModel model)

private void ProcessStaticFields(SemanticModel model)
{
foreach (INamedTypeSymbol @class in context.StaticFieldSymbols.Select(p => p.ContainingType).Distinct<INamedTypeSymbol>(SymbolEqualityComparer.Default).ToArray())
foreach (INamedTypeSymbol @class in _context.StaticFieldSymbols.Select(p => p.ContainingType).Distinct<INamedTypeSymbol>(SymbolEqualityComparer.Default).ToArray())
{
foreach (IFieldSymbol field in @class.GetAllMembers().OfType<IFieldSymbol>())
{
if (field.IsConst || !field.IsStatic) continue;
ProcessFieldInitializer(model, field, null, () =>
{
byte index = context.AddStaticField(field);
byte index = _context.AddStaticField(field);
AccessSlot(OpCode.STSFLD, index);
});
}
}
foreach (var (fieldIndex, type) in context.VTables)
foreach (var (fieldIndex, type) in _context.VTables)
{
IMethodSymbol[] virtualMethods = type.GetAllMembers().OfType<IMethodSymbol>().Where(p => p.IsVirtualMethod()).ToArray();
for (int i = virtualMethods.Length - 1; i >= 0; i--)
Expand All @@ -85,7 +85,7 @@ private void ProcessStaticFields(SemanticModel model)
}
else
{
MethodConvert convert = context.ConvertMethod(model, method);
MethodConvert convert = _context.ConvertMethod(model, method);
Jump(OpCode.PUSHA, convert._startTarget);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private void ConvertFieldIdentifierNameCoalesceAssignment(SemanticModel model, I
JumpTarget endTarget = new();
if (left.IsStatic)
{
byte index = context.AddStaticField(left);
byte index = _context.AddStaticField(left);
AccessSlot(OpCode.LDSFLD, index);
AddInstruction(OpCode.ISNULL);
Jump(OpCode.JMPIF_L, assignmentTarget);
Expand Down Expand Up @@ -232,7 +232,7 @@ private void ConvertFieldMemberAccessCoalesceAssignment(SemanticModel model, Mem
JumpTarget endTarget = new();
if (field.IsStatic)
{
byte index = context.AddStaticField(field);
byte index = _context.AddStaticField(field);
AccessSlot(OpCode.LDSFLD, index);
AddInstruction(OpCode.ISNULL);
Jump(OpCode.JMPIF_L, assignmentTarget);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ private void ConvertFieldIdentifierNameComplexAssignment(SemanticModel model, IT
{
if (left.IsStatic)
{
byte index = context.AddStaticField(left);
byte index = _context.AddStaticField(left);
AccessSlot(OpCode.LDSFLD, index);
ConvertExpression(model, right);
EmitComplexAssignmentOperator(type, operatorToken);
Expand Down Expand Up @@ -184,7 +184,7 @@ private void ConvertFieldMemberAccessComplexAssignment(SemanticModel model, ITyp
{
if (field.IsStatic)
{
byte index = context.AddStaticField(field);
byte index = _context.AddStaticField(field);
AccessSlot(OpCode.LDSFLD, index);
ConvertExpression(model, right);
EmitComplexAssignmentOperator(type, operatorToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private void ConvertIdentifierNameAssignment(SemanticModel model, IdentifierName
case IFieldSymbol field:
if (field.IsStatic)
{
byte index = context.AddStaticField(field);
byte index = _context.AddStaticField(field);
AccessSlot(OpCode.STSFLD, index);
}
else
Expand Down Expand Up @@ -139,7 +139,7 @@ private void ConvertMemberAccessAssignment(SemanticModel model, MemberAccessExpr
case IFieldSymbol field:
if (field.IsStatic)
{
byte index = context.AddStaticField(field);
byte index = _context.AddStaticField(field);
AccessSlot(OpCode.STSFLD, index);
}
else
Expand Down
Loading

0 comments on commit 7add5a9

Please sign in to comment.