diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 38fd65cb4c54..eee0e802528c 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2497,9 +2497,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def provablyDisjoint(tp1: Type, tp2: Type)(using Context): Boolean = trace(i"provable disjoint $tp1, $tp2", matchTypes) { // println(s"provablyDisjoint(${tp1.show}, ${tp2.show})") - def isEnumValueOrModule(ref: TermRef): Boolean = + def isEnumValue(ref: TermRef): Boolean = val sym = ref.termSymbol - sym.isAllOf(EnumCase, butNot=JavaDefined) || sym.is(Module) + sym.isAllOf(EnumCase, butNot=JavaDefined) + + def isEnumValueOrModule(ref: TermRef): Boolean = + isEnumValue(ref) || ref.termSymbol.is(Module) || (ref.info match { + case tp: TermRef => isEnumValueOrModule(tp) + case _ => false + }) /** Can we enumerate all instantiations of this type? */ def isClosedSum(tp: Symbol): Boolean = @@ -2512,8 +2518,21 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def decompose(sym: Symbol, tp: Type): List[Type] = sym.children.map(x => refineUsingParent(tp, x)).filter(_.exists) + def fullyInstantiated(tp: Type): Boolean = new TypeAccumulator[Boolean] { + override def apply(x: Boolean, t: Type) = + x && { + t match { + case tp: TypeRef if tp.symbol.isAbstractOrParamType => false + case _: SkolemType | _: TypeVar => false + case _ => foldOver(x, t) + } + } + }.apply(true, tp) + (tp1.dealias, tp2.dealias) match { - case (tp1: TypeRef, tp2: TypeRef) if tp1.symbol == defn.SingletonClass || tp2.symbol == defn.SingletonClass => + case (tp1: TypeRef, _) if tp1.symbol == defn.SingletonClass => + false + case (_, tp2: TypeRef) if tp2.symbol == defn.SingletonClass => false case (tp1: ConstantType, tp2: ConstantType) => tp1 != tp2 @@ -2557,21 +2576,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // doesn't have type tags, meaning that users cannot write patterns // that do type tests on higher kinded types. def invariantDisjoint(tp1: Type, tp2: Type, tparam: TypeParamInfo): Boolean = - covariantDisjoint(tp1, tp2, tparam) || !isSameType(tp1, tp2) && { - // We can only trust a "no" from `isSameType` when both - // `tp1` and `tp2` are fully instantiated. - def fullyInstantiated(tp: Type): Boolean = new TypeAccumulator[Boolean] { - override def apply(x: Boolean, t: Type) = - x && { - t match { - case tp: TypeRef if tp.symbol.isAbstractOrParamType => false - case _: SkolemType | _: TypeVar => false - case _ => foldOver(x, t) - } - } - }.apply(true, tp) - fullyInstantiated(tp1) && fullyInstantiated(tp2) - } + covariantDisjoint(tp1, tp2, tparam) || + !isSameType(tp1, tp2) && + fullyInstantiated(tp1) && // We can only trust a "no" from `isSameType` when + fullyInstantiated(tp2) // both `tp1` and `tp2` are fully instantiated. args1.lazyZip(args2).lazyZip(tycon1.typeParams).exists { (arg1, arg2, tparam) => @@ -2608,11 +2616,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling provablyDisjoint(tp1, gadtBounds(tp2.symbol).hi) || provablyDisjoint(tp1, tp2.superType) case (tp1: TermRef, tp2: TermRef) if isEnumValueOrModule(tp1) && isEnumValueOrModule(tp2) => tp1.termSymbol != tp2.termSymbol - case (tp1: TermRef, tp2: TypeRef) if isEnumValueOrModule(tp1) && !tp1.classSymbols.exists(_.derivesFrom(tp2.classSymbol)) => - // Note: enum values may have multiple parents - true - case (tp1: TypeRef, tp2: TermRef) if isEnumValueOrModule(tp2) && !tp2.classSymbols.exists(_.derivesFrom(tp1.classSymbol)) => - true + case (tp1: TermRef, tp2: TypeRef) if isEnumValue(tp1) => + fullyInstantiated(tp2) && !tp1.classSymbols.exists(_.derivesFrom(tp2.symbol)) + case (tp1: TypeRef, tp2: TermRef) if isEnumValue(tp2) => + fullyInstantiated(tp1) && !tp2.classSymbols.exists(_.derivesFrom(tp1.symbol)) case (tp1: Type, tp2: Type) if defn.isTupleType(tp1) => provablyDisjoint(tp1.toNestedPairs, tp2) case (tp1: Type, tp2: Type) if defn.isTupleType(tp2) => diff --git a/tests/neg/12549.scala b/tests/neg/12549.scala new file mode 100644 index 000000000000..4e2a5531efa0 --- /dev/null +++ b/tests/neg/12549.scala @@ -0,0 +1,18 @@ +enum Bool { + case True + case False +} + +import Bool.* + +type Not[B <: Bool] = B match { + case True.type => False.type + case False.type => True.type + case _ => "unreachable" +} + +def foo[B <: Bool & Singleton]: Unit = { + implicitly[Not[B] =:= "unreachable"] // error + + () +}