Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pattern matching in global init checker, skipping cases when safe to do so #22179

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 81 additions & 34 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import util.{ SourcePosition, NoSourcePosition }
import config.Printers.init as printer
import reporting.StoreReporter
import reporting.trace as log
import reporting.trace.force as forcelog
import typer.Applications.*

import Errors.*
Expand Down Expand Up @@ -91,6 +92,7 @@ class Objects(using Context @constructorOnly):
* ve ::= ObjectRef(class) // global object
* | OfClass(class, vs[outer], ctor, args, env) // instance of a class
* | OfArray(object[owner], regions)
* | BaseOrUnknownValue // Int, String, etc., and values without source
* | Fun(..., env) // value elements that can be contained in ValueSet
* vs ::= ValueSet(ve) // set of abstract values
* Bottom ::= ValueSet(Empty)
Expand Down Expand Up @@ -233,6 +235,11 @@ class Objects(using Context @constructorOnly):
case class ValueSet(values: ListSet[ValueElement]) extends Value:
def show(using Context) = values.map(_.show).mkString("[", ",", "]")

// Represents common base values like Int, String, etc.
// and also values loaded without source
case object BaseOrUnknownValue extends ValueElement:
def show(using Context): String = "BaseOrUnknownValue"

/** A cold alias which should not be used during initialization.
*
* Cold is not ValueElement since RefSet containing Cold is equivalent to Cold
Expand Down Expand Up @@ -602,6 +609,12 @@ class Objects(using Context @constructorOnly):
case (ValueSet(values), b : ValueElement) => ValueSet(values + b)
case (a : ValueElement, b : ValueElement) => ValueSet(ListSet(a, b))

def remove(b: Value): Value = (a, b) match
case (ValueSet(values1), b: ValueElement) => ValueSet(values1 - b)
case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1.removedAll(values2))
case (a: Ref, b: Ref) if a.equals(b) => Bottom
case _ => a

def widen(height: Int)(using Context): Value =
if height == 0 then Cold
else
Expand Down Expand Up @@ -630,12 +643,15 @@ class Objects(using Context @constructorOnly):
if baseClasses.isEmpty then a
else filterClass(baseClasses.head) // could have called ClassSymbol, but it does not handle OrType and AndType

// Filter the value according to a class symbol, and only leaves the sub-values
// which could represent an object of the given class
def filterClass(sym: Symbol)(using Context): Value =
if !sym.isClass then a
else
val klass = sym.asClass
a match
case Cold => Cold
case BaseOrUnknownValue => BaseOrUnknownValue
case ref: Ref => if ref.klass.isSubClass(klass) then ref else Bottom
case ValueSet(values) => values.map(v => v.filterClass(klass)).join
case arr: OfArray => if defn.ArrayClass.isSubClass(klass) then arr else Bottom
Expand Down Expand Up @@ -668,6 +684,13 @@ class Objects(using Context @constructorOnly):
case Bottom =>
Bottom

// Bottom arguments mean unreachable call
case _ if args.map(_.value).contains(Bottom) =>
Bottom

case BaseOrUnknownValue =>
BaseOrUnknownValue

case arr: OfArray =>
val target = resolve(defn.ArrayClass, meth)

Expand All @@ -686,7 +709,7 @@ class Objects(using Context @constructorOnly):
Bottom
else
// Array.length is OK
Bottom
BaseOrUnknownValue

case ref: Ref =>
val isLocal = !meth.owner.isClass
Expand All @@ -707,7 +730,7 @@ class Objects(using Context @constructorOnly):
arr
else if target.equals(defn.Predef_classOf) then
// Predef.classOf is a stub method in tasty and is replaced in backend
Bottom
BaseOrUnknownValue
else if target.hasSource then
val cls = target.owner.enclosingClass.asClass
val ddef = target.defTree.asInstanceOf[DefDef]
Expand All @@ -730,7 +753,7 @@ class Objects(using Context @constructorOnly):
}
}
else
Bottom
BaseOrUnknownValue
else if target.exists then
select(ref, target, receiver, needResolve = false)
else
Expand Down Expand Up @@ -798,7 +821,7 @@ class Objects(using Context @constructorOnly):
}
else
// no source code available
Bottom
BaseOrUnknownValue

case _ =>
report.warning("[Internal error] unexpected constructor call, meth = " + ctor + ", this = " + value + Trace.show, Trace.position)
Expand All @@ -818,6 +841,9 @@ class Objects(using Context @constructorOnly):
report.warning("Using cold alias", Trace.position)
Bottom

case BaseOrUnknownValue =>
BaseOrUnknownValue

case ref: Ref =>
val target = if needResolve then resolve(ref.klass, field) else field
if target.is(Flags.Lazy) then
Expand All @@ -826,7 +852,7 @@ class Objects(using Context @constructorOnly):
val rhs = target.defTree.asInstanceOf[ValDef].rhs
eval(rhs, ref, target.owner.asClass, cacheResult = true)
else
Bottom
BaseOrUnknownValue
else if target.exists then
def isNextFieldOfColonColon: Boolean = ref.klass == defn.ConsClass && target.name.toString == "next"
if target.isOneOf(Flags.Mutable) && !isNextFieldOfColonColon then
Expand All @@ -842,24 +868,24 @@ class Objects(using Context @constructorOnly):
Bottom
else
// initialization error, reported by the initialization checker
Bottom
BaseOrUnknownValue
else if ref.hasVal(target) then
ref.valValue(target)
else if ref.isObjectRef && ref.klass.hasSource then
report.warning("Access uninitialized field " + field.show + ". " + Trace.show, Trace.position)
Bottom
else
// initialization error, reported by the initialization checker
Bottom
BaseOrUnknownValue

else
if ref.klass.isSubClass(receiver.widenSingleton.classSymbol) then
report.warning("[Internal error] Unexpected resolution failure: ref.klass = " + ref.klass.show + ", field = " + field.show + Trace.show, Trace.position)
Bottom
else
// This is possible due to incorrect type cast.
// See tests/init/pos/Type.scala
Bottom
// This is possible due to incorrect type cast or accessing standard library objects
// See tests/init/pos/Type.scala / tests/init/warn/unapplySeq-implicit-arg2.scala
BaseOrUnknownValue

case fun: Fun =>
report.warning("[Internal error] unexpected tree in selecting a function, fun = " + fun.code.show + Trace.show, fun.code)
Expand All @@ -869,7 +895,7 @@ class Objects(using Context @constructorOnly):
report.warning("[Internal error] unexpected tree in selecting an array, array = " + arr.show + Trace.show, Trace.position)
Bottom

case Bottom =>
case Bottom => // TODO: add a value for packages?
if field.isStaticObject then accessObject(field.moduleClass.asClass)
else Bottom

Expand All @@ -895,7 +921,7 @@ class Objects(using Context @constructorOnly):
case Cold =>
report.warning("Assigning to cold aliases is forbidden. " + Trace.show, Trace.position)

case Bottom =>
case BaseOrUnknownValue | Bottom =>

case ValueSet(values) =>
values.foreach(ref => assign(ref, field, rhs, rhsTyp))
Expand Down Expand Up @@ -930,6 +956,9 @@ class Objects(using Context @constructorOnly):
report.warning("[Internal error] unexpected outer in instantiating a class, outer = " + outer.show + ", class = " + klass.show + ", " + Trace.show, Trace.position)
Bottom

case BaseOrUnknownValue =>
BaseOrUnknownValue

case outer: (Ref | Cold.type | Bottom.type) =>
if klass == defn.ArrayClass then
args.head.tree.tpe match
Expand Down Expand Up @@ -1018,6 +1047,7 @@ class Objects(using Context @constructorOnly):
case Cold =>
report.warning("Calling cold by-name alias. " + Trace.show, Trace.position)
Bottom
case BaseOrUnknownValue => BaseOrUnknownValue
case _: ValueSet | _: Ref | _: OfArray =>
report.warning("[Internal error] Unexpected by-name value " + value.show + ". " + Trace.show, Trace.position)
Bottom
Expand Down Expand Up @@ -1206,7 +1236,7 @@ class Objects(using Context @constructorOnly):
evalType(expr.tpe, thisV, klass)

case Literal(_) =>
Bottom
BaseOrUnknownValue

case Typed(expr, tpt) =>
if tpt.tpe.hasAnnotation(defn.UncheckedAnnot) then
Expand Down Expand Up @@ -1341,29 +1371,25 @@ class Objects(using Context @constructorOnly):
def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)

def evalCase(caseDef: CaseDef): Value =
evalPattern(scrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
eval(caseDef.body, thisV, klass)

/** Abstract evaluation of patterns.
*
* It augments the local environment for bound pattern variables. As symbols are globally
* unique, we can put them in a single environment.
*
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
*/
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
def evalPattern(scrutinee: Value, pat: Tree): (Type, Value) = log("match " + scrutinee.show + " against " + pat.show, printer, (_: (Type, Value))._2.show):
val trace2 = Trace.trace.add(pat)
pat match
case Alternative(pats) =>
for pat <- pats do evalPattern(scrutinee, pat)
scrutinee
val (types, values) = pats.map(evalPattern(scrutinee, _)).unzip()
val orType = types.fold(defn.NothingType)(OrType(_, _, false))
(orType, values.join)

