Skip to content

Commit

Permalink
Lower the priority of looking up the rank of scope (#5065)
Browse files Browse the repository at this point in the history
* Lower the priority of looking up the rank of scope

In the previous change of #5060, we propose a way to resolve
the ambiguous call when considering the scope of a function.

But this rule should be considered as a low priority than "specialized
candidate", aka. we should consider more "specialized candiate" first.

* Count distance between reference site to declaration site

Compare the candidate by calculating distance
from reference site to declaration site via nearest common prefix
in the scope tree.

This will involve finding the common parent node of two child nodes
and how sum the distance from the common parent to the two child nodes.

* Change the priority higher than 'getOverloadRank'

* Don't evaluate the scope rank algorithm on generic

If the candidate is generic function, the function parameters
won't be checked before 'CompareOverloadCandidates', so it will
results in that the candidates this function could be invalid.

We should not evaluate the distance algorithm in this case, instead
we will evaluate later when the candidate is in flavor of Func or Expr
since then all the type checks for the function will be done.
  • Loading branch information
kaizhangNV authored Sep 18, 2024
1 parent 2d83875 commit 3240799
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 41 deletions.
139 changes: 100 additions & 39 deletions source/slang/slang-check-overload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1200,28 +1200,6 @@ namespace Slang
return parent;
}

void countDistanceToGloablScope(DeclRef<Slang::Decl> const& leftDecl,
DeclRef<Slang::Decl> const& rightDecl,
int& leftDistance, int& rightDistance)
{
leftDistance = 0;
rightDistance = 0;

DeclRef<Decl> decl = leftDecl;
while(decl)
{
leftDistance++;
decl = decl.getParent();
}

decl = rightDecl;
while(decl)
{
rightDistance++;
decl = decl.getParent();
}
}

// Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal.
//
int SemanticsVisitor::CompareLookupResultItems(
Expand Down Expand Up @@ -1347,23 +1325,6 @@ namespace Slang
}
}

// We need to consider the distance of the declarations to the global scope to resolve this case:
// float f(float x);
// struct S
// {
// float f(float x);
// float g(float y) { return f(y); } // will call S::f() instead of ::f()
// }
// We don't need to know the call site of 'f(y)', but only need to count the two candidates' distance to the global scope,
// because this function will only choose the valid candidates. So if there is situation like this:
// void main() { S s; s.f(1.0);} or
// struct T { float g(y) { f(y); } }, there won't be ambiguity.
// So we just need to count which declaration is farther from the global scope and favor the farther one.
int leftDistance = 0;
int rightDistance = 0;
countDistanceToGloablScope(left.declRef, right.declRef, leftDistance, rightDistance);
if (leftDistance != rightDistance)
return leftDistance > rightDistance ? -1 : 1;

// TODO: We should generalize above rules such that in a tie a declaration
// A::m is better than B::m when all other factors are equal and
Expand Down Expand Up @@ -1479,6 +1440,70 @@ namespace Slang
return 0;
}

int getScopeRank(DeclRef<Decl> const& left,
DeclRef<Decl> const& right, Slang::Scope* referenceSiteScope)
{
if (!referenceSiteScope)
return 0;

DeclRef<Decl> prefixDecl = referenceSiteScope->containerDecl;

// Hold the path from reference site to the root
// key: Decl node, value: distance from reference site
Dictionary<Decl*, uint32_t> refPath;
for (auto node = prefixDecl; node != nullptr; node = node.getParent())
{
Decl* key = node.getDecl();
uint32_t value = (uint32_t)refPath.getCount();
refPath.add(key, value);
}

// find the common prefix decl of reference site and left
int leftDistance = 0;
int rightDistance = 0;
auto distanceToCommonPrefix = [](DeclRef<Decl> const& candidate, Dictionary<Decl*, uint32_t> refPath) -> int
{
uint32_t distanceToReferenceSite = 0;
uint32_t distanceToCandidate = 0;

// Sanity check
if (candidate.getDecl() == nullptr)
return -1;

// search from candidate to root, once we found the first node in the reference path, that is the first
// common prefix, and we can stop searching.
for (auto node = candidate; node != nullptr; node = node.getParent())
{
Decl* key = node.getDecl();
if (refPath.tryGetValue(key, distanceToReferenceSite))
{
break;
}
distanceToCandidate++;
}

// If we don't find the common prefix, there must be something wrong, return the max value.
if (distanceToReferenceSite == 0)
return -1;

return distanceToReferenceSite + distanceToCandidate;
};

leftDistance = distanceToCommonPrefix(left, refPath);
rightDistance = distanceToCommonPrefix(right, refPath);

if (leftDistance == rightDistance)
return 0;

if (leftDistance == -1)
return 1;

if (rightDistance == -1)
return -1;

return leftDistance < rightDistance ? -1 : 1;
}

int SemanticsVisitor::CompareOverloadCandidates(
OverloadCandidate* left,
OverloadCandidate* right)
Expand Down Expand Up @@ -1558,6 +1583,42 @@ namespace Slang
if (externExportDiff)
return externExportDiff;

// We need to consider the distance of the declarations to the global scope to resolve this case:
// float f(float x);
// struct S
// {
// float f(float x);
// float g(float y) { return f(y); } // will call S::f() instead of ::f()
// }
// we will count the distance from the reference site to the declaration in the scope tree.

