Skip to content

Commit

Permalink
Fix handling of static virtual method implementation checking (dotnet…
Browse files Browse the repository at this point in the history
…#54710)

- It turns out that GetMethodDescFromMemberDefOrRefOrSpec and FindOrCreateAssociatedMethodDesc are not safe to use on a MemberRef whent  the associated MethodTable is not fully loaded.
- Instead only use that feature when working with a MethodDef or a fully loaded type, and when working with a not fully loaded type, use MemberLoader::FindMethod instead.
- When running the resolution algorithm for doing constraint validation, it also is not necessary to fully resolve to the exact correct MethodDesc, which as that process uses FindOrCreateAssociatedMethodDesc needs to be avoided.
- The above was not evident as in many cases, the validation algorithm did not run as it was misplaced and located directly before the call to SetIsFullyLoaded. That code path is only followed if the type is able to fully load without circular dependencies. (Test case CuriouslyRecurringGenericWithUnimplementedMethod added to cover that scenario)
- In addition, while investigating these issues, I realized we were lacking checks that the constraints on the impl and decl method were not checked at during type load, but that work was instead deferred to dispatch time. Along with the constraint check there was also a set of accessibility checks that had been missed that are common to all MethodImpl handling. Fix by adding tweaking the logic to share most of that code.
  • Loading branch information
davidwrighton authored Jun 25, 2021
1 parent 24adc91 commit 333c4e7
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 49 deletions.
84 changes: 54 additions & 30 deletions src/coreclr/vm/methodtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5658,6 +5658,22 @@ void MethodTable::DoFullyLoad(Generics::RecursionGraph * const pVisited, const

}

// Validate implementation of virtual static methods on all implemented interfaces unless:
// 1) The type resides in a module where sanity checks are disabled (such as System.Private.CoreLib, or an
// R2R module with type checks disabled)
// 2) There are no virtual static methods defined on any of the interfaces implemented by this type;
// 3) The type is abstract in which case it's allowed to leave some virtual static methods unimplemented
// akin to equivalent behavior of virtual instance method overriding in abstract classes;
// 4) The type is a not the typical type definition. (The typical type is always checked)

if (fNeedsSanityChecks &&
IsTypicalTypeDefinition() &&
!IsAbstract())
{
if (HasVirtualStaticMethods())
VerifyThatAllVirtualStaticMethodsAreImplemented();
}

if (locals.fBailed)
{
// We couldn't complete security checks on some dependency because it is already being processed by one of our callers.
Expand All @@ -5671,22 +5687,6 @@ void MethodTable::DoFullyLoad(Generics::RecursionGraph * const pVisited, const
}
else
{
// Validate implementation of virtual static methods on all implemented interfaces unless:
// 1) The type resides in the system module (System.Private.CoreLib); we own this module and ensure
// its consistency by other means not requiring runtime checks;
// 2) There are no virtual static methods defined on any of the interfaces implemented by this type;
// 3) The method is abstract in which case it's allowed to leave some virtual static methods unimplemented
// akin to equivalent behavior of virtual instance method overriding in abstract classes;
// 4) The type is a shared generic in which case we generally don't have enough information to perform
// the validation.
if (!GetModule()->IsSystem() &&
HasVirtualStaticMethods() &&
!IsAbstract() &&
!IsSharedByGenericInstantiations())
{
VerifyThatAllVirtualStaticMethodsAreImplemented();
}

// Finally, mark this method table as fully loaded
SetIsFullyLoaded();
}
Expand Down Expand Up @@ -9201,7 +9201,7 @@ MethodDesc *MethodTable::GetDefaultConstructor(BOOL forceBoxedEntryPoint /* = FA
//==========================================================================================
// Finds the (non-unboxing) MethodDesc that implements the interface virtual static method pInterfaceMD.
MethodDesc *
MethodTable::ResolveVirtualStaticMethod(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL allowNullResult, BOOL checkDuplicates, BOOL allowVariantMatches)
MethodTable::ResolveVirtualStaticMethod(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL allowNullResult, BOOL verifyImplemented, BOOL allowVariantMatches)
{
if (!pInterfaceMD->IsSharedByGenericMethodInstantiations() && !pInterfaceType->IsSharedByGenericInstantiations())
{
Expand Down Expand Up @@ -9231,7 +9231,7 @@ MethodTable::ResolveVirtualStaticMethod(MethodTable* pInterfaceType, MethodDesc*
// Search for match on a per-level in the type hierarchy
for (MethodTable* pMT = this; pMT != nullptr; pMT = pMT->GetParentMethodTable())
{
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, checkDuplicates);
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented);
if (pMD != nullptr)
{
return pMD;
Expand Down Expand Up @@ -9273,7 +9273,7 @@ MethodTable::ResolveVirtualStaticMethod(MethodTable* pInterfaceType, MethodDesc*
{
// Variant or equivalent matching interface found
// Attempt to resolve on variance matched interface
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(it.GetInterface(), pInterfaceMD, checkDuplicates);
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(it.GetInterface(), pInterfaceMD, verifyImplemented);
if (pMD != nullptr)
{
return pMD;
Expand All @@ -9295,7 +9295,7 @@ MethodTable::ResolveVirtualStaticMethod(MethodTable* pInterfaceType, MethodDesc*
// Try to locate the appropriate MethodImpl matching a given interface static virtual method.
// Returns nullptr on failure.
MethodDesc*
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL checkDuplicates)
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented)
{
HRESULT hr = S_OK;
IMDInternalImport* pMDInternalImport = GetMDImport();
Expand Down Expand Up @@ -9347,13 +9347,39 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType
{
continue;
}
MethodDesc *pMethodDecl = MemberLoader::GetMethodDescFromMemberDefOrRefOrSpec(
MethodDesc *pMethodDecl;

if ((TypeFromToken(methodDecl) == mdtMethodDef) || pInterfaceMT->IsFullyLoaded())
{
pMethodDecl = MemberLoader::GetMethodDescFromMemberDefOrRefOrSpec(
GetModule(),
methodDecl,
&sigTypeContext,
/* strictMetadataChecks */ FALSE,
/* allowInstParam */ FALSE,
/* owningTypeLoadLevel */ CLASS_LOAD_EXACTPARENTS);
}
else if (TypeFromToken(methodDecl) == mdtMemberRef)
{
LPCUTF8 szMember;
PCCOR_SIGNATURE pSig;
DWORD cSig;

IfFailThrow(pMDInternalImport->GetNameAndSigOfMemberRef(methodDecl, &pSig, &cSig, &szMember));

// Do a quick name check to avoid excess use of FindMethod
if (strcmp(szMember, pInterfaceMD->GetName()) != 0)
{
continue;
}

pMethodDecl = MemberLoader::FindMethod(pInterfaceMT, szMember, pSig, cSig, GetModule());
}
else
{
COMPlusThrow(kTypeLoadException, E_FAIL);
}

if (pMethodDecl == nullptr)
{
COMPlusThrow(kTypeLoadException, E_FAIL);
Expand All @@ -9369,13 +9395,11 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType
COMPlusThrow(kTypeLoadException, E_FAIL);
}

MethodDesc *pMethodImpl = MemberLoader::GetMethodDescFromMemberDefOrRefOrSpec(
MethodDesc *pMethodImpl = MemberLoader::GetMethodDescFromMethodDef(
GetModule(),
methodBody,
&sigTypeContext,
/* strictMetadataChecks */ FALSE,
/* allowInstParam */ FALSE,
/* owningTypeLoadLevel */ CLASS_LOAD_EXACTPARENTS);
FALSE,
CLASS_LOAD_EXACTPARENTS);
if (pMethodImpl == nullptr)
{
COMPlusThrow(kTypeLoadException, E_FAIL);
Expand All @@ -9388,7 +9412,7 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType
COMPlusThrow(kTypeLoadException, E_FAIL);
}

if (pInterfaceMD->HasMethodInstantiation() || pMethodImpl->HasMethodInstantiation() || HasInstantiation())
if (!verifyImplemented)
{
pMethodImpl = pMethodImpl->FindOrCreateAssociatedMethodDesc(
pMethodImpl,
Expand All @@ -9398,11 +9422,11 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType
/* allowInstParam */ FALSE,
/* forceRemotableMethod */ FALSE,
/* allowCreate */ TRUE,
/* level */ CLASS_LOAD_EXACTPARENTS);
/* level */ CLASS_LOADED);
}
if (pMethodImpl != nullptr)
{
if (!checkDuplicates)
if (!verifyImplemented)
{
return pMethodImpl;
}
Expand Down Expand Up @@ -9432,7 +9456,7 @@ MethodTable::VerifyThatAllVirtualStaticMethodsAreImplemented()
MethodDesc *pMD = it.GetMethodDesc();
if (pMD->IsVirtual() &&
pMD->IsStatic() &&
!ResolveVirtualStaticMethod(pInterfaceMT, pMD, /* allowNullResult */ TRUE, /* checkDuplicates */ TRUE, /* allowVariantMatches */ FALSE))
!ResolveVirtualStaticMethod(pInterfaceMT, pMD, /* allowNullResult */ TRUE, /* verifyImplemented */ TRUE, /* allowVariantMatches */ FALSE))
{
IMDInternalImport* pInternalImport = GetModule()->GetMDImport();
GetModule()->GetAssembly()->ThrowTypeLoadException(pInternalImport, GetCl(), pMD->GetName(), IDS_CLASSLOAD_STATICVIRTUAL_NOTIMPL);
Expand Down
8 changes: 6 additions & 2 deletions src/coreclr/vm/methodtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -2285,7 +2285,11 @@ class MethodTable


// Resolve virtual static interface method pInterfaceMD on this type.
MethodDesc *ResolveVirtualStaticMethod(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL allowNullResult, BOOL checkDuplicates = FALSE, BOOL allowVariantMatches = TRUE);
//
// Specify allowNullResult to return NULL instead of throwing if the there is no implementation
// Specify verifyImplemented to verify that there is a match, but do not actually return a final useable MethodDesc
// Specify allowVariantMatches to permit generic interface variance
MethodDesc *ResolveVirtualStaticMethod(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL allowNullResult, BOOL verifyImplemented = FALSE, BOOL allowVariantMatches = TRUE);

// Try a partial resolve of the constraint call, up to generic code sharing.
//
Expand Down Expand Up @@ -2402,7 +2406,7 @@ class MethodTable

// Try to resolve a given static virtual method override on this type. Return nullptr
// when not found.
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL checkDuplicates);
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented);

