From db00a61ddd21e95d6c29ff9960210145e36e9dd1 Mon Sep 17 00:00:00 2001 From: Manodasan Wignarajah Date: Sun, 13 Oct 2024 16:46:06 -0700 Subject: [PATCH] Reuse generated code when possible (#1821) * Initialize size improvements * Also optimize ccw * Fix build * Add comments and renaming --- .../WinRT.SourceGenerator/AotOptimizer.cs | 173 +++++++++++++----- 1 file changed, 130 insertions(+), 43 deletions(-) diff --git a/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs b/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs index 11d952e61..31534e200 100644 --- a/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs +++ b/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs @@ -773,11 +773,11 @@ private static void GenerateVtableAttributes( GenerateVtableAttributes(sourceProductionContext.AddSource, value.vtableAttributes, value.context.properties.isCsWinRTComponent, value.context.escapedAssemblyName); } - internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, string escapedAssemblyName) + internal static string GenerateVtableEntry(VtableEntry vtableEntry, string escapedAssemblyName) { StringBuilder source = new(); - foreach (var genericInterface in vtableAttribute.GenericInterfaces) + foreach (var genericInterface in vtableEntry.GenericInterfaces) { source.AppendLine(GenericVtableInitializerStrings.GetInstantiationInitFunction( genericInterface.GenericDefinition, @@ -785,9 +785,9 @@ internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, stri escapedAssemblyName)); } - if (vtableAttribute.IsDelegate) + if (vtableEntry.IsDelegate) { - var @interface = vtableAttribute.Interfaces.First(); + var @interface = vtableEntry.Interfaces.First(); source.AppendLine(); source.AppendLine($$""" var delegateInterface = new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry @@ -799,7 +799,7 @@ internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, stri return global::WinRT.DelegateTypeDetails<{{@interface}}>.GetExposedInterfaces(delegateInterface); """); } - else if (vtableAttribute.Interfaces.Any()) + else if (vtableEntry.Interfaces.Any()) { source.AppendLine(); source.AppendLine($$""" @@ -807,7 +807,7 @@ internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, stri { """); - foreach (var @interface in vtableAttribute.Interfaces) + foreach (var @interface in vtableEntry.Interfaces) { var genericStartIdx = @interface.IndexOf('<'); var interfaceStaticsMethod = @interface[..(genericStartIdx == -1 ? @interface.Length : genericStartIdx)] + "Methods"; @@ -840,6 +840,10 @@ internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, stri internal static void GenerateVtableAttributes(Action addSource, ImmutableArray vtableAttributes, bool isCsWinRTComponentFromAotOptimizer, string escapedAssemblyName) { + var vtableEntryToVtableClassName = new Dictionary(); + StringBuilder vtableClassesSource = new(); + bool firstVtableClass = true; + // Using ToImmutableHashSet to avoid duplicate entries from the use of partial classes by the developer // to split out their implementation. When they do that, we will get multiple entries here for that // and try to generate the same attribute and file with the same data as we use the semantic model @@ -850,11 +854,10 @@ internal static void GenerateVtableAttributes(Action addSource, // from the AOT optimizer, then any public types are not handled // right now as they are handled by the WinRT component source generator // calling this. - if (((isCsWinRTComponentFromAotOptimizer && !vtableAttribute.IsPublic) || !isCsWinRTComponentFromAotOptimizer) && + if (((isCsWinRTComponentFromAotOptimizer && !vtableAttribute.IsPublic) || !isCsWinRTComponentFromAotOptimizer) && vtableAttribute.Interfaces.Any()) { StringBuilder source = new(); - source.AppendLine("using static WinRT.TypeExtensions;\n"); if (!vtableAttribute.IsGlobalNamespace) { source.AppendLine($$""" @@ -863,6 +866,16 @@ namespace {{vtableAttribute.Namespace}} """); } + // Check if this class shares the same vtable as another class. If so, reuse the same generated class for it. + VtableEntry entry = new(vtableAttribute.Interfaces, vtableAttribute.GenericInterfaces, vtableAttribute.IsDelegate); + bool vtableEntryExists = vtableEntryToVtableClassName.TryGetValue(entry, out var ccwClassName); + if (!vtableEntryExists) + { + var @namespace = vtableAttribute.IsGlobalNamespace ? "" : $"{vtableAttribute.Namespace}."; + ccwClassName = GeneratorHelper.EscapeTypeNameForIdentifier(@namespace + vtableAttribute.ClassName); + vtableEntryToVtableClassName.Add(entry, ccwClassName); + } + var escapedClassName = GeneratorHelper.EscapeTypeNameForIdentifier(vtableAttribute.ClassName); // Simple case when the type is not nested @@ -874,7 +887,7 @@ namespace {{vtableAttribute.Namespace}} } source.AppendLine($$""" - [global::WinRT.WinRTExposedType(typeof({{escapedClassName}}WinRTTypeDetails))] + [global::WinRT.WinRTExposedType(typeof(global::WinRT.{{escapedAssemblyName}}VtableClasses.{{ccwClassName}}WinRTTypeDetails))] partial class {{vtableAttribute.ClassName}} { } @@ -900,7 +913,7 @@ partial class {{vtableAttribute.ClassName}} } source.AppendLine($$""" - [global::WinRT.WinRTExposedType(typeof({{escapedClassName}}WinRTTypeDetails))] + [global::WinRT.WinRTExposedType(typeof(global::WinRT.{{escapedAssemblyName}}VtableClasses.{{ccwClassName}}WinRTTypeDetails))] partial {{classHierarchy[0].GetTypeKeyword()}} {{classHierarchy[0].QualifiedName}} { } @@ -913,62 +926,78 @@ partial class {{vtableAttribute.ClassName}} } } - source.AppendLine(); - source.AppendLine($$""" - internal sealed class {{escapedClassName}}WinRTTypeDetails : global::WinRT.IWinRTExposedTypeDetails + // Only generate class, if this is the first time we run into this set of vtables. + if (!vtableEntryExists) + { + if (firstVtableClass) + { + vtableClassesSource.AppendLine($$""" + namespace WinRT.{{escapedAssemblyName}}VtableClasses + { + """); + firstVtableClass = false; + } + else + { + vtableClassesSource.AppendLine(); + } + + vtableClassesSource.AppendLine($$""" + internal sealed class {{ccwClassName}}WinRTTypeDetails : global::WinRT.IWinRTExposedTypeDetails { public global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry[] GetExposedInterfaces() { """); - if (vtableAttribute.Interfaces.Any()) - { - foreach (var genericInterface in vtableAttribute.GenericInterfaces) + if (vtableAttribute.Interfaces.Any()) { - source.AppendLine(GenericVtableInitializerStrings.GetInstantiationInitFunction( - genericInterface.GenericDefinition, - genericInterface.GenericParameters, - escapedAssemblyName)); - } + foreach (var genericInterface in vtableAttribute.GenericInterfaces) + { + vtableClassesSource.AppendLine(GenericVtableInitializerStrings.GetInstantiationInitFunction( + genericInterface.GenericDefinition, + genericInterface.GenericParameters, + escapedAssemblyName)); + } - source.AppendLine(); - source.AppendLine($$""" + vtableClassesSource.AppendLine(); + vtableClassesSource.AppendLine($$""" return new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry[] { """); - foreach (var @interface in vtableAttribute.Interfaces) - { - var genericStartIdx = @interface.IndexOf('<'); - var interfaceStaticsMethod = @interface[..(genericStartIdx == -1 ? @interface.Length : genericStartIdx)] + "Methods"; - if (genericStartIdx != -1) + foreach (var @interface in vtableAttribute.Interfaces) { - interfaceStaticsMethod += @interface[genericStartIdx..@interface.Length]; - } + var genericStartIdx = @interface.IndexOf('<'); + var interfaceStaticsMethod = @interface[..(genericStartIdx == -1 ? @interface.Length : genericStartIdx)] + "Methods"; + if (genericStartIdx != -1) + { + interfaceStaticsMethod += @interface[genericStartIdx..@interface.Length]; + } - source.AppendLine($$""" + vtableClassesSource.AppendLine($$""" new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry { IID = global::ABI.{{interfaceStaticsMethod}}.IID, Vtable = global::ABI.{{interfaceStaticsMethod}}.AbiToProjectionVftablePtr }, """); - } - source.AppendLine($$""" + } + vtableClassesSource.AppendLine($$""" }; """); - } - else - { - source.AppendLine($$""" + } + else + { + vtableClassesSource.AppendLine($$""" return global::System.Array.Empty(); """); - } + } - source.AppendLine($$""" + vtableClassesSource.AppendLine($$""" } } """); + } if (!vtableAttribute.IsGlobalNamespace) { @@ -979,6 +1008,12 @@ internal sealed class {{escapedClassName}}WinRTTypeDetails : global::WinRT.IWinR addSource($"{prefix}{escapedClassName}.WinRTVtable.g.cs", source.ToString()); } } + + if (vtableClassesSource.Length != 0) + { + vtableClassesSource.AppendLine("}"); + addSource($"WinRTCCWVtable.g.cs", vtableClassesSource.ToString()); + } } private static void GenerateCCWForGenericInstantiation( @@ -1444,12 +1479,37 @@ private static ComWrappers.ComInterfaceEntry[] LookupVtableEntries(Type type) """); } + // We gather all the class names that have the same vtable and generate it + // as part of one if to reduce generated code. + var vtableEntryToClassNameList = new Dictionary>(); foreach (var vtableAttribute in value.vtableAttributes.ToImmutableHashSet()) { + VtableEntry entry = new(vtableAttribute.Interfaces, vtableAttribute.GenericInterfaces, vtableAttribute.IsDelegate); + if (!vtableEntryToClassNameList.TryGetValue(entry, out var classNameList)) + { + classNameList = new List(); + vtableEntryToClassNameList.Add(entry, classNameList); + } + classNameList.Add(vtableAttribute.VtableLookupClassName); + } + + foreach (var vtableEntry in vtableEntryToClassNameList) + { + source.AppendLine($$""" + if (typeName == "{{vtableEntry.Value[0]}}" + """); + + for (var i = 1; i < vtableEntry.Value.Count; i++) + { + source.AppendLine($$""" + || typeName == "{{vtableEntry.Value[i]}}" + """); + } + source.AppendLine($$""" - if (typeName == "{{vtableAttribute.VtableLookupClassName}}") + ) { - {{GenerateVtableEntry(vtableAttribute, value.context.escapedAssemblyName)}} + {{GenerateVtableEntry(vtableEntry.Key, value.context.escapedAssemblyName)}} } """); } @@ -1469,12 +1529,34 @@ private static string LookupRuntimeClassName(Type type) string typeName = type.ToString(); """); + var runtimeClassNameToClassNameList = new Dictionary>(); foreach (var vtableAttribute in value.vtableAttributes.ToImmutableHashSet().Where(static v => !string.IsNullOrEmpty(v.RuntimeClassName))) + { + if (!runtimeClassNameToClassNameList.TryGetValue(vtableAttribute.RuntimeClassName, out var classNameList)) + { + classNameList = new List(); + runtimeClassNameToClassNameList.Add(vtableAttribute.RuntimeClassName, classNameList); + } + classNameList.Add(vtableAttribute.VtableLookupClassName); + } + + foreach (var entry in runtimeClassNameToClassNameList) { source.AppendLine($$""" - if (typeName == "{{vtableAttribute.VtableLookupClassName}}") + if (typeName == "{{entry.Value[0]}}" + """); + + for (var i = 1; i < entry.Value.Count; i++) + { + source.AppendLine($$""" + || typeName == "{{entry.Value[i]}}" + """); + } + + source.AppendLine($$""" + ) { - return "{{vtableAttribute.RuntimeClassName}}"; + return "{{entry.Key}}"; } """); } @@ -1630,6 +1712,11 @@ internal sealed record VtableAttribute( bool IsPublic, string RuntimeClassName = default); + sealed record VtableEntry( + EquatableArray Interfaces, + EquatableArray GenericInterfaces, + bool IsDelegate); + internal readonly record struct BindableCustomProperty( string Name, string Type,