// NOTE: We CAN'T do this for the generic function, because generic lookup is little bit complicated.
// It will go through multiple passes of candidates compare.
// In the first pass, it will lookup all the generic candidates that matches the generic parameter only,
// e.g., the following generic functions are totally different, but they will be selected as candidates
// because the function name and the generic parameters are the same:
// void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b);
// void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b, Z0 c);
// void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b, Z0 c, Z1 d);
//
// So in this case, we should not consider the scope rank and overload rank at all, because there is only
// one of above candidates is valid, and the rank calculation doesn't consider the correctness of the
// candidates, so it could select the wrong candidate.
//
// In the next pass, the lookup system will match the input parameters in those candidates to find out the valid
// match, the "flavor" field will become "Func" or "Expr". So the rank calculation can be applied.
if (left->flavor == OverloadCandidate::Flavor::Generic ||
left->flavor == OverloadCandidate::Flavor::UnspecializedGeneric ||
right->flavor == OverloadCandidate::Flavor::Generic ||
right->flavor == OverloadCandidate::Flavor::UnspecializedGeneric)
{
return 0;
}

auto scopeRank = getScopeRank(left->item.declRef, right->item.declRef, this->m_outerScope);
if (scopeRank)
return scopeRank;

// If we reach here, we will attempt to use overload rank to break the ties.
auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef);
if (overloadRankDiff)
Expand Down
65 changes: 65 additions & 0 deletions tests/bugs/overload-ambiguous-1.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// https://github.com/shader-slang/slang/issues/4476

//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj

//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<uint> outputBuffer;

namespace A1
{
uint func()
{
return 1u;
}

namespace A2
{
uint func()
{
return 2u;
}

namespace A3
{
uint func()
{
return 3u;
}

uint test2()
{
return func(); // choose A3::func()
}
}

namespace A4
{
uint test()
{
return func(); // choose A2::func()
}
}
}
}

[numthreads(1, 1, 1)]
[shader("compute")]
void computeMain(uint3 threadID: SV_DispatchThreadID)
{
using namespace A1;
using namespace A1::A2;
using namespace A1::A2::A3;
using namespace A1::A2::A4;
outputBuffer[0] = test();
// BUF: 2

outputBuffer[1] = func(); // choose the A1::func()
// BUF-NEXT: 1

outputBuffer[2] = test2();
// BUF-NEXT: 3
}
67 changes: 67 additions & 0 deletions tests/bugs/overload-ambiguous-2.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// https://github.com/shader-slang/slang/issues/4476

//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj

//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
RWStructuredBuffer<uint> outputBuffer;

namespace A
{
struct Struct1<let SIZE : uint>
{
uint data;
};

Struct1<Z1> myFunc<let Z0 : uint, let Z1 : uint>(Struct1<Z0> inputS1)
{
Struct1<Z1> s1;
s1.data = inputS1.data + 2U;
return s1;
}
};


A::Struct1<Z1> myFunc<let Z0 : uint, let Z1 : uint>(A::Struct1<Z0> inputS1)
{
A::Struct1<Z1> s1;
s1.data = inputS1.data + 5U;
return s1;
}

namespace A
{
struct Struct2<let SIZE : uint>
{
Struct1<SIZE> s1;
}

Struct2<Z1> myFunc<let Z0 : uint, let Z1 : uint>(Struct2<Z0> inputS2)
{
Struct2<Z1> s2;
// We want to cover a corner case in our compiler where:
// when looking up "myFunc", the compiler should find
// Struct1<Z1> A::myFunc<let Z0 : uint, let Z1 : uint>(Struct1<Z0> inputS1)
// and it won't be ambiguous with the global "myFunc".
s2.s1 = myFunc<Z0, Z1>(inputS2.s1);
return s2;
}
};

[numthreads(1, 1, 1)]
[shader("compute")]
void computeMain(uint3 threadID: SV_DispatchThreadID)
{
using namespace A;

Struct2<10> input = {threadID.x};

Struct2<20> output;
output = myFunc<10, 20>(input);
outputBuffer[0] = output.s1.data;

// BUF: 2
}
19 changes: 17 additions & 2 deletions tests/bugs/overload-ambiguous.slang
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj

//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer
//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<uint> outputBuffer;


Expand Down Expand Up @@ -34,7 +34,18 @@ struct DataObtainer
}
}

RWStructuredBuffer<uint> output;
uint myFunc(uint a)
{
return a + 1u;
}

__generic<T: __BuiltinIntegerType>
uint myFunc(T a)
{
uint b = __intCast<uint, T>(a);
return b + 2u;
}


[numthreads(1, 1, 1)]
[shader("compute")]
Expand All @@ -43,6 +54,10 @@ void computeMain(uint3 threadID: SV_DispatchThreadID)
DataObtainer obtainer = {2u};
outputBuffer[0] = obtainer.getValue();
outputBuffer[1] = obtainer.getValue2();

uint a = 1u;
outputBuffer[2] = myFunc(a); // will call myFunc(uint) which more specialized
// BUF: 2
// BUF-NEXT: 1
// BUF-NEXT: 2
}
45 changes: 45 additions & 0 deletions tests/diagnostics/overload-ambiguous.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// https://github.com/shader-slang/slang/issues/4476

//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
RWStructuredBuffer<uint> outputBuffer;

namespace A1
{
uint func()
{
return 1u;
}

namespace A2
{
uint func()
{
return 2u;
}
}
}
namespace B1
{
uint func()
{
return 4u;
}
}

[numthreads(1, 1, 1)]
[shader("compute")]
void computeMain(uint3 threadID: SV_DispatchThreadID)
{
using namespace A1;
using namespace A1::A2;
using namespace B1;
using namespace C1;

// Only A1::func() and B1::func() will cause ambiguity because the distance from
// the reference site to those two functions declaration are the same.
outputBuffer[0] = func();
// CHECK-NOT: {{.*}}A2::func() -> uint
// CHECK: ambiguous call to 'func' with arguments of type ()
// CHECK: candidate: func B1::func() -> uint
// CHECK: candidate: func A1::func() -> uint
}

0 comments on commit 3240799

Please sign in to comment.