Skip to content

Commit

Permalink
Merge pull request scala#15350 from dotty-staging/fix-check-ctx
Browse files Browse the repository at this point in the history
Fix checking ctx to carry correct modes
  • Loading branch information
odersky authored Jun 21, 2022
2 parents de3a82c + 6095a12 commit 0059d1d
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 28 deletions.
17 changes: 11 additions & 6 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1010,12 +1010,17 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

/** `tree ne null` (might need a cast to be type correct) */
def testNotNull(using Context): Tree = {
val receiver = if (tree.tpe.isBottomType)
// If the receiver is of type `Nothing` or `Null`, add an ascription so that the selection
// succeeds: e.g. `null.ne(null)` doesn't type, but `(null: AnyRef).ne(null)` does.
Typed(tree, TypeTree(defn.AnyRefType))
else tree.ensureConforms(defn.ObjectType)
receiver.select(defn.Object_ne).appliedTo(nullLiteral).withSpan(tree.span)
// If the receiver is of type `Nothing` or `Null`, add an ascription or cast
// so that the selection succeeds.
// e.g. `null.ne(null)` doesn't type, but `(null: AnyRef).ne(null)` does.
val receiver =
if tree.tpe.isBottomType then
if ctx.explicitNulls then tree.cast(defn.AnyRefType)
else Typed(tree, TypeTree(defn.AnyRefType))
else tree.ensureConforms(defn.ObjectType)
// also need to cast the null literal to AnyRef in explicit nulls
val nullLit = if ctx.explicitNulls then nullLiteral.cast(defn.AnyRefType) else nullLiteral
receiver.select(defn.Object_ne).appliedTo(nullLit).withSpan(tree.span)
}

/** If inititializer tree is `_`, the default value of its type,
Expand Down
10 changes: 9 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class Erasure extends Phase with DenotTransformer {
override def checkPostCondition(tree: tpd.Tree)(using Context): Unit = {
assertErased(tree)
tree match {
case _: tpd.Import => assert(false, i"illegal tree: $tree")
case res: tpd.This =>
assert(!ExplicitOuter.referencesOuter(ctx.owner.lexicallyEnclosingClass, res),
i"Reference to $res from ${ctx.owner.showLocated}")
Expand Down Expand Up @@ -1034,14 +1035,21 @@ object Erasure {
typed(tree.arg, pt)

override def typedStats(stats: List[untpd.Tree], exprOwner: Symbol)(using Context): (List[Tree], Context) = {
val stats0 = addRetainedInlineBodies(stats)(using preErasureCtx)
// discard Imports first, since Bridges will use tree's symbol
val stats0 = addRetainedInlineBodies(stats.filter(!_.isInstanceOf[untpd.Import]))(using preErasureCtx)
val stats1 =
if (takesBridges(ctx.owner)) new Bridges(ctx.owner.asClass, erasurePhase).add(stats0)
else stats0
val (stats2, finalCtx) = super.typedStats(stats1, exprOwner)
(stats2.filterConserve(!_.isEmpty), finalCtx)
}

/** Finally drops all (language-) imports in erasure.
* Since some of the language imports change the subtyping,
* we cannot check the trees before erasure.
*/
override def typedImport(tree: untpd.Import)(using Context) = EmptyTree