case bind @ Bind(_, pat) =>
val value = evalPattern(scrutinee, pat)
val (tpe, value) = evalPattern(scrutinee, pat)
initLocal(bind.symbol, value)
scrutinee
(tpe, value)

case UnApply(fun, implicits, pats) =>
given Trace = trace2
Expand All @@ -1372,6 +1398,10 @@ class Objects(using Context @constructorOnly):
val funRef = fun1.tpe.asInstanceOf[TermRef]
val unapplyResTp = funRef.widen.finalResultType

val receiverType = fun1 match
case ident: Ident => funRef.prefix
case select: Select => select.qualifier.tpe

val receiver = fun1 match
case ident: Ident =>
evalType(funRef.prefix, thisV, klass)
Expand Down Expand Up @@ -1460,17 +1490,20 @@ class Objects(using Context @constructorOnly):
end if
end if
end if
scrutinee
// TODO: receiverType is the companion object type, not the class itself;
// cannot filter scritunee by this type
(receiverType, scrutinee)

case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) =>
scrutinee
(defn.ThrowableType, scrutinee)

case Typed(pat, _) =>
evalPattern(scrutinee, pat)
case Typed(pat, typeTree) =>
val (_, value) = evalPattern(scrutinee.filterType(typeTree.tpe), pat)
(typeTree.tpe, value)

