Skip to content

Commit

Permalink
Merge pull request #13292 from dotty-staging/fix-merge-constraint-3
Browse files Browse the repository at this point in the history
 Reimplement constraint merging for correctness
  • Loading branch information
odersky authored Aug 18, 2021
2 parents dd2962b + b9d8ca9 commit 45ce129
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 108 deletions.
13 changes: 0 additions & 13 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,6 @@ abstract class Constraint extends Showable {
*/
def uninstVars: collection.Seq[TypeVar]

/** The weakest constraint that subsumes both this constraint and `other`.
* The constraints should be _compatible_, meaning that a type lambda
* occurring in both constraints is associated with the same typevars in each.
*
* @param otherHasErrors If true, handle incompatible constraints by
* returning an approximate constraint, instead of
* failing with an exception
*/
def & (other: Constraint, otherHasErrors: Boolean)(using Context): Constraint

/** Whether `tl` is present in both `this` and `that` but is associated with
* different TypeVars there, meaning that the constraints cannot be merged.
*/
Expand All @@ -183,7 +173,4 @@ abstract class Constraint extends Showable {
* of athe type lambda that is associated with the typevar itself.
*/
def checkConsistentVars()(using Context): Unit

/** A string describing the constraint's contents without a header or trailer */
def contentsToString(using Context): String
}
99 changes: 10 additions & 89 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ object OrderingConstraint {
type ParamOrdering = ArrayValuedMap[List[TypeParamRef]]

/** A new constraint with given maps */
private def newConstraint(boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering)(using Context) : OrderingConstraint = {
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap)
ctx.run.recordConstraintSize(result, result.boundsMap.size)
result
}
private def newConstraint(boundsMap: ParamBounds, lowerMap: ParamOrdering, upperMap: ParamOrdering)(using Context) : OrderingConstraint =
if boundsMap.isEmpty && lowerMap.isEmpty && upperMap.isEmpty then
empty
else
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap)
ctx.run.recordConstraintSize(result, result.boundsMap.size)
result