public:
static MethodDesc *MapMethodDeclToMethodImpl(MethodDesc *pMDDecl);
Expand Down
80 changes: 68 additions & 12 deletions src/coreclr/vm/methodtablebuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1186,21 +1186,45 @@ MethodTableBuilder::bmtInterfaceEntry::CreateSlotTable(
CONSISTENCY_CHECK(m_pImplTable == NULL);

SLOT_INDEX cSlots = (SLOT_INDEX)GetInterfaceType()->GetMethodTable()->GetNumVirtuals();
bmtInterfaceSlotImpl * pST = new (pStackingAllocator) bmtInterfaceSlotImpl[cSlots];
SLOT_INDEX cSlotsTotal = cSlots;

if (GetInterfaceType()->GetMethodTable()->HasVirtualStaticMethods())
{
MethodTable::MethodIterator it(GetInterfaceType()->GetMethodTable());
for (; it.IsValid(); it.Next())
{
MethodDesc *pDeclMD = it.GetDeclMethodDesc();
if (pDeclMD->IsStatic() && pDeclMD->IsVirtual())
{
cSlotsTotal++;
}
}
}

bmtInterfaceSlotImpl * pST = new (pStackingAllocator) bmtInterfaceSlotImpl[cSlotsTotal];


MethodTable::MethodIterator it(GetInterfaceType()->GetMethodTable());
for (; it.IsValid(); it.Next())
{
if (!it.IsVirtual())
MethodDesc *pDeclMD = it.GetDeclMethodDesc();
if (!pDeclMD->IsVirtual())
{
break;
}

bmtRTMethod * pCurMethod = new (pStackingAllocator)
bmtRTMethod(GetInterfaceType(), it.GetDeclMethodDesc());

CONSISTENCY_CHECK(m_cImplTable == it.GetSlotNumber());
pST[m_cImplTable++] = bmtInterfaceSlotImpl(pCurMethod, INVALID_SLOT_INDEX);
if (pDeclMD->IsStatic())
{
pST[cSlots + m_cImplTableStatics++] = bmtInterfaceSlotImpl(pCurMethod, INVALID_SLOT_INDEX);
}
else
{
CONSISTENCY_CHECK(m_cImplTable == it.GetSlotNumber());
pST[m_cImplTable++] = bmtInterfaceSlotImpl(pCurMethod, INVALID_SLOT_INDEX);
}
}

m_pImplTable = pST;
Expand Down Expand Up @@ -4808,16 +4832,16 @@ VOID MethodTableBuilder::TestMethodImpl(
{
BuildMethodTableThrowException(IDS_CLASSLOAD_MI_NONVIRTUAL_DECL);
}
if (!IsMdVirtual(dwImplAttrs))
if ((IsMdVirtual(dwImplAttrs) && IsMdStatic(dwImplAttrs)) || (!IsMdVirtual(dwImplAttrs) && !IsMdStatic(dwImplAttrs)))
{
BuildMethodTableThrowException(IDS_CLASSLOAD_MI_MUSTBEVIRTUAL);
}
// Virtual methods cannot be static
if (IsMdStatic(dwDeclAttrs))
// Virtual methods on classes/valuetypes cannot be static
if (IsMdStatic(dwDeclAttrs) && !hDeclMethod.GetOwningType().IsInterface())
{
BuildMethodTableThrowException(IDS_CLASSLOAD_STATICVIRTUAL);
}
if (IsMdStatic(dwImplAttrs))
if ((!!IsMdStatic(dwImplAttrs)) != (!!IsMdStatic(dwDeclAttrs)))
{
BuildMethodTableThrowException(IDS_CLASSLOAD_STATICVIRTUAL);
}
Expand Down Expand Up @@ -5421,14 +5445,14 @@ MethodTableBuilder::PlaceVirtualMethods()
// that the name+signature corresponds to. Used by ProcessMethodImpls and ProcessInexactMethodImpls
// Always returns the first match that it finds. Affects the ambiguities in code:#ProcessInexactMethodImpls_Ambiguities
MethodTableBuilder::bmtMethodHandle
MethodTableBuilder::FindDeclMethodOnInterfaceEntry(bmtInterfaceEntry *pItfEntry, MethodSignature &declSig)
MethodTableBuilder::FindDeclMethodOnInterfaceEntry(bmtInterfaceEntry *pItfEntry, MethodSignature &declSig, bool searchForStaticMethods)
{
STANDARD_VM_CONTRACT;

bmtMethodHandle declMethod;

bmtInterfaceEntry::InterfaceSlotIterator slotIt =
pItfEntry->IterateInterfaceSlots(GetStackingAllocator());
pItfEntry->IterateInterfaceSlots(GetStackingAllocator(), searchForStaticMethods);
// Check for exact match
for (; !slotIt.AtEnd(); slotIt.Next())
{
Expand Down Expand Up @@ -5656,7 +5680,7 @@ MethodTableBuilder::ProcessMethodImpls()
DeclaredMethodIterator it(*this);
while (it.Next())
{
if (!IsMdVirtual(it.Attrs()) && it.IsMethodImpl())
if (!IsMdVirtual(it.Attrs()) && it.IsMethodImpl() && bmtProp->fNoSanityChecks)
{
// Non-virtual methods can only be classified as methodImpl when implementing
// static virtual methods.
Expand Down Expand Up @@ -5839,7 +5863,7 @@ MethodTableBuilder::ProcessMethodImpls()
}

// 3. Find the matching method.
declMethod = FindDeclMethodOnInterfaceEntry(pItfEntry, declSig);
declMethod = FindDeclMethodOnInterfaceEntry(pItfEntry, declSig, !IsMdVirtual(it.Attrs())); // Search for statics when the impl is non-virtual
}
else
{
Expand Down Expand Up @@ -5874,6 +5898,14 @@ MethodTableBuilder::ProcessMethodImpls()
BuildMethodTableThrowException(IDS_CLASSLOAD_MI_MUSTBEVIRTUAL, it.Token());
}

if (!IsMdVirtual(it.Attrs()) && it.IsMethodImpl() && IsMdStatic(it.Attrs()))
{
// Non-virtual methods can only be classified as methodImpl when implementing
// static virtual methods.
ValidateStaticMethodImpl(declMethod, *it);//bmtMethodHandle(pCurImplMethod));
continue;
}

if (bmtMetaData->rgMethodImplTokens[m].fRequiresCovariantReturnTypeChecking)
{
it->GetMethodDesc()->SetRequiresCovariantReturnTypeChecking();
Expand Down Expand Up @@ -6744,6 +6776,30 @@ MethodTableBuilder::PlaceParentDeclarationOnClass(
(*pSlotIndex)++;
} // MethodTableBuilder::PlaceParentDeclarationOnClass

VOID MethodTableBuilder::ValidateStaticMethodImpl(
bmtMethodHandle hDecl,
bmtMethodHandle hImpl)
{
// While we don't want to place the static method impl declarations on the class/interface, we do
// need to validate the method constraints and signature are compatible
if (!bmtProp->fNoSanityChecks)
{
///////////////////////////////
// Verify the signatures match

MethodImplCompareSignatures(
hDecl,
hImpl,
FALSE /* allowCovariantReturn */,
IDS_CLASSLOAD_CONSTRAINT_MISMATCH_ON_INTERFACE_METHOD_IMPL);

///////////////////////////////
// Validate the method impl.

TestMethodImpl(hDecl, hImpl);
}
}

//*******************************************************************************
// This will validate that all interface methods that were matched during
// layout also validate against type constraints.
Expand Down
Loading

0 comments on commit 333c4e7

Please sign in to comment.