override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree =
trace(i"adapting ${tree.showSummary()}: ${tree.tpe} to $pt", show = true) {
if ctx.phase != erasurePhase && ctx.phase != erasurePhase.next then
Expand Down
10 changes: 9 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/MixinOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Symbols._, Types._, Contexts._, DenotTransformers._, Flags._
import util.Spans._
import SymUtils._
import StdNames._, NameOps._
import typer.Nullables

class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
import ast.tpd._
Expand Down Expand Up @@ -80,13 +81,20 @@ class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
prefss =>
val (targs, vargss) = splitArgs(prefss)
val tapp = superRef(target).appliedToTypeTrees(targs)
vargss match
val rhs = vargss match
case Nil | List(Nil) =>
// Overriding is somewhat loose about `()T` vs `=> T`, so just pick
// whichever makes sense for `target`
tapp.ensureApplied
case _ =>
tapp.appliedToArgss(vargss)
if ctx.explicitNulls && target.is(JavaDefined) && !ctx.phase.erasedTypes then
// We may forward to a super Java member in resolveSuper phase.
// Since this is still before erasure, the type can be nullable
// and causes error during checking. So we need to enable
// unsafe-nulls to construct the rhs.
Block(Nullables.importUnsafeNulls :: Nil, rhs)
else rhs

private def competingMethodsIterator(meth: Symbol): Iterator[Symbol] =
cls.baseClasses.iterator
Expand Down
9 changes: 0 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import Decorators.*
* The phase also replaces all expressions that appear in an erased context by
* default values. This is necessary so that subsequent checking phases such
* as IsInstanceOfChecker don't give false negatives.
* Finally, the phase drops (language-) imports.
*/
class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform =>
import tpd._
Expand Down Expand Up @@ -56,18 +55,10 @@ class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform =>
checkErasedInExperimental(tree.symbol)
tree

override def transformOther(tree: Tree)(using Context): Tree = tree match
case tree: Import => EmptyTree
case _ => tree

def checkErasedInExperimental(sym: Symbol)(using Context): Unit =
// Make an exception for Scala 2 experimental macros to allow dual Scala 2/3 macros under non experimental mode
if sym.is(Erased, butNot = Macro) && sym != defn.Compiletime_erasedValue && !sym.isInExperimentalScope then
Feature.checkExperimentalFeature("erased", sym.sourcePos)

override def checkPostCondition(tree: Tree)(using Context): Unit = tree match
case _: tpd.Import => assert(false, i"illegal tree: $tree")
case _ =>
}