/** A lens for updating a single entry array in one of the three constraint maps */
abstract class ConstraintLens[T <: AnyRef: ClassTag] {
Expand Down Expand Up @@ -457,48 +459,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds,

// ----------- Joins -----------------------------------------------------

def & (other: Constraint, otherHasErrors: Boolean)(using Context): OrderingConstraint = {

def merge[T](m1: ArrayValuedMap[T], m2: ArrayValuedMap[T], join: (T, T) => T): ArrayValuedMap[T] = {
var merged = m1
def mergeArrays(xs1: Array[T], xs2: Array[T]) = {
val xs = xs1.clone
for (i <- xs.indices) xs(i) = join(xs1(i), xs2(i))
xs
}
m2.foreachBinding { (poly, xs2) =>
merged = merged.updated(poly,
if (m1.contains(poly)) mergeArrays(m1(poly), xs2) else xs2)
}
merged
}

def mergeParams(ps1: List[TypeParamRef], ps2: List[TypeParamRef]) =
ps2.foldLeft(ps1)((ps1, p2) => if (ps1.contains(p2)) ps1 else p2 :: ps1)

// Must be symmetric
def mergeEntries(e1: Type, e2: Type): Type =
(e1, e2) match {
case _ if e1 eq e2 => e1
case (e1: TypeBounds, e2: TypeBounds) => e1 & e2
case (e1: TypeBounds, _) if e1 contains e2 => e2
case (_, e2: TypeBounds) if e2 contains e1 => e1
case (tv1: TypeVar, tv2: TypeVar) if tv1 eq tv2 => e1
case _ =>
if (otherHasErrors)
e1
else
throw new AssertionError(i"cannot merge $this with $other, mergeEntries($e1, $e2) failed")
}

val that = other.asInstanceOf[OrderingConstraint]

new OrderingConstraint(
merge(this.boundsMap, that.boundsMap, mergeEntries),
merge(this.lowerMap, that.lowerMap, mergeParams),
merge(this.upperMap, that.upperMap, mergeParams))
}.showing(i"constraint merge $this with $other = $result", constr)

def hasConflictingTypeVarsFor(tl: TypeLambda, that: Constraint): Boolean =
contains(tl) && that.contains(tl) &&
// Since TypeVars are allocated in bulk for each type lambda, we only have
Expand Down Expand Up @@ -641,49 +601,10 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
upperMap.foreachBinding((_, paramss) => paramss.foreach(_.foreach(checkClosedType(_, "upper"))))
end checkClosed

// ---------- toText -----------------------------------------------------

private def contentsToText(printer: Printer): Text =
//Printer.debugPrintUnique = true
def entryText(tp: Type) = tp match {
case tp: TypeBounds =>
tp.toText(printer)
case _ =>
" := " ~ tp.toText(printer)
}
val indent = 3
val uninstVarsText = " uninstantiated variables: " ~
Text(uninstVars.map(_.toText(printer)), ", ")
val constrainedText =
" constrained types: " ~ Text(domainLambdas map (_.toText(printer)), ", ")
val boundsText =
" bounds: " ~ {
val assocs =
for (param <- domainParams)
yield (" " * indent) ~ param.toText(printer) ~ entryText(entry(param))
Text(assocs, "\n")
}
val orderingText =
" ordering: " ~ {
val deps =
for {
param <- domainParams
ups = minUpper(param)
if ups.nonEmpty
}
yield
(" " * indent) ~ param.toText(printer) ~ " <: " ~
Text(ups.map(_.toText(printer)), ", ")
Text(deps, "\n")
}
//Printer.debugPrintUnique = false
Text.lines(List(uninstVarsText, constrainedText, boundsText, orderingText))
// ---------- Printing -----------------------------------------------------

override def toText(printer: Printer): Text =
Text.lines(List("Constraint(", contentsToText(printer), ")"))

def contentsToString(using Context): String =
contentsToText(ctx.printer).show
printer.toText(this)

override def toString: String = {
def entryText(tp: Type): String = tp match {
Expand All @@ -692,7 +613,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
}
val constrainedText =
" constrained types = " + domainLambdas.mkString("\n")
val boundsText = domainLambdas
val boundsText =
" bounds = " + {
val assocs =
for (param <- domainParams)
Expand Down
38 changes: 35 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TyperState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,43 @@ class TyperState() {
*/
def mergeConstraintWith(that: TyperState)(using Context): Unit =
that.ensureNotConflicting(constraint)
constraint = constraint & (that.constraint, otherHasErrors = that.reporter.errorsReported)
for tvar <- constraint.uninstVars do
if !isOwnedAnywhere(this, tvar) then includeVar(tvar)

val comparingCtx =
if ctx.typerState == this then ctx
else ctx.fresh.setTyperState(this)

comparing(typeComparer =>
val other = that.constraint
val res = other.domainLambdas.forall(tl =>
// Integrate the type lambdas from `other`
constraint.contains(tl) || other.isRemovable(tl) || {
val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv }
tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar))
typeComparer.addToConstraint(tl, tvars)
}) &&
// Integrate the additional constraints on type variables from `other`
constraint.uninstVars.forall(tv =>
val p = tv.origin
val otherLos = other.lower(p)
val otherHis = other.upper(p)
val otherEntry = other.entry(p)
( (otherLos eq constraint.lower(p)) || otherLos.forall(_ <:< p)) &&
( (otherHis eq constraint.upper(p)) || otherHis.forall(p <:< _)) &&
((otherEntry eq constraint.entry(p)) || otherEntry.match
case NoType =>
true
case tp: TypeBounds =>
tp.contains(tv)
case tp =>
tv =:= tp
)
)
assert(res || ctx.reporter.errorsReported, i"cannot merge $constraint with $other.")
)(using comparingCtx)

for tl <- constraint.domainLambdas do
if constraint.isRemovable(tl) then constraint = constraint.remove(tl)
end mergeConstraintWith

/** Take ownership of `tvar`.
*
Expand Down
41 changes: 41 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,47 @@ class PlainPrinter(_ctx: Context) extends Printer {
case _ => "{...}"
s"import $exprStr.$selectorStr"

def toText(c: OrderingConstraint): Text =
val savedConstraint = ctx.typerState.constraint
try
// The current TyperState constraint determines how type variables are printed
ctx.typerState.constraint = c
def entryText(tp: Type) = tp match {
case tp: TypeBounds =>
toText(tp)
case _ =>
" := " ~ toText(tp)
}
val indent = 3
val uninstVarsText = " uninstantiated variables: " ~
Text(c.uninstVars.map(toText), ", ")
val constrainedText =
" constrained types: " ~ Text(c.domainLambdas.map(toText), ", ")
val boundsText =
" bounds: " ~ {
val assocs =
for (param <- c.domainParams)
yield (" " * indent) ~ toText(param) ~ entryText(c.entry(param))
Text(assocs, "\n")
}
val orderingText =
" ordering: " ~ {
val deps =
for {
param <- c.domainParams
ups = c.minUpper(param)
if ups.nonEmpty
}
yield
(" " * indent) ~ toText(param) ~ " <: " ~
Text(ups.map(toText), ", ")
Text(deps, "\n")
}
//Printer.debugPrintUnique = false
Text.lines(List(uninstVarsText, constrainedText, boundsText, orderingText))
finally
ctx.typerState.constraint = savedConstraint

def plain: PlainPrinter = this

protected def keywordStr(text: String): String = coloredStr(text, SyntaxHighlighting.KeywordColor)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/Printer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ abstract class Printer {
/** Textual representation of info relating to an import clause */
def toText(result: ImportInfo): Text

/** Textual representation of a constraint */
def toText(c: OrderingConstraint): Text

/** Render element within highest precedence */
def toTextLocal(elem: Showable): Text =
atPrec(DotPrec) { elem.toText(this) }
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ object ErrorReporting {
"the empty constraint"
else
i"""a constraint with:
|${c.contentsToString}"""
|$c"""
i"""
|${TypeComparer.explained(_.isSubType(found, expected), header)}
|
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ object ProtoTypes {
def typedArgs(norm: (untpd.Tree, Int) => untpd.Tree = sameTree)(using Context): List[Tree] =
if state.typedArgs.size == args.length then state.typedArgs
else
val passedCtx = ctx
val passedTyperState = ctx.typerState
inContext(protoCtx.withUncommittedTyperState) {
val protoTyperState = ctx.typerState
Expand Down Expand Up @@ -409,8 +410,7 @@ object ProtoTypes {
tvar.instantiate(fromBelow = false)
case _ =>
}

passedTyperState.mergeConstraintWith(protoTyperState)
passedTyperState.mergeConstraintWith(protoTyperState)(using passedCtx)
end if
args1
}
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/util/SimpleIdentityMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import collection.mutable.ListBuffer
* It has linear complexity for `apply`, `updated`, and `remove`.
*/
abstract class SimpleIdentityMap[K <: AnyRef, +V >: Null <: AnyRef] extends (K => V) {
final def isEmpty: Boolean = this eq SimpleIdentityMap.myEmpty
def size: Int
def apply(k: K): V
def remove(k: K): SimpleIdentityMap[K, V]
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/util/Stats.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import collection.mutable
aggregate()
println()
println(hits.toList.sortBy(_._2).map{ case (x, y) => s"$x -> $y" } mkString "\n")
hits.clear()
}
}
else op
Expand Down
54 changes: 54 additions & 0 deletions compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package dotty.tools
package dotc.core

import vulpix.TestConfiguration

import dotty.tools.dotc.core.Contexts.{*, given}
import dotty.tools.dotc.core.Decorators.{*, given}
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.typer.ProtoTypes.constrained

import org.junit.Test

import dotty.tools.DottyTest

class ConstraintsTest:

@Test def mergeParamsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T, R]: Any }") {
val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
val List(s, t, r) = tp.paramRefs

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
s <:< t
}