case tree =>
// For all other trees, the semantics is normal.
eval(tree, thisV, klass)
(defn.ThrowableType, eval(tree, thisV, klass))

end evalPattern

Expand All @@ -1481,15 +1514,15 @@ class Objects(using Context @constructorOnly):
// call .lengthCompare or .length
val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
if lengthCompareDenot.exists then
call(scrutinee, lengthCompareDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
call(scrutinee, lengthCompareDenot.symbol, ArgInfo(BaseOrUnknownValue, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
else
val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
call(scrutinee, lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
end if

// call .apply
val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
val applyRes = call(scrutinee, applyDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
val applyRes = call(scrutinee, applyDenot.symbol, ArgInfo(BaseOrUnknownValue, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)

if isWildcardStarArgList(pats) then
if pats.size == 1 then
Expand All @@ -1500,7 +1533,7 @@ class Objects(using Context @constructorOnly):
else
// call .drop
val dropDenot = getMemberMethod(scrutineeType, nme.drop, dropType(elemType))
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(BaseOrUnknownValue, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
for pat <- pats.init do evalPattern(applyRes, pat)
evalPattern(dropRes, pats.last)
end if
Expand All @@ -1510,8 +1543,21 @@ class Objects(using Context @constructorOnly):
end if
end evalSeqPatterns

def canSkipCase(remainingScrutinee: Value, catchValue: Value) =
(remainingScrutinee == Bottom && scrutinee != Bottom) ||
(catchValue == Bottom && remainingScrutinee != Bottom)

cases.map(evalCase).join
var remainingScrutinee = scrutinee
val caseResults: mutable.ArrayBuffer[Value] = mutable.ArrayBuffer()
for caseDef <- cases do
val (tpe, value) = evalPattern(remainingScrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
if !canSkipCase(remainingScrutinee, value) then
caseResults.addOne(eval(caseDef.body, thisV, klass))
if catchesAllOf(caseDef, tpe) then
remainingScrutinee = remainingScrutinee.remove(value)

caseResults.join
end patternMatch

/** Handle semantics of leaf nodes
Expand All @@ -1529,12 +1575,12 @@ class Objects(using Context @constructorOnly):
def evalType(tp: Type, thisV: ThisValue, klass: ClassSymbol, elideObjectAccess: Boolean = false): Contextual[Value] = log("evaluating " + tp.show, printer, (_: Value).show) {
tp match
case _: ConstantType =>
Bottom
BaseOrUnknownValue

case tmref: TermRef if tmref.prefix == NoPrefix =>
val sym = tmref.symbol
if sym.is(Flags.Package) then
Bottom
Bottom // TODO: package value?
else if sym.owner.isClass then
// The typer incorrectly assigns a TermRef with NoPrefix for `config`,
// while the actual denotation points to the symbol of the class member
Expand Down Expand Up @@ -1768,6 +1814,7 @@ class Objects(using Context @constructorOnly):
else
thisV match
case Bottom => Bottom
case BaseOrUnknownValue => BaseOrUnknownValue
case Cold => Cold
case ref: Ref =>
val outerCls = klass.owner.lexicallyEnclosingClass.asClass
Expand Down
Loading