From 9445d164334a403eeff0e926e57c8af09d18b7dc Mon Sep 17 00:00:00 2001 From: noti0na1 Date: Tue, 23 Nov 2021 14:14:07 -0500 Subject: [PATCH] Relax overriding by stripping nulls deeply --- .../tools/dotc/core/NullOpsDecorator.scala | 77 +++++++++++++------ .../src/dotty/tools/dotc/core/Types.scala | 4 +- .../dotc/transform/OverridingPairs.scala | 24 +++--- .../tools/dotc/transform/ResolveSuper.scala | 13 ++-- .../explicit-nulls/neg/opaque-nullable.scala | 47 +++++++++++ .../explicit-nulls/pos/opaque-nullable.scala | 26 ------- .../pos/override-type-params.scala | 18 +++++ 7 files changed, 143 insertions(+), 66 deletions(-) create mode 100644 tests/explicit-nulls/neg/opaque-nullable.scala delete mode 100644 tests/explicit-nulls/pos/opaque-nullable.scala create mode 100644 tests/explicit-nulls/pos/override-type-params.scala diff --git a/compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala b/compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala index d0799ca89d24..9d39186430ca 100644 --- a/compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala +++ b/compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala @@ -9,6 +9,41 @@ import Types._ /** Defines operations on nullable types and tree. */ object NullOpsDecorator: + private class StripNullsMap(isDeep: Boolean)(using Context) extends TypeMap: + def strip(tp: Type): Type = tp match + case tp @ OrType(lhs, rhs) => + val llhs = this(lhs) + val rrhs = this(rhs) + if rrhs.isNullType then llhs + else if llhs.isNullType then rrhs + else derivedOrType(tp, llhs, rrhs) + case tp @ AndType(tp1, tp2) => + // We cannot `tp.derivedAndType(strip(tp1), strip(tp2))` directly, + // since `stripNull((A | Null) & B)` would produce the wrong + // result `(A & B) | Null`. + val tp1s = this(tp1) + val tp2s = this(tp2) + if isDeep || (tp1s ne tp1) && (tp2s ne tp2) then + derivedAndType(tp, tp1s, tp2s) + else tp + case tp: TypeBounds => + mapOver(tp) + case _ => tp + + def stripOver(tp: Type): Type = tp match + case appTp @ AppliedType(tycon, targs) => + derivedAppliedType(appTp, tycon, targs.map(this)) + case ptp: PolyType => + derivedLambdaType(ptp)(ptp.paramInfos, this(ptp.resType)) + case mtp: MethodType => + mapOver(mtp) + case _ => strip(tp) + + override def apply(tp: Type): Type = + if isDeep then stripOver(tp) else strip(tp) + + end StripNullsMap + extension (self: Type) /** Syntactically strips the nullability from this type. * If the type is `T1 | ... | Tn`, and `Ti` references to `Null`, @@ -17,31 +52,11 @@ object NullOpsDecorator: * The type will not be changed if explicit-nulls is not enabled. */ def stripNull(using Context): Type = { - def strip(tp: Type): Type = - val tpWiden = tp.widenDealias - val tpStripped = tpWiden match { - case tp @ OrType(lhs, rhs) => - val llhs = strip(lhs) - val rrhs = strip(rhs) - if rrhs.isNullType then llhs - else if llhs.isNullType then rrhs - else tp.derivedOrType(llhs, rrhs) - case tp @ AndType(tp1, tp2) => - // We cannot `tp.derivedAndType(strip(tp1), strip(tp2))` directly, - // since `stripNull((A | Null) & B)` would produce the wrong - // result `(A & B) | Null`. - val tp1s = strip(tp1) - val tp2s = strip(tp2) - if (tp1s ne tp1) && (tp2s ne tp2) then - tp.derivedAndType(tp1s, tp2s) - else tp - case tp @ TypeBounds(lo, hi) => - tp.derivedTypeBounds(strip(lo), strip(hi)) - case tp => tp - } - if tpStripped ne tpWiden then tpStripped else tp - - if ctx.explicitNulls then strip(self) else self + if ctx.explicitNulls then + val selfw = self.widenDealias + val selfws = new StripNullsMap(false)(selfw) + if selfws ne selfw then selfws else self + else self } /** Is self (after widening and dealiasing) a type of the form `T | Null`? */ @@ -49,6 +64,18 @@ object NullOpsDecorator: val stripped = self.stripNull stripped ne self } + + /** Strips nulls from this type deeply. + * Compaired to `stripNull`, `stripNullsDeep` will apply `stripNull` to + * each member of function types as well. + */ + def stripNullsDeep(using Context): Type = + if ctx.explicitNulls then + val selfw = self.widenDealias + val selfws = new StripNullsMap(true)(selfw) + if selfws ne selfw then selfws else self + else self + end extension import ast.tpd._ diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index b384cbdcb084..2a310fbef76f 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1112,8 +1112,10 @@ object Types { */ def matches(that: Type)(using Context): Boolean = { record("matches") + val thisTp1 = this.stripNullsDeep + val thatTp1 = that.stripNullsDeep withoutMode(Mode.SafeNulls)( - TypeComparer.matchesType(this, that, relaxed = !ctx.phase.erasedTypes)) + TypeComparer.matchesType(thisTp1, thatTp1, relaxed = !ctx.phase.erasedTypes)) } /** This is the same as `matches` except that it also matches => T with T and diff --git a/compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala b/compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala index 437dfea9f156..28ed037405b2 100644 --- a/compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala +++ b/compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala @@ -5,6 +5,7 @@ package transform import core._ import Flags._, Symbols._, Contexts._, Scopes._, Decorators._, Types.Type import NameKinds.DefaultGetterName +import NullOpsDecorator._ import collection.mutable import collection.immutable.BitSet import scala.annotation.tailrec @@ -215,15 +216,20 @@ object OverridingPairs: } ) else - // releaxed override check for explicit nulls if one of the symbols is Java defined, - // force `Null` being a subtype of reference types during override checking - val relaxedCtxForNulls = + def matchNullaryLoosely = member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack + // default getters are not checked for compatibility + member.name.is(DefaultGetterName) || { if ctx.explicitNulls && (member.is(JavaDefined) || other.is(JavaDefined)) then - ctx.retractMode(Mode.SafeNulls) - else ctx - member.name.is(DefaultGetterName) // default getters are not checked for compatibility - || memberTp.overrides(otherTp, - member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack - )(using relaxedCtxForNulls) + // releaxed override check for explicit nulls if one of the symbols is Java defined, + // force `Null` being a subtype of reference types during override checking. + // `stripNullsDeep` is used here because we may encounter type parameters + // (`T | Null` is not a subtype of `T` even if we retract Mode.SafeNulls). + val memberTp1 = memberTp.stripNullsDeep + val otherTp1 = otherTp.stripNullsDeep + withoutMode(Mode.SafeNulls)( + memberTp1.overrides(otherTp1, matchNullaryLoosely)) + else + memberTp.overrides(otherTp, matchNullaryLoosely) + } end OverridingPairs diff --git a/compiler/src/dotty/tools/dotc/transform/ResolveSuper.scala b/compiler/src/dotty/tools/dotc/transform/ResolveSuper.scala index 40c860bf3bdc..8b5b6dbdd50a 100644 --- a/compiler/src/dotty/tools/dotc/transform/ResolveSuper.scala +++ b/compiler/src/dotty/tools/dotc/transform/ResolveSuper.scala @@ -13,6 +13,7 @@ import Names._ import StdNames._ import NameOps._ import NameKinds._ +import NullOpsDecorator._ import ResolveSuper._ import reporting.IllegalSuperAccessor @@ -110,11 +111,13 @@ object ResolveSuper { // Since the super class can be Java defined, // we use releaxed overriding check for explicit nulls if one of the symbols is Java defined. // This forces `Null` being a subtype of reference types during override checking. - val relaxedCtxForNulls = - if ctx.explicitNulls && (sym.is(JavaDefined) || acc.is(JavaDefined)) then - ctx.retractMode(Mode.SafeNulls) - else ctx - if (!(otherTp.overrides(accTp, matchLoosely = true)(using relaxedCtxForNulls))) + val overridesSuper = if ctx.explicitNulls && (sym.is(JavaDefined) || acc.is(JavaDefined)) then + val otherTp1 = otherTp.stripNullsDeep + val accTp1 = accTp.stripNullsDeep + withoutMode(Mode.SafeNulls)(otherTp1.overrides(accTp1, matchLoosely = true)) + else + otherTp.overrides(accTp, matchLoosely = true) + if !overridesSuper then report.error(IllegalSuperAccessor(base, memberName, targetName, acc, accTp, other.symbol, otherTp), base.srcPos) bcs = bcs.tail diff --git a/tests/explicit-nulls/neg/opaque-nullable.scala b/tests/explicit-nulls/neg/opaque-nullable.scala new file mode 100644 index 000000000000..1d3a22249d29 --- /dev/null +++ b/tests/explicit-nulls/neg/opaque-nullable.scala @@ -0,0 +1,47 @@ +// Unboxed option type using unions + null + opaque. +// Relies on the fact that Null is not a subtype of AnyRef. +// Test suggested by Sébastien Doeraene. + +object Nullables { + opaque type Nullable[+A <: AnyRef] = A | Null // disjoint by construction! + + object Nullable: + def apply[A <: AnyRef](x: A | Null): Nullable[A] = x + + def some[A <: AnyRef](x: A): Nullable[A] = x + def none: Nullable[Nothing] = null + + extension [A <: AnyRef](x: Nullable[A]) + def isEmpty: Boolean = x == null + def get: A | Null = x + + extension [A <: AnyRef, B <: AnyRef](x: Nullable[A]) + def flatMap(f: A => Nullable[B]): Nullable[B] = + if (x == null) null + else f(x) + + def map(f: A => B): Nullable[B] = x.flatMap(f) + + def test1 = + val s1: Nullable[String] = Nullable("hello") + val s2: Nullable[String] = "world" + val s3: Nullable[String] = Nullable.none + val s4: Nullable[String] = null + + s1.isEmpty + s1.flatMap((x) => true) + + assert(s2 != null) +} + +def test2 = + import Nullables._ + + val s1: Nullable[String] = Nullable("hello") + val s2: Nullable[String] = Nullable.none + val s3: Nullable[String] = null // error: don't leak nullable union + + s1.isEmpty + s1.flatMap((x) => Nullable(true)) + + assert(s2 == null) // error diff --git a/tests/explicit-nulls/pos/opaque-nullable.scala b/tests/explicit-nulls/pos/opaque-nullable.scala deleted file mode 100644 index a7f626054ad3..000000000000 --- a/tests/explicit-nulls/pos/opaque-nullable.scala +++ /dev/null @@ -1,26 +0,0 @@ -// Unboxed option type using unions + null + opaque. -// Relies on the fact that Null is not a subtype of AnyRef. -// Test suggested by Sébastien Doeraene. - -opaque type Nullable[+A <: AnyRef] = A | Null // disjoint by construction! - -object Nullable { - def apply[A <: AnyRef](x: A | Null): Nullable[A] = x - - def some[A <: AnyRef](x: A): Nullable[A] = x - def none: Nullable[Nothing] = null - - extension [A <: AnyRef](x: Nullable[A]) - def isEmpty: Boolean = x == null - - extension [A <: AnyRef, B <: AnyRef](x: Nullable[A]) - def flatMap(f: A => Nullable[B]): Nullable[B] = - if (x == null) null - else f(x) - - val s1: Nullable[String] = "hello" - val s2: Nullable[String] = null - - s1.isEmpty - s1.flatMap((x) => true) -} diff --git a/tests/explicit-nulls/pos/override-type-params.scala b/tests/explicit-nulls/pos/override-type-params.scala new file mode 100644 index 000000000000..7f59409a4c3c --- /dev/null +++ b/tests/explicit-nulls/pos/override-type-params.scala @@ -0,0 +1,18 @@ +// Testing relaxed overriding check for explicit nulls. +// The relaxed check is only enabled if one of the members is Java defined. + +import java.util.Comparator + +class C1[T <: AnyRef] extends Ordering[T]: + override def compare(o1: T, o2: T): Int = 0 + +// The following overriding is not allowed, because `compare` +// has already been declared in Scala class `Ordering`. +// class C2[T <: AnyRef] extends Ordering[T]: +// override def compare(o1: T | Null, o2: T | Null): Int = 0 + +class D1[T <: AnyRef] extends Comparator[T]: + override def compare(o1: T, o2: T): Int = 0 + +class D2[T <: AnyRef] extends Comparator[T]: + override def compare(o1: T | Null, o2: T | Null): Int = 0