Skip to content

Commit

Permalink
Relax overriding by stripping nulls deeply
Browse files Browse the repository at this point in the history
  • Loading branch information
noti0na1 committed Nov 23, 2021
1 parent 2ef89b2 commit 9445d16
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 66 deletions.
77 changes: 52 additions & 25 deletions compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,41 @@ import Types._
/** Defines operations on nullable types and tree. */
object NullOpsDecorator:

private class StripNullsMap(isDeep: Boolean)(using Context) extends TypeMap:
def strip(tp: Type): Type = tp match
case tp @ OrType(lhs, rhs) =>
val llhs = this(lhs)
val rrhs = this(rhs)
if rrhs.isNullType then llhs
else if llhs.isNullType then rrhs
else derivedOrType(tp, llhs, rrhs)
case tp @ AndType(tp1, tp2) =>
// We cannot `tp.derivedAndType(strip(tp1), strip(tp2))` directly,
// since `stripNull((A | Null) & B)` would produce the wrong
// result `(A & B) | Null`.
val tp1s = this(tp1)
val tp2s = this(tp2)
if isDeep || (tp1s ne tp1) && (tp2s ne tp2) then
derivedAndType(tp, tp1s, tp2s)
else tp
case tp: TypeBounds =>
mapOver(tp)
case _ => tp

def stripOver(tp: Type): Type = tp match
case appTp @ AppliedType(tycon, targs) =>
derivedAppliedType(appTp, tycon, targs.map(this))
case ptp: PolyType =>
derivedLambdaType(ptp)(ptp.paramInfos, this(ptp.resType))
case mtp: MethodType =>
mapOver(mtp)
case _ => strip(tp)

override def apply(tp: Type): Type =
if isDeep then stripOver(tp) else strip(tp)

end StripNullsMap

extension (self: Type)
/** Syntactically strips the nullability from this type.
* If the type is `T1 | ... | Tn`, and `Ti` references to `Null`,
Expand All @@ -17,38 +52,30 @@ object NullOpsDecorator:
* The type will not be changed if explicit-nulls is not enabled.
*/
def stripNull(using Context): Type = {
def strip(tp: Type): Type =
val tpWiden = tp.widenDealias
val tpStripped = tpWiden match {
case tp @ OrType(lhs, rhs) =>
val llhs = strip(lhs)
val rrhs = strip(rhs)
if rrhs.isNullType then llhs
else if llhs.isNullType then rrhs
else tp.derivedOrType(llhs, rrhs)
case tp @ AndType(tp1, tp2) =>
// We cannot `tp.derivedAndType(strip(tp1), strip(tp2))` directly,
// since `stripNull((A | Null) & B)` would produce the wrong
// result `(A & B) | Null`.
val tp1s = strip(tp1)
val tp2s = strip(tp2)
if (tp1s ne tp1) && (tp2s ne tp2) then
tp.derivedAndType(tp1s, tp2s)
else tp
case tp @ TypeBounds(lo, hi) =>
tp.derivedTypeBounds(strip(lo), strip(hi))
case tp => tp
}
if tpStripped ne tpWiden then tpStripped else tp

if ctx.explicitNulls then strip(self) else self
if ctx.explicitNulls then
val selfw = self.widenDealias
val selfws = new StripNullsMap(false)(selfw)
if selfws ne selfw then selfws else self
else self
}

/** Is self (after widening and dealiasing) a type of the form `T | Null`? */
def isNullableUnion(using Context): Boolean = {
val stripped = self.stripNull
stripped ne self
}

/** Strips nulls from this type deeply.
* Compaired to `stripNull`, `stripNullsDeep` will apply `stripNull` to
* each member of function types as well.
*/
def stripNullsDeep(using Context): Type =
if ctx.explicitNulls then
val selfw = self.widenDealias
val selfws = new StripNullsMap(true)(selfw)
if selfws ne selfw then selfws else self
else self

end extension

import ast.tpd._
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1112,8 +1112,10 @@ object Types {
*/
def matches(that: Type)(using Context): Boolean = {
record("matches")
val thisTp1 = this.stripNullsDeep
val thatTp1 = that.stripNullsDeep
withoutMode(Mode.SafeNulls)(
TypeComparer.matchesType(this, that, relaxed = !ctx.phase.erasedTypes))
TypeComparer.matchesType(thisTp1, thatTp1, relaxed = !ctx.phase.erasedTypes))
}