object PruneErasedDefs {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class SyntheticMembers(thisPhase: DenotTransformer) {

def nameRef: Tree =
if isJavaEnumValue then
Select(This(clazz), nme.name).ensureApplied
val name = Select(This(clazz), nme.name).ensureApplied
if ctx.explicitNulls then name.cast(defn.StringType) else name
else
identifierRef

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class TreeChecker extends Phase with SymTransformer {

val checkingCtx = ctx
.fresh
.setMode(Mode.ImplicitsEnabled)
.addMode(Mode.ImplicitsEnabled)
.setReporter(new ThrowingReporter(ctx.reporter))

val checker = inContext(ctx) {
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ import ast.Trees.mods
object Nullables:
import ast.tpd._

def importUnsafeNulls(using Context): Import = Import(
ref(defn.LanguageModule),
List(untpd.ImportSelector(untpd.Ident(nme.unsafeNulls), EmptyTree, EmptyTree)))

inline def unsafeNullsEnabled(using Context): Boolean =
ctx.explicitNulls && !ctx.mode.is(Mode.SafeNulls)

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ReTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ReTyper(nestingLevel: Int = 0) extends Typer(nestingLevel) with ReChecking
override def typedSuper(tree: untpd.Super, pt: Type)(using Context): Tree =
promote(tree)

override def typedImport(tree: untpd.Import, sym: Symbol)(using Context): Tree =
override def typedImport(tree: untpd.Import)(using Context): Tree =
promote(tree)

override def typedTyped(tree: untpd.Typed, pt: Type)(using Context): Tree = {
Expand Down
18 changes: 15 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val (stats1, exprCtx) = withoutMode(Mode.Pattern) {
typedBlockStats(tree.stats)
}
val expr1 = typedExpr(tree.expr, pt.dropIfProto)(using exprCtx)
var expr1 = typedExpr(tree.expr, pt.dropIfProto)(using exprCtx)

// If unsafe nulls is enabled inside a block but not enabled outside
// and the type does not conform the expected type without unsafe nulls,
// we will cast the last expression to the expected type.
// See: tests/explicit-nulls/pos/unsafe-block.scala
if ctx.mode.is(Mode.SafeNulls)
&& !exprCtx.mode.is(Mode.SafeNulls)
&& pt.isValueType
&& !inContext(exprCtx.addMode(Mode.SafeNulls))(expr1.tpe <:< pt) then
expr1 = expr1.cast(pt)

ensureNoLocalRefs(
cpy.Block(tree)(stats1, expr1)
.withType(expr1.tpe)
Expand Down Expand Up @@ -2602,7 +2613,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
|The selector is not a member of an object or package.""")
else typd(imp.expr, AnySelectionProto)

def typedImport(imp: untpd.Import, sym: Symbol)(using Context): Tree =
def typedImport(imp: untpd.Import)(using Context): Tree =
val sym = retrieveSym(imp)
val expr1 = typedImportQualifier(imp, typedExpr(_, _)(using ctx.withOwner(sym)))
checkLegalImportPath(expr1)
val selectors1 = typedSelectors(imp.selectors)
Expand Down Expand Up @@ -2869,7 +2881,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
case tree: untpd.If => typedIf(tree, pt)
case tree: untpd.Function => typedFunction(tree, pt)
case tree: untpd.Closure => typedClosure(tree, pt)
case tree: untpd.Import => typedImport(tree, retrieveSym(tree))
case tree: untpd.Import => typedImport(tree)
case tree: untpd.Export => typedExport(tree)
case tree: untpd.Match => typedMatch(tree, pt)
case tree: untpd.Return => typedReturn(tree)
Expand Down
12 changes: 12 additions & 0 deletions tests/explicit-nulls/pos/enums.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMessageID]:

case NoExplanationID // errorNumber: -1
case EmptyCatchOrFinallyBlockID extends ErrorMessageID(isActive = false) // errorNumber: 0

def errorNumber = ordinal - 1

enum Color(val rgb: Int):
case Red extends Color(0xFF0000)
case Green extends Color(0x00FF00)
case Blue extends Color(0x0000FF)

5 changes: 5 additions & 0 deletions tests/explicit-nulls/pos/test-not-null.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// testNotNull can be inserted during PatternMatcher
def f(xs: List[String]) =
xs.zipWithIndex.collect {
case (arg, idx) => idx
}
67 changes: 67 additions & 0 deletions tests/explicit-nulls/pos/unsafe-block.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
def trim(x: String | Null): String =
import scala.language.unsafeNulls
// The type of `x.trim()` is `String | Null`.
// Although `String | Null` conforms the expected type `String`,
// we still need to cast the expression to the expected type here,
// because outside the scope we don't have `unsafeNulls` anymore.
x.trim()

class TestDefs:

def f1: String | Null = null
def f2: Array[String | Null] | Null = null
def f3: Array[String] | Null = null

def h1a: String =
import scala.language.unsafeNulls
f1

def h1b: String | Null =
import scala.language.unsafeNulls
f1

def h2a: Array[String] =
import scala.language.unsafeNulls
f2

def h2b: Array[String | Null] =
import scala.language.unsafeNulls
f2

def h3a: Array[String] =
import scala.language.unsafeNulls
f3

def h3b: Array[String | Null] =
import scala.language.unsafeNulls
f3

class TestVals:

val f1: String | Null = null
val f2: Array[String | Null] | Null = null
val f3: Array[String] | Null = null

val h1a: String =
import scala.language.unsafeNulls
f1

val h1b: String | Null =
import scala.language.unsafeNulls
f1

val h2a: Array[String] =
import scala.language.unsafeNulls
f2

val h2b: Array[String | Null] =
import scala.language.unsafeNulls
f2

val h3a: Array[String] =
import scala.language.unsafeNulls
f3

val h3b: Array[String | Null] =
import scala.language.unsafeNulls
f3
21 changes: 16 additions & 5 deletions tests/explicit-nulls/pos/unsafe-chain.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import java.nio.file.FileSystems
import java.util.ArrayList

def directorySeparator: String =
import scala.language.unsafeNulls
FileSystems.getDefault().getSeparator()
class A:

def directorySeparator: String =
import scala.language.unsafeNulls
FileSystems.getDefault().getSeparator()

def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
import scala.language.unsafeNulls
xs.get(0).get(0).get(0)

def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
class B:
import scala.language.unsafeNulls
xs.get(0).get(0).get(0)

def directorySeparator: String =
FileSystems.getDefault().getSeparator()

def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
xs.get(0).get(0).get(0)

0 comments on commit 0059d1d

Please sign in to comment.