diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 385917f9b368..5b89c9bbacd1 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -141,9 +141,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] => loop(tree, Nil) /** All term arguments of an application in a single flattened list */ + def allTermArguments(tree: Tree): List[Tree] = unsplice(tree) match { + case Apply(fn, args) => allArguments(fn) ::: args + case TypeApply(fn, args) => allArguments(fn) + case Block(_, expr) => allArguments(expr) + case _ => Nil + } + + /** All type and term arguments of an application in a single flattened list */ def allArguments(tree: Tree): List[Tree] = unsplice(tree) match { case Apply(fn, args) => allArguments(fn) ::: args - case TypeApply(fn, _) => allArguments(fn) + case TypeApply(fn, args) => allArguments(fn) ::: args case Block(_, expr) => allArguments(expr) case _ => Nil } diff --git a/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala b/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala index 98d9a0ca85f6..c79e75895a7b 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala @@ -21,6 +21,9 @@ import Decorators.* * @param newOwners New owners, replacing previous owners. * @param substFrom The symbols that need to be substituted. * @param substTo The substitution targets. + * @param cpy A tree copier that is used to create new trees. + * @param alwaysCopySymbols If set, symbols are always copied, even when they + * are not impacted by the transformation. * * The reason the substitution is broken out from the rest of the type map is * that all symbols have to be substituted at the same time. If we do not do this, @@ -38,7 +41,9 @@ class TreeTypeMap( val newOwners: List[Symbol] = Nil, val substFrom: List[Symbol] = Nil, val substTo: List[Symbol] = Nil, - cpy: tpd.TreeCopier = tpd.cpy)(using Context) extends tpd.TreeMap(cpy) { + cpy: tpd.TreeCopier = tpd.cpy, + alwaysCopySymbols: Boolean = false, +)(using Context) extends tpd.TreeMap(cpy) { import tpd.* def copy( @@ -48,7 +53,7 @@ class TreeTypeMap( newOwners: List[Symbol], substFrom: List[Symbol], substTo: List[Symbol])(using Context): TreeTypeMap = - new TreeTypeMap(typeMap, treeMap, oldOwners, newOwners, substFrom, substTo) + new TreeTypeMap(typeMap, treeMap, oldOwners, newOwners, substFrom, substTo, cpy, alwaysCopySymbols) /** If `sym` is one of `oldOwners`, replace by corresponding symbol in `newOwners` */ def mapOwner(sym: Symbol): Symbol = sym.subst(oldOwners, newOwners) @@ -207,7 +212,7 @@ class TreeTypeMap( * between original and mapped symbols. */ def withMappedSyms(syms: List[Symbol]): TreeTypeMap = - withMappedSyms(syms, mapSymbols(syms, this)) + withMappedSyms(syms, mapSymbols(syms, this, mapAlways = alwaysCopySymbols)) /** The tree map with the substitution between originals `syms` * and mapped symbols `mapped`. Also goes into mapped classes @@ -229,6 +234,10 @@ class TreeTypeMap( tmap1 } + def withAlwaysCopySymbols: TreeTypeMap = + if alwaysCopySymbols then this + else new TreeTypeMap(typeMap, treeMap, oldOwners, newOwners, substFrom, substTo, cpy, alwaysCopySymbols = true) + override def toString = def showSyms(syms: List[Symbol]) = syms.map(sym => s"$sym#${sym.id}").mkString(", ") diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index b4cdeba4600b..08e206b1a850 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -3,8 +3,9 @@ package dotc package core import Symbols.*, Types.*, Contexts.*, Constants.*, Phases.* -import ast.tpd, tpd.* -import util.Spans.Span +import ast.{tpd, untpd, TreeTypeMap} +import tpd.* +import util.Spans.{Span, NoSpan} import printing.{Showable, Printer} import printing.Texts.Text @@ -30,8 +31,8 @@ object Annotations { def derivedAnnotation(tree: Tree)(using Context): Annotation = if (tree eq this.tree) this else Annotation(tree) - /** All arguments to this annotation in a single flat list */ - def arguments(using Context): List[Tree] = tpd.allArguments(tree) + /** All term arguments of this annotation in a single flat list */ + def arguments(using Context): List[Tree] = tpd.allTermArguments(tree) def argument(i: Int)(using Context): Option[Tree] = { val args = arguments @@ -54,18 +55,26 @@ object Annotations { * type, since ranges cannot be types of trees. */ def mapWith(tm: TypeMap)(using Context) = - val args = arguments + val args = tpd.allArguments(tree) if args.isEmpty then this else + // Checks if `tm` would result in any change by applying it to types + // inside the annotations' arguments and checking if the resulting types + // are different. val findDiff = new TreeAccumulator[Type]: def apply(x: Type, tree: Tree)(using Context): Type = if tm.isRange(x) then x else val tp1 = tm(tree.tpe) - foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree) + foldOver(if !tp1.exists || (tp1 frozen_=:= tree.tpe) then x else tp1, tree) val diff = findDiff(NoType, args) if tm.isRange(diff) then EmptyAnnotation - else if diff.exists then derivedAnnotation(tm.mapOver(tree)) + else if diff.exists then + // If the annotation has been transformed, we need to make sure that the + // symbol are copied so that we don't end up with the same symbol in different + // trees, which would lead to a crash in pickling. + val mappedTree = TreeTypeMap(typeMap = tm, alwaysCopySymbols = true).transform(tree) + derivedAnnotation(mappedTree) else this /** Does this annotation refer to a parameter of `tl`? */ diff --git a/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala index 86076517021a..3d8080e72a29 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala @@ -33,6 +33,7 @@ object PositionPickler: pickler: TastyPickler, addrOfTree: TreeToAddr, treeAnnots: untpd.MemberDef => List[tpd.Tree], + typeAnnots: List[tpd.Tree], relativePathReference: String, source: SourceFile, roots: List[Tree], @@ -136,6 +137,9 @@ object PositionPickler: } for (root <- roots) traverse(root, NoSource) + + for annotTree <- typeAnnots do + traverse(annotTree, NoSource) end picklePositions end PositionPickler diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index 6659348fb5de..7fd6444746ce 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -41,6 +41,10 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) { */ private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]() + /** A set of annotation trees appearing in annotated types. + */ + private val annotatedTypeTrees = mutable.ListBuffer[Tree]() + /** A map from member definitions to their doc comments, so that later * parallel comment pickling does not need to access symbols of trees (which * would involve accessing symbols of named types and possibly changing phases @@ -57,6 +61,8 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) { val ts = annotTrees.lookup(tree) if ts == null then Nil else ts.toList + def typeAnnots: List[Tree] = annotatedTypeTrees.toList + def docString(tree: untpd.MemberDef): Option[Comment] = Option(docStrings.lookup(tree)) @@ -278,6 +284,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) { case tpe: AnnotatedType => writeByte(ANNOTATEDtype) withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) } + annotatedTypeTrees += tpe.annot.tree case tpe: AndType => writeByte(ANDtype) withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) } diff --git a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala index 6d6e2ff01ad4..67a354919d5b 100644 --- a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala +++ b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala @@ -224,7 +224,7 @@ object PickledQuotes { if tree.span.exists then val positionWarnings = new mutable.ListBuffer[Message]() val reference = ctx.settings.sourceroot.value - PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference, + PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference, ctx.compilationUnit.source, tree :: Nil, positionWarnings) positionWarnings.foreach(report.warning(_)) diff --git a/compiler/src/dotty/tools/dotc/transform/Pickler.scala b/compiler/src/dotty/tools/dotc/transform/Pickler.scala index dd24f38990df..c8c071064ab8 100644 --- a/compiler/src/dotty/tools/dotc/transform/Pickler.scala +++ b/compiler/src/dotty/tools/dotc/transform/Pickler.scala @@ -322,7 +322,7 @@ class Pickler extends Phase { if tree.span.exists then val reference = ctx.settings.sourceroot.value PositionPickler.picklePositions( - pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference, + pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference, unit.source, tree :: Nil, positionWarnings, scratch.positionBuffer, scratch.pickledIndices) diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 0feee53ca50f..af8721d2a479 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -2,7 +2,7 @@ package dotty.tools package dotc package transform -import dotty.tools.dotc.ast.{Trees, tpd, untpd, desugar} +import dotty.tools.dotc.ast.{Trees, TreeTypeMap, tpd, untpd, desugar} import scala.collection.mutable import core.* import dotty.tools.dotc.typer.Checking @@ -158,7 +158,14 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => val saved = inJavaAnnot inJavaAnnot = annot.symbol.is(JavaDefined) if (inJavaAnnot) checkValidJavaAnnotation(annot) - try transform(annot) + try + val res = transform(annot) + if res ne annot then + // If the annotation has been transformed, we need to make sure that the + // symbol are copied so that we don't end up with the same symbol in different + // trees, which would lead to a crash in pickling. + TreeTypeMap(alwaysCopySymbols = true)(res) + else res finally inJavaAnnot = saved } diff --git a/tests/pos/annot-17939.scala b/tests/pos/annot-17939.scala new file mode 100644 index 000000000000..2b3adf0ac1cc --- /dev/null +++ b/tests/pos/annot-17939.scala @@ -0,0 +1,7 @@ +class qualified[T](f: T => Boolean) extends annotation.StaticAnnotation + +class Box[T](val x: T) +class Box2(val x: Int) + +class A(a: String @qualified((x: Int) => Box(3).x == 3)) // crash +class A2(a2: String @qualified((x: Int) => Box2(3).x == 3)) // works diff --git a/tests/pos/annot-17939b.scala b/tests/pos/annot-17939b.scala new file mode 100644 index 000000000000..a48f4690d0b2 --- /dev/null +++ b/tests/pos/annot-17939b.scala @@ -0,0 +1,10 @@ +import scala.annotation.Annotation +class myRefined(f: ? => Boolean) extends Annotation + +def test(axes: Int) = true + +trait Tensor: + def mean(axes: Int): Int @myRefined(_ => test(axes)) + +class TensorImpl() extends Tensor: + def mean(axes: Int) = ??? diff --git a/tests/pos/annot-17939c.scala b/tests/pos/annot-17939c.scala new file mode 100644 index 000000000000..5babdf389114 --- /dev/null +++ b/tests/pos/annot-17939c.scala @@ -0,0 +1,5 @@ +class qualified(f: Int => Boolean) extends annotation.StaticAnnotation +class Box[T](val y: T) +def Test = + val x: String @qualified((x: Int) => Box(42).y == 2) = ??? + val y = x diff --git a/tests/pos/annot-18064.scala b/tests/pos/annot-18064.scala new file mode 100644 index 000000000000..b6a67ea9ebe7 --- /dev/null +++ b/tests/pos/annot-18064.scala @@ -0,0 +1,9 @@ +//> using options "-Xprint:typer" + +class myAnnot[T]() extends annotation.Annotation + +trait Tensor[T]: + def add: Tensor[T] @myAnnot[T]() + +class TensorImpl[A]() extends Tensor[A]: + def add /* : Tensor[A] @myAnnot[A] */ = this diff --git a/tests/pos/annot-19846.scala b/tests/pos/annot-19846.scala new file mode 100644 index 000000000000..09c24a5cf3cf --- /dev/null +++ b/tests/pos/annot-19846.scala @@ -0,0 +1,8 @@ +class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation + +class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x)) + +@main def main = + val p = EqualPair(42, 42) + val y = p.y + println(42) diff --git a/tests/pos/annot-19846b.scala b/tests/pos/annot-19846b.scala new file mode 100644 index 000000000000..951a3c8116ff --- /dev/null +++ b/tests/pos/annot-19846b.scala @@ -0,0 +1,7 @@ +class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation + +def f(x: Int): Int @lambdaAnnot(() => x) = x + +object Test: + val y: Int = ??? + val z /* : Int @lambdaAnnot(() => y) */ = f(y) diff --git a/tests/pos/annot-5789.scala b/tests/pos/annot-5789.scala new file mode 100644 index 000000000000..bdf4438c9d5d --- /dev/null +++ b/tests/pos/annot-5789.scala @@ -0,0 +1,10 @@ +class Annot[T] extends scala.annotation.Annotation + +class D[T](val f: Int@Annot[T]) + +object A{ + def main(a:Array[String]) = { + val c = new D[Int](1) + c.f + } +} diff --git a/tests/printing/annot-18064.check b/tests/printing/annot-18064.check new file mode 100644 index 000000000000..d93ddb95afee --- /dev/null +++ b/tests/printing/annot-18064.check @@ -0,0 +1,16 @@ +[[syntax trees at end of typer]] // tests/printing/annot-18064.scala +package { + class myAnnot[T >: Nothing <: Any]() extends annotation.Annotation() { + T + } + trait Tensor[T >: Nothing <: Any]() extends Object { + T + def add: Tensor[Tensor.this.T] @myAnnot[T] + } + class TensorImpl[A >: Nothing <: Any]() extends Object(), Tensor[ + TensorImpl.this.A] { + A + def add: Tensor[A] @myAnnot[A] = this + } +} + diff --git a/tests/printing/annot-18064.scala b/tests/printing/annot-18064.scala new file mode 100644 index 000000000000..b6a67ea9ebe7 --- /dev/null +++ b/tests/printing/annot-18064.scala @@ -0,0 +1,9 @@ +//> using options "-Xprint:typer" + +class myAnnot[T]() extends annotation.Annotation + +trait Tensor[T]: + def add: Tensor[T] @myAnnot[T]() + +class TensorImpl[A]() extends Tensor[A]: + def add /* : Tensor[A] @myAnnot[A] */ = this diff --git a/tests/printing/annot-19846b.check b/tests/printing/annot-19846b.check new file mode 100644 index 000000000000..3f63a46c4286 --- /dev/null +++ b/tests/printing/annot-19846b.check @@ -0,0 +1,33 @@ +[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala +package { + class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(), + annotation.StaticAnnotation { + private[this] val g: () => Int + } + final lazy module val Test: Test = new Test() + final module class Test() extends Object() { this: Test.type => + val y: Int = ??? + val z: + Int @lambdaAnnot( + { + def $anonfun(): Int = Test.y + closure($anonfun) + } + ) + = f(Test.y) + } + final lazy module val annot-19846b$package: annot-19846b$package = + new annot-19846b$package() + final module class annot-19846b$package() extends Object() { + this: annot-19846b$package.type => + def f(x: Int): + Int @lambdaAnnot( + { + def $anonfun(): Int = x + closure($anonfun) + } + ) + = x + } +} + diff --git a/tests/printing/annot-19846b.scala b/tests/printing/annot-19846b.scala new file mode 100644 index 000000000000..951a3c8116ff --- /dev/null +++ b/tests/printing/annot-19846b.scala @@ -0,0 +1,7 @@ +class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation + +def f(x: Int): Int @lambdaAnnot(() => x) = x + +object Test: + val y: Int = ??? + val z /* : Int @lambdaAnnot(() => y) */ = f(y)