/** This is the same as `matches` except that it also matches => T with T and
Expand Down
24 changes: 15 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package transform
import core._
import Flags._, Symbols._, Contexts._, Scopes._, Decorators._, Types.Type
import NameKinds.DefaultGetterName
import NullOpsDecorator._
import collection.mutable
import collection.immutable.BitSet
import scala.annotation.tailrec
Expand Down Expand Up @@ -215,15 +216,20 @@ object OverridingPairs:
}
)
else
// releaxed override check for explicit nulls if one of the symbols is Java defined,
// force `Null` being a subtype of reference types during override checking
val relaxedCtxForNulls =
def matchNullaryLoosely = member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack
// default getters are not checked for compatibility
member.name.is(DefaultGetterName) || {
if ctx.explicitNulls && (member.is(JavaDefined) || other.is(JavaDefined)) then
ctx.retractMode(Mode.SafeNulls)
else ctx
member.name.is(DefaultGetterName) // default getters are not checked for compatibility
|| memberTp.overrides(otherTp,
member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack
)(using relaxedCtxForNulls)
// releaxed override check for explicit nulls if one of the symbols is Java defined,
// force `Null` being a subtype of reference types during override checking.
// `stripNullsDeep` is used here because we may encounter type parameters
// (`T | Null` is not a subtype of `T` even if we retract Mode.SafeNulls).
val memberTp1 = memberTp.stripNullsDeep
val otherTp1 = otherTp.stripNullsDeep
withoutMode(Mode.SafeNulls)(
memberTp1.overrides(otherTp1, matchNullaryLoosely))
else
memberTp.overrides(otherTp, matchNullaryLoosely)
}

end OverridingPairs
13 changes: 8 additions & 5 deletions compiler/src/dotty/tools/dotc/transform/ResolveSuper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Names._
import StdNames._
import NameOps._
import NameKinds._
import NullOpsDecorator._
import ResolveSuper._
import reporting.IllegalSuperAccessor

Expand Down Expand Up @@ -110,11 +111,13 @@ object ResolveSuper {
// Since the super class can be Java defined,
// we use releaxed overriding check for explicit nulls if one of the symbols is Java defined.
// This forces `Null` being a subtype of reference types during override checking.
val relaxedCtxForNulls =
if ctx.explicitNulls && (sym.is(JavaDefined) || acc.is(JavaDefined)) then
ctx.retractMode(Mode.SafeNulls)
else ctx
if (!(otherTp.overrides(accTp, matchLoosely = true)(using relaxedCtxForNulls)))
val overridesSuper = if ctx.explicitNulls && (sym.is(JavaDefined) || acc.is(JavaDefined)) then
val otherTp1 = otherTp.stripNullsDeep
val accTp1 = accTp.stripNullsDeep
withoutMode(Mode.SafeNulls)(otherTp1.overrides(accTp1, matchLoosely = true))
else
otherTp.overrides(accTp, matchLoosely = true)
if !overridesSuper then
report.error(IllegalSuperAccessor(base, memberName, targetName, acc, accTp, other.symbol, otherTp), base.srcPos)

bcs = bcs.tail
Expand Down
47 changes: 47 additions & 0 deletions tests/explicit-nulls/neg/opaque-nullable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Unboxed option type using unions + null + opaque.
// Relies on the fact that Null is not a subtype of AnyRef.
// Test suggested by Sébastien Doeraene.

object Nullables {
opaque type Nullable[+A <: AnyRef] = A | Null // disjoint by construction!

object Nullable:
def apply[A <: AnyRef](x: A | Null): Nullable[A] = x

def some[A <: AnyRef](x: A): Nullable[A] = x
def none: Nullable[Nothing] = null

extension [A <: AnyRef](x: Nullable[A])
def isEmpty: Boolean = x == null
def get: A | Null = x

extension [A <: AnyRef, B <: AnyRef](x: Nullable[A])
def flatMap(f: A => Nullable[B]): Nullable[B] =
if (x == null) null
else f(x)

def map(f: A => B): Nullable[B] = x.flatMap(f)

def test1 =
val s1: Nullable[String] = Nullable("hello")
val s2: Nullable[String] = "world"
val s3: Nullable[String] = Nullable.none
val s4: Nullable[String] = null

s1.isEmpty
s1.flatMap((x) => true)

assert(s2 != null)
}

def test2 =
import Nullables._

val s1: Nullable[String] = Nullable("hello")
val s2: Nullable[String] = Nullable.none
val s3: Nullable[String] = null // error: don't leak nullable union

s1.isEmpty
s1.flatMap((x) => Nullable(true))

assert(s2 == null) // error
26 changes: 0 additions & 26 deletions tests/explicit-nulls/pos/opaque-nullable.scala

This file was deleted.

18 changes: 18 additions & 0 deletions tests/explicit-nulls/pos/override-type-params.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Testing relaxed overriding check for explicit nulls.
// The relaxed check is only enabled if one of the members is Java defined.

import java.util.Comparator

class C1[T <: AnyRef] extends Ordering[T]:
override def compare(o1: T, o2: T): Int = 0

// The following overriding is not allowed, because `compare`
// has already been declared in Scala class `Ordering`.
// class C2[T <: AnyRef] extends Ordering[T]:
// override def compare(o1: T | Null, o2: T | Null): Int = 0

class D1[T <: AnyRef] extends Comparator[T]:
override def compare(o1: T, o2: T): Int = 0

class D2[T <: AnyRef] extends Comparator[T]:
override def compare(o1: T | Null, o2: T | Null): Int = 0

0 comments on commit 9445d16

Please sign in to comment.