Skip to content

Commit

Permalink
Improve getExactClasses to support classes as base types (#92440)
Browse files Browse the repository at this point in the history
Fixes #88547

```csharp
using System.Runtime.CompilerServices;

public class ClassA
{
    public virtual int GetValue() => 42;
}

public class ClassB : ClassA
{
    // we don't even need to override GetValue here
}

class MyClass
{
    static void Main()
    {
        Test(new ClassB());
    }

    [MethodImpl(MethodImplOptions.NoInlining)]
    static int Test(ClassA c) => c.GetValue();
}
```

Old codegen:

```
; Method MyClass:Test(ClassA):int (FullOpts)
       sub      rsp, 40
       mov      rax, qword ptr [rcx]
       call     [rax+30H]ClassA:GetValue():int:this
       nop
       add      rsp, 40
       ret
; Total bytes of code: 16
```

New codegen:

```
00007FF66DD8AFB0  cmp         dword ptr [rcx],ecx
00007FF66DD8AFB2  mov         eax,2Ah
00007FF66DD8AFB7  ret
```
  • Loading branch information
MichalStrehovsky authored Sep 22, 2023
1 parent f61d3c7 commit e0a4bdd
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 25 deletions.
64 changes: 45 additions & 19 deletions src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/ILScanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ private sealed class ScannedDevirtualizationManager : DevirtualizationManager
private HashSet<TypeDesc> _constructedTypes = new HashSet<TypeDesc>();
private HashSet<TypeDesc> _canonConstructedTypes = new HashSet<TypeDesc>();
private HashSet<TypeDesc> _unsealedTypes = new HashSet<TypeDesc>();
private Dictionary<TypeDesc, HashSet<TypeDesc>> _interfaceImplementators = new();
private HashSet<TypeDesc> _disqualifiedInterfaces = new();
private Dictionary<TypeDesc, HashSet<TypeDesc>> _implementators = new();
private HashSet<TypeDesc> _disqualifiedTypes = new();

public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<DependencyNodeCore<NodeFactory>> markedNodes)
{
Expand All @@ -437,8 +437,8 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
{
// If the interface is implemented through IDynamicInterfaceCastable, there might be
// no real upper bound on the number of actual classes implementing it.
if (CanAssumeWholeProgramViewOnInterfaceUse(factory, type, baseInterface))
_disqualifiedInterfaces.Add(baseInterface);
if (CanAssumeWholeProgramViewOnTypeUse(factory, type, baseInterface))
_disqualifiedTypes.Add(baseInterface);
}
}
}
Expand All @@ -457,14 +457,23 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend

if (type is not MetadataType { IsAbstract: true })
{
// Record all interfaces this class implements to _interfaceImplementators
// Record all interfaces this class implements to _implementators
foreach (DefType baseInterface in type.RuntimeInterfaces)
{
if (CanAssumeWholeProgramViewOnInterfaceUse(factory, type, baseInterface))
if (CanAssumeWholeProgramViewOnTypeUse(factory, type, baseInterface))
{
RecordImplementation(baseInterface, type);
}
}

// Record all base types of this class
for (DefType @base = type.BaseType; @base != null; @base = @base.BaseType)
{
if (CanAssumeWholeProgramViewOnTypeUse(factory, type, @base))
{
RecordImplementation(@base, type);
}
}
}

if (type.IsCanonicalSubtype(CanonicalFormKind.Any))
Expand All @@ -474,7 +483,13 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
// due to MakeGenericType.
foreach (DefType baseInterface in type.RuntimeInterfaces)
{
_disqualifiedInterfaces.Add(baseInterface);
_disqualifiedTypes.Add(baseInterface);
}

// Same for base classes
for (DefType @base = type.BaseType; @base != null; @base = @base.BaseType)
{
_disqualifiedTypes.Add(@base);
}
}
else if (type.IsArray || type.GetTypeDefinition() == factory.ArrayOfTEnumeratorType)
Expand All @@ -490,7 +505,7 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
{
// Limit to the generic ones - ICollection<T>, etc.
if (baseInterface.HasInstantiation)
_disqualifiedInterfaces.Add(baseInterface);
_disqualifiedTypes.Add(baseInterface);
}
}
}
Expand All @@ -513,22 +528,23 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
}
}

private static bool CanAssumeWholeProgramViewOnInterfaceUse(NodeFactory factory, TypeDesc implementingType, DefType interfaceType)
private static bool CanAssumeWholeProgramViewOnTypeUse(NodeFactory factory, TypeDesc implementingType, DefType baseType)
{
if (!interfaceType.HasInstantiation)
if (!baseType.HasInstantiation)
{
return true;
}

// If there are variance considerations, bail
if (VariantInterfaceMethodUseNode.IsVariantInterfaceImplementation(factory, implementingType, interfaceType))
if (baseType.IsInterface
&& VariantInterfaceMethodUseNode.IsVariantInterfaceImplementation(factory, implementingType, baseType))
{
return false;
}

if (interfaceType.IsCanonicalSubtype(CanonicalFormKind.Any)
|| interfaceType.ConvertToCanonForm(CanonicalFormKind.Specific) != interfaceType
|| interfaceType.Context.SupportsUniversalCanon)
if (baseType.IsCanonicalSubtype(CanonicalFormKind.Any)
|| baseType.ConvertToCanonForm(CanonicalFormKind.Specific) != baseType
|| baseType.Context.SupportsUniversalCanon)
{
// If the interface has a canonical form, we might not have a full view of all implementers.
// E.g. if we have:
Expand All @@ -549,10 +565,10 @@ private void RecordImplementation(TypeDesc type, TypeDesc implType)
Debug.Assert(!implType.IsInterface);

HashSet<TypeDesc> implList;
if (!_interfaceImplementators.TryGetValue(type, out implList))
if (!_implementators.TryGetValue(type, out implList))
{
implList = new();
_interfaceImplementators[type] = implList;
_implementators[type] = implList;
}
implList.Add(implType);
}
Expand Down Expand Up @@ -604,13 +620,23 @@ protected override MethodDesc ResolveVirtualMethod(MethodDesc declMethod, DefTyp

public override TypeDesc[] GetImplementingClasses(TypeDesc type)
{
if (_disqualifiedInterfaces.Contains(type))
if (_disqualifiedTypes.Contains(type))
return null;

if (type.IsInterface && _interfaceImplementators.TryGetValue(type, out HashSet<TypeDesc> implementations))
if (_implementators.TryGetValue(type, out HashSet<TypeDesc> implementations))
{
var types = new TypeDesc[implementations.Count];
TypeDesc[] types;
int index = 0;
if (!type.IsInterface && type is not MetadataType { IsAbstract: true })
{
types = new TypeDesc[implementations.Count + 1];
types[index++] = type;
}
else
{
types = new TypeDesc[implementations.Count];
}

foreach (TypeDesc implementation in implementations)
{
types[index++] = implementation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2196,12 +2196,6 @@ private int getExactClasses(CORINFO_CLASS_STRUCT_* baseType, int maxExactClasses
return 1;
}

if (!type.IsInterface)
{
// TODO: handle classes
return 0;
}

TypeDesc[] implClasses = _compilation.GetImplementingClasses(type);
if (implClasses == null || implClasses.Length > maxExactClasses)
{
Expand Down

0 comments on commit e0a4bdd

Please sign in to comment.