t <:< r

ctx.typerState.mergeConstraintWith(innerCtx.typerState)
assert(s frozen_<:< r,
i"Merging constraints `?S <: ?T` and `?T <: ?R` should result in `?S <:< ?R`: ${ctx.typerState.constraint}")
}
end mergeParamsTransitivity

@Test def mergeBoundsTransitivity: Unit =
inCompilerContext(TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S, T]: Any }") {
val tp = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda])
val List(s, t) = tp.paramRefs

val innerCtx = ctx.fresh.setExploreTyperState()
inContext(innerCtx) {
s <:< t
}

defn.IntType <:< s

ctx.typerState.mergeConstraintWith(innerCtx.typerState)
assert(defn.IntType frozen_<:< t,
i"Merging constraints `?S <: ?T` and `Int <: ?S` should result in `Int <:< ?T`: ${ctx.typerState.constraint}")
}
end mergeBoundsTransitivity
37 changes: 37 additions & 0 deletions tests/pos/i12730.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
class ComponentSimple

class Props {
def apply(props: Any): Any = ???
}

class Foo[C] {
def build: ComponentSimple = ???
}

class Bar[E] {
def render(r: E => Any): Unit = {}
}

trait Conv[A, B] {
def apply(a: A): B
}

object Test {
def toComponentCtor[F](c: ComponentSimple): Props = ???

def defaultToNoBackend[G, H](ev: G => Foo[H]): Conv[Foo[H], Bar[H]] = ???

def conforms[A]: A => A = ???

def problem = Main // crashes

def foo[H]: Foo[H] = ???

val NameChanger =
foo
.build

val Main =
defaultToNoBackend(conforms).apply(foo)
.render(_ => toComponentCtor(NameChanger)(13))
}

0 comments on commit 45ce129

Please sign in to comment.