diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 70eabb4f7e..b2173cd7b9 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1200,28 +1200,6 @@ namespace Slang return parent; } - void countDistanceToGloablScope(DeclRef const& leftDecl, - DeclRef const& rightDecl, - int& leftDistance, int& rightDistance) - { - leftDistance = 0; - rightDistance = 0; - - DeclRef decl = leftDecl; - while(decl) - { - leftDistance++; - decl = decl.getParent(); - } - - decl = rightDecl; - while(decl) - { - rightDistance++; - decl = decl.getParent(); - } - } - // Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal. // int SemanticsVisitor::CompareLookupResultItems( @@ -1347,23 +1325,6 @@ namespace Slang } } - // We need to consider the distance of the declarations to the global scope to resolve this case: - // float f(float x); - // struct S - // { - // float f(float x); - // float g(float y) { return f(y); } // will call S::f() instead of ::f() - // } - // We don't need to know the call site of 'f(y)', but only need to count the two candidates' distance to the global scope, - // because this function will only choose the valid candidates. So if there is situation like this: - // void main() { S s; s.f(1.0);} or - // struct T { float g(y) { f(y); } }, there won't be ambiguity. - // So we just need to count which declaration is farther from the global scope and favor the farther one. - int leftDistance = 0; - int rightDistance = 0; - countDistanceToGloablScope(left.declRef, right.declRef, leftDistance, rightDistance); - if (leftDistance != rightDistance) - return leftDistance > rightDistance ? -1 : 1; // TODO: We should generalize above rules such that in a tie a declaration // A::m is better than B::m when all other factors are equal and @@ -1479,6 +1440,70 @@ namespace Slang return 0; } + int getScopeRank(DeclRef const& left, + DeclRef const& right, Slang::Scope* referenceSiteScope) + { + if (!referenceSiteScope) + return 0; + + DeclRef prefixDecl = referenceSiteScope->containerDecl; + + // Hold the path from reference site to the root + // key: Decl node, value: distance from reference site + Dictionary refPath; + for (auto node = prefixDecl; node != nullptr; node = node.getParent()) + { + Decl* key = node.getDecl(); + uint32_t value = (uint32_t)refPath.getCount(); + refPath.add(key, value); + } + + // find the common prefix decl of reference site and left + int leftDistance = 0; + int rightDistance = 0; + auto distanceToCommonPrefix = [](DeclRef const& candidate, Dictionary refPath) -> int + { + uint32_t distanceToReferenceSite = 0; + uint32_t distanceToCandidate = 0; + + // Sanity check + if (candidate.getDecl() == nullptr) + return -1; + + // search from candidate to root, once we found the first node in the reference path, that is the first + // common prefix, and we can stop searching. + for (auto node = candidate; node != nullptr; node = node.getParent()) + { + Decl* key = node.getDecl(); + if (refPath.tryGetValue(key, distanceToReferenceSite)) + { + break; + } + distanceToCandidate++; + } + + // If we don't find the common prefix, there must be something wrong, return the max value. + if (distanceToReferenceSite == 0) + return -1; + + return distanceToReferenceSite + distanceToCandidate; + }; + + leftDistance = distanceToCommonPrefix(left, refPath); + rightDistance = distanceToCommonPrefix(right, refPath); + + if (leftDistance == rightDistance) + return 0; + + if (leftDistance == -1) + return 1; + + if (rightDistance == -1) + return -1; + + return leftDistance < rightDistance ? -1 : 1; + } + int SemanticsVisitor::CompareOverloadCandidates( OverloadCandidate* left, OverloadCandidate* right) @@ -1558,6 +1583,42 @@ namespace Slang if (externExportDiff) return externExportDiff; + // We need to consider the distance of the declarations to the global scope to resolve this case: + // float f(float x); + // struct S + // { + // float f(float x); + // float g(float y) { return f(y); } // will call S::f() instead of ::f() + // } + // we will count the distance from the reference site to the declaration in the scope tree. + + // NOTE: We CAN'T do this for the generic function, because generic lookup is little bit complicated. + // It will go through multiple passes of candidates compare. + // In the first pass, it will lookup all the generic candidates that matches the generic parameter only, + // e.g., the following generic functions are totally different, but they will be selected as candidates + // because the function name and the generic parameters are the same: + // void func(Z0 a, Z1 b); + // void func(Z0 a, Z1 b, Z0 c); + // void func(Z0 a, Z1 b, Z0 c, Z1 d); + // + // So in this case, we should not consider the scope rank and overload rank at all, because there is only + // one of above candidates is valid, and the rank calculation doesn't consider the correctness of the + // candidates, so it could select the wrong candidate. + // + // In the next pass, the lookup system will match the input parameters in those candidates to find out the valid + // match, the "flavor" field will become "Func" or "Expr". So the rank calculation can be applied. + if (left->flavor == OverloadCandidate::Flavor::Generic || + left->flavor == OverloadCandidate::Flavor::UnspecializedGeneric || + right->flavor == OverloadCandidate::Flavor::Generic || + right->flavor == OverloadCandidate::Flavor::UnspecializedGeneric) + { + return 0; + } + + auto scopeRank = getScopeRank(left->item.declRef, right->item.declRef, this->m_outerScope); + if (scopeRank) + return scopeRank; + // If we reach here, we will attempt to use overload rank to break the ties. auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef); if (overloadRankDiff) diff --git a/tests/bugs/overload-ambiguous-1.slang b/tests/bugs/overload-ambiguous-1.slang new file mode 100644 index 0000000000..9f9c6e5bc5 --- /dev/null +++ b/tests/bugs/overload-ambiguous-1.slang @@ -0,0 +1,65 @@ +// https://github.com/shader-slang/slang/issues/4476 + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +namespace A1 +{ + uint func() + { + return 1u; + } + + namespace A2 + { + uint func() + { + return 2u; + } + + namespace A3 + { + uint func() + { + return 3u; + } + + uint test2() + { + return func(); // choose A3::func() + } + } + + namespace A4 + { + uint test() + { + return func(); // choose A2::func() + } + } + } +} + +[numthreads(1, 1, 1)] +[shader("compute")] +void computeMain(uint3 threadID: SV_DispatchThreadID) +{ + using namespace A1; + using namespace A1::A2; + using namespace A1::A2::A3; + using namespace A1::A2::A4; + outputBuffer[0] = test(); + // BUF: 2 + + outputBuffer[1] = func(); // choose the A1::func() + // BUF-NEXT: 1 + + outputBuffer[2] = test2(); + // BUF-NEXT: 3 +} diff --git a/tests/bugs/overload-ambiguous-2.slang b/tests/bugs/overload-ambiguous-2.slang new file mode 100644 index 0000000000..46af9f0919 --- /dev/null +++ b/tests/bugs/overload-ambiguous-2.slang @@ -0,0 +1,67 @@ +// https://github.com/shader-slang/slang/issues/4476 + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +namespace A +{ + struct Struct1 + { + uint data; + }; + + Struct1 myFunc(Struct1 inputS1) + { + Struct1 s1; + s1.data = inputS1.data + 2U; + return s1; + } +}; + + +A::Struct1 myFunc(A::Struct1 inputS1) +{ + A::Struct1 s1; + s1.data = inputS1.data + 5U; + return s1; +} + +namespace A +{ + struct Struct2 + { + Struct1 s1; + } + + Struct2 myFunc(Struct2 inputS2) + { + Struct2 s2; + // We want to cover a corner case in our compiler where: + // when looking up "myFunc", the compiler should find + // Struct1 A::myFunc(Struct1 inputS1) + // and it won't be ambiguous with the global "myFunc". + s2.s1 = myFunc(inputS2.s1); + return s2; + } +}; + +[numthreads(1, 1, 1)] +[shader("compute")] +void computeMain(uint3 threadID: SV_DispatchThreadID) +{ + using namespace A; + + Struct2<10> input = {threadID.x}; + + Struct2<20> output; + output = myFunc<10, 20>(input); + outputBuffer[0] = output.s1.data; + + // BUF: 2 +} diff --git a/tests/bugs/overload-ambiguous.slang b/tests/bugs/overload-ambiguous.slang index 1b74cb68c2..d764f72e42 100644 --- a/tests/bugs/overload-ambiguous.slang +++ b/tests/bugs/overload-ambiguous.slang @@ -6,7 +6,7 @@ //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj -//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -34,7 +34,18 @@ struct DataObtainer } } -RWStructuredBuffer output; +uint myFunc(uint a) +{ + return a + 1u; +} + +__generic +uint myFunc(T a) +{ + uint b = __intCast(a); + return b + 2u; +} + [numthreads(1, 1, 1)] [shader("compute")] @@ -43,6 +54,10 @@ void computeMain(uint3 threadID: SV_DispatchThreadID) DataObtainer obtainer = {2u}; outputBuffer[0] = obtainer.getValue(); outputBuffer[1] = obtainer.getValue2(); + + uint a = 1u; + outputBuffer[2] = myFunc(a); // will call myFunc(uint) which more specialized // BUF: 2 // BUF-NEXT: 1 + // BUF-NEXT: 2 } diff --git a/tests/diagnostics/overload-ambiguous.slang b/tests/diagnostics/overload-ambiguous.slang new file mode 100644 index 0000000000..0c8f7bd216 --- /dev/null +++ b/tests/diagnostics/overload-ambiguous.slang @@ -0,0 +1,45 @@ +// https://github.com/shader-slang/slang/issues/4476 + +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): +RWStructuredBuffer outputBuffer; + +namespace A1 +{ + uint func() + { + return 1u; + } + + namespace A2 + { + uint func() + { + return 2u; + } + } +} +namespace B1 +{ + uint func() + { + return 4u; + } +} + +[numthreads(1, 1, 1)] +[shader("compute")] +void computeMain(uint3 threadID: SV_DispatchThreadID) +{ + using namespace A1; + using namespace A1::A2; + using namespace B1; + using namespace C1; + + // Only A1::func() and B1::func() will cause ambiguity because the distance from + // the reference site to those two functions declaration are the same. + outputBuffer[0] = func(); + // CHECK-NOT: {{.*}}A2::func() -> uint + // CHECK: ambiguous call to 'func' with arguments of type () + // CHECK: candidate: func B1::func() -> uint + // CHECK: candidate: func A1::func() -> uint +}