Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Improve overload resolution #20054

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ object Trees {
case MethodTpe(_, _, x: MethodType) => !x.isImplicitMethod
case _ => true
}}
val alternatives = ctx.typer.resolveOverloaded(allAlts, proto)
val alternatives = ctx.typer.resolveOverloaded(allAlts, proto, receiver.srcPos)
assert(alternatives.size == 1,
i"${if (alternatives.isEmpty) "no" else "multiple"} overloads available for " +
i"$method on ${receiver.tpe.widenDealiasKeepAnnots} with targs: $targs%, %; args: $args%, %; expectedType: $expectedType." +
Expand Down
68 changes: 51 additions & 17 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,7 @@ trait Applications extends Compatibility {
* Two trials: First, without implicits or SAM conversions enabled. Then,
* if the first finds no eligible candidates, with implicits and SAM conversions enabled.
*/
def resolveOverloaded(alts: List[TermRef], pt: Type)(using Context): List[TermRef] =
def resolveOverloaded(alts: List[TermRef], pt: Type, srcPos: SrcPos)(using Context): List[TermRef] =
record("resolveOverloaded")

/** Is `alt` a method or polytype whose result type after the first value parameter
Expand Down Expand Up @@ -2100,7 +2100,7 @@ trait Applications extends Compatibility {
case Nil => chosen
case alt2 :: Nil => alt2
case alts2 =>
resolveOverloaded(alts2, pt) match {
resolveOverloaded(alts2, pt, srcPos) match {
case alt2 :: Nil => alt2
case _ => chosen
}
Expand All @@ -2115,12 +2115,12 @@ trait Applications extends Compatibility {
val alts0 = alts.filterConserve(_.widen.stripPoly.isImplicitMethod)
if alts0 ne alts then return resolve(alts0)
else if alts.exists(_.widen.stripPoly.isContextualMethod) then
return resolveMapped(alts, alt => stripImplicit(alt.widen), pt)
return resolveMapped(alts, alt => stripImplicit(alt.widen), pt, srcPos)
case _ =>

var found = withoutMode(Mode.ImplicitsEnabled)(resolveOverloaded1(alts, pt))
var found = withoutMode(Mode.ImplicitsEnabled)(resolveOverloaded1(alts, pt, srcPos))
if found.isEmpty && ctx.mode.is(Mode.ImplicitsEnabled) then
found = resolveOverloaded1(alts, pt)
found = resolveOverloaded1(alts, pt, srcPos)
found match
case alt :: Nil => adaptByResult(alt, alts) :: Nil
case _ => found
Expand Down Expand Up @@ -2167,10 +2167,44 @@ trait Applications extends Compatibility {
* It might be called twice from the public `resolveOverloaded` method, once with
* implicits and SAM conversions enabled, and once without.
*/
private def resolveOverloaded1(alts: List[TermRef], pt: Type)(using Context): List[TermRef] =
private def resolveOverloaded1(alts: List[TermRef], pt: Type, srcPos: SrcPos)(using Context): List[TermRef] =
trace(i"resolve over $alts%, %, pt = $pt", typr, show = true) {
record(s"resolveOverloaded1", alts.length)

val sv = Feature.sourceVersion
val isOldPriorityVersion: Boolean = sv.isAtMost(SourceVersion.`3.6`)
val isWarnPriorityChangeVersion = sv == SourceVersion.`3.6` || sv == SourceVersion.`3.7-migration`

inline def warnOnPriorityChange(oldCands: List[TermRef], newCands: List[TermRef])(f: List[TermRef] => List[TermRef]): List[TermRef] =

def doWarn(oldChoice: String, newChoice: String): Unit =
val (change, whichChoice) =
if isOldPriorityVersion
then ("will change", "Current choice ")
else ("has changed", "Previous choice")

val msg = // uses oldCands as the list of alternatives since they should be a superset of newCands
em"""Overloading resolution for ${err.expectedTypeStr(pt)} between alternatives
| ${oldCands map (_.info)}%\n %
|$change.
|$whichChoice : $oldChoice
|New choice from Scala 3.7: $newChoice"""

report.warning(msg, srcPos)
end doWarn

lazy val oldRes = f(oldCands)
val newRes = f(newCands)

if isWarnPriorityChangeVersion then (oldRes, newRes) match
case (oldAlt :: Nil, newAlt :: Nil) if oldAlt != newAlt => doWarn(oldAlt.info.show, newAlt.info.show)
case (oldAlt :: Nil, Nil) => doWarn(oldAlt.info.show, "none")
case (Nil, newAlt :: Nil) => doWarn("none", newAlt.info.show)
case _ => // neither scheme has determined an alternative

if isOldPriorityVersion then oldRes else newRes
end warnOnPriorityChange

def isDetermined(alts: List[TermRef]) = alts.isEmpty || alts.tail.isEmpty

/** The shape of given tree as a type; cannot handle named arguments. */
Expand Down Expand Up @@ -2299,7 +2333,7 @@ trait Applications extends Compatibility {
TypeOps.boundsViolations(targs1, tp.paramInfos, _.substParams(tp, _), NoType).isEmpty
val alts2 = alts1.filter(withinBounds)
if isDetermined(alts2) then alts2
else resolveMapped(alts1, _.widen.appliedTo(targs1.tpes), pt1)
else resolveMapped(alts1, _.widen.appliedTo(targs1.tpes), pt1, srcPos)

case pt =>
val compat = alts.filterConserve(normalizedCompatible(_, pt, keepConstraint = false))
Expand Down Expand Up @@ -2357,37 +2391,37 @@ trait Applications extends Compatibility {
candidates
else
val found = narrowMostSpecific(candidates)
if found.length <= 1 then found
if isDetermined(found) then found
else
val deepPt = pt.deepenProto
deepPt match
case pt @ FunProto(_, PolyProto(targs, resType)) =>
// try to narrow further with snd argument list and following type params
resolveMapped(candidates,
skipParamClause(pt.typedArgs().tpes, targs.tpes), resType)
warnOnPriorityChange(candidates, found):
resolveMapped(_, skipParamClause(pt.typedArgs().tpes, targs.tpes), resType, srcPos)
case pt @ FunProto(_, resType: FunOrPolyProto) =>
// try to narrow further with snd argument list
resolveMapped(candidates,
skipParamClause(pt.typedArgs().tpes, Nil), resType)
warnOnPriorityChange(candidates, found):
resolveMapped(_, skipParamClause(pt.typedArgs().tpes, Nil), resType, srcPos)
case _ =>
// prefer alternatives that need no eta expansion
val noCurried = alts.filterConserve(!resultIsMethod(_))
val noCurriedCount = noCurried.length
if noCurriedCount == 1 then
noCurried
else if noCurriedCount > 1 && noCurriedCount < alts.length then
resolveOverloaded1(noCurried, pt)
resolveOverloaded1(noCurried, pt, srcPos)
else
// prefer alternatves that match without default parameters
val noDefaults = alts.filterConserve(!_.symbol.hasDefaultParams)
val noDefaultsCount = noDefaults.length
if noDefaultsCount == 1 then
noDefaults
else if noDefaultsCount > 1 && noDefaultsCount < alts.length then
resolveOverloaded1(noDefaults, pt)
resolveOverloaded1(noDefaults, pt, srcPos)
else if deepPt ne pt then
// try again with a deeper known expected type
resolveOverloaded1(alts, deepPt)
resolveOverloaded1(alts, deepPt, srcPos)
else
candidates
}
Expand All @@ -2414,7 +2448,7 @@ trait Applications extends Compatibility {
* type is mapped with `f`, alternatives with non-existing types or symbols are dropped, and the
* expected type is `pt`. Map the results back to the original alternatives.
*/
def resolveMapped(alts: List[TermRef], f: TermRef => Type, pt: Type)(using Context): List[TermRef] =
def resolveMapped(alts: List[TermRef], f: TermRef => Type, pt: Type, srcPos: SrcPos)(using Context): List[TermRef] =
val reverseMapping = alts.flatMap { alt =>
val t = f(alt)
if t.exists && alt.symbol.exists then
Expand All @@ -2437,7 +2471,7 @@ trait Applications extends Compatibility {
}
val mapped = reverseMapping.map(_._1)
overload.println(i"resolve mapped: ${mapped.map(_.widen)}%, % with $pt")
resolveOverloaded(mapped, pt)(using ctx.retractMode(Mode.SynthesizeExtMethodReceiver))
resolveOverloaded(mapped, pt, srcPos)(using ctx.retractMode(Mode.SynthesizeExtMethodReceiver))
.map(reverseMapping.toMap)

/** Try to typecheck any arguments in `pt` that are function values missing a
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4022,7 +4022,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def altRef(alt: SingleDenotation) = TermRef(ref.prefix, ref.name, alt)
val alts = altDenots.map(altRef)

resolveOverloaded(alts, pt) match
resolveOverloaded(alts, pt, tree.srcPos) match
case alt :: Nil =>
readaptSimplified(tree.withType(alt))
case Nil =>
Expand Down
1 change: 1 addition & 0 deletions compiler/test/dotc/pos-test-pickling.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ i7445b.scala
i15525.scala
i19955a.scala
i19955b.scala
i20053.scala
i20053b.scala

# alias types at different levels of dereferencing
Expand Down
27 changes: 27 additions & 0 deletions tests/neg/multiparamlist-overload-3.6.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
-- [E007] Type Mismatch Error: tests/neg/multiparamlist-overload-3.6.scala:33:21 ---------------------------------------
33 | val r = f(new B)(new A) // error since resolves to R2 in 3.6 (and 3.7) as expected
| ^^^^^
| Found: A
| Required: B
|
| longer explanation available when compiling with `-explain`
-- Warning: tests/neg/multiparamlist-overload-3.6.scala:20:10 ----------------------------------------------------------
20 | val r = f(new B)(new C) // resolves to R1 in 3.6
| ^
| Overloading resolution for arguments (B)(C) between alternatives
| (x: B)(y: B): R3
| (x: B)(y: A): R2
| (x: A)(y: C): R1
| will change.
| Current choice : (x: A)(y: C): R1
| New choice from Scala 3.7: (x: B)(y: B): R3
-- Warning: tests/neg/multiparamlist-overload-3.6.scala:40:12 ----------------------------------------------------------
40 | val r = f(new B)(new A) // resolves to R1 in 3.6
| ^
| Overloading resolution for arguments (B)(A) between alternatives
| (x: B)(y: C): R3
| (x: B)(y: B): R2
| (x: A)(y: A): R1
| will change.
| Current choice : (x: A)(y: A): R1
| New choice from Scala 3.7: none
43 changes: 43 additions & 0 deletions tests/neg/multiparamlist-overload-3.6.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import scala.language.`3.6`

class A
class B extends A
class C extends B

class R1
class R2
class R3

// The alternatives are ordered from most genereal to most specific in each test,
// with respect to a lexicographic ordering by parameter list.


object Test1:
def f(x: A)(y: C) = new R1
def f(x: B)(y: A) = new R2
def f(x: B)(y: B) = new R3

val r = f(new B)(new C) // resolves to R1 in 3.6
val _: R1 = r
end Test1


object Test2:
// R1 is the only applicable alternative in both parts
// but it is only resolved to in Part2 by adding (an unapplicable) R3

object Part1:
def f(x: A)(y: A) = new R1
def f(x: B)(y: B) = new R2

val r = f(new B)(new A) // error since resolves to R2 in 3.6 (and 3.7) as expected

object Part2:
def f(x: A)(y: A) = new R1
def f(x: B)(y: B) = new R2
def f(x: B)(y: C) = new R3

val r = f(new B)(new A) // resolves to R1 in 3.6
val _: R1 = r

end Test2
25 changes: 25 additions & 0 deletions tests/neg/multiparamlist-overload-3.7.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
-- [E007] Type Mismatch Error: tests/neg/multiparamlist-overload-3.7.scala:33:21 ---------------------------------------
33 | val r = f(new B)(new A) // error since resolves to R2 in 3.7 (and 3.6), as expected
| ^^^^^
| Found: A
| Required: B
|
| longer explanation available when compiling with `-explain`
-- [E134] Type Error: tests/neg/multiparamlist-overload-3.7.scala:40:12 ------------------------------------------------
40 | val r = f(new B)(new A) // error since resolves to R2 in 3.7, as in Part1
| ^
| None of the overloaded alternatives of method f in object Part2 with types
| (x: B)(y: C): R3
| (x: B)(y: B): R2
| (x: A)(y: A): R1
| match arguments (B)(A)
-- Warning: tests/neg/multiparamlist-overload-3.7.scala:20:10 ----------------------------------------------------------
20 | val r = f(new B)(new C) // resolves to R3 in 3.7
| ^
| Overloading resolution for arguments (B)(C) between alternatives
| (x: B)(y: B): R3
| (x: B)(y: A): R2
| (x: A)(y: C): R1
| has changed.
| Previous choice : (x: A)(y: C): R1
| New choice from Scala 3.7: (x: B)(y: B): R3
42 changes: 42 additions & 0 deletions tests/neg/multiparamlist-overload-3.7.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import scala.language.`3.7-migration`

class A
class B extends A
class C extends B

class R1
class R2
class R3

// The alternatives are ordered from most genereal to most specific in each test,
// with respect to a lexicographic ordering by parameter list.


object Test1:
def f(x: A)(y: C) = new R1
def f(x: B)(y: A) = new R2
def f(x: B)(y: B) = new R3

val r = f(new B)(new C) // resolves to R3 in 3.7
val _: R3 = r
end Test1


object Test2:
// R1 is the only applicable alternative in both parts
// but it is never resolved to since R2 has a more specific 1st parameter list

object Part1:
def f(x: A)(y: A) = new R1
def f(x: B)(y: B) = new R2

val r = f(new B)(new A) // error since resolves to R2 in 3.7 (and 3.6), as expected

object Part2:
def f(x: A)(y: A) = new R1
def f(x: B)(y: B) = new R2
def f(x: B)(y: C) = new R3

val r = f(new B)(new A) // error since resolves to R2 in 3.7, as in Part1

end Test2
22 changes: 22 additions & 0 deletions tests/neg/scalatest-overload-3.7.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import scala.language.`3.7`

class TestBody1
class TestBody2

class StartWithWord
class EndWithWord

class Matchers:
extension (leftSideString: String)
def should(body: TestBody1): Unit = ()
def should(body: TestBody2): Unit = ()

extension [T](leftSideValue: T)
def should(word: StartWithWord)(using T <:< String): Unit = ()
def should(word: EndWithWord)(using T <:< String): Unit = ()

def endWith(rightSideString: String): EndWithWord = new EndWithWord

class Test extends Matchers:
def test(): Unit =
"hello world" should endWith ("world") // error
Loading
Loading