Skip to content

Commit

Permalink
Place staged type captures in Quote AST
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasstucki committed May 11, 2023
1 parent d103f8c commit 0ee80f9
Show file tree
Hide file tree
Showing 25 changed files with 249 additions and 251 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/CompilationUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ object CompilationUnit {
if tree.symbol.is(Flags.Inline) then
containsInline = true
tree match
case tpd.Quote(_) =>
case _: tpd.Quote =>
containsQuote = true
case tree: tpd.Apply if tree.symbol == defn.QuotedTypeModule_of =>
containsQuote = true
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1976,7 +1976,7 @@ object desugar {
trees foreach collect
case Block(Nil, expr) =>
collect(expr)
case Quote(body) =>
case Quote(body, _) =>
new UntypedTreeTraverser {
def traverse(tree: untpd.Tree)(using Context): Unit = tree match {
case Splice(expr) => collect(expr)
Expand Down
21 changes: 13 additions & 8 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,14 @@ object Trees {
* when type checking. TASTy files will not contain type quotes. Type quotes are used again
* in the `staging` phase to represent the reification of `Type.of[T]]`.
*
* Type tags `tags` are always empty before the `staging` phase. Tags for stage inconsistent
* types are added in the `staging` phase to level 0 quotes. Tags for types that refer to
* definitions in an outer quote are added in the `splicing` phase
*
* @param body The tree that was quoted
* @param tags Term references to instances of `Type[T]` for `T`s that are used in the quote
*/
case class Quote[+T <: Untyped] private[ast] (body: Tree[T])(implicit @constructorOnly src: SourceFile)
case class Quote[+T <: Untyped] private[ast] (body: Tree[T], tags: List[Tree[T]])(implicit @constructorOnly src: SourceFile)
extends TermTree[T] {
type ThisTree[+T <: Untyped] = Quote[T]

Expand Down Expand Up @@ -1313,9 +1318,9 @@ object Trees {
case tree: Inlined if (call eq tree.call) && (bindings eq tree.bindings) && (expansion eq tree.expansion) => tree
case _ => finalize(tree, untpd.Inlined(call, bindings, expansion)(sourceFile(tree)))
}
def Quote(tree: Tree)(body: Tree)(using Context): Quote = tree match {
case tree: Quote if (body eq tree.body) => tree
case _ => finalize(tree, untpd.Quote(body)(sourceFile(tree)))
def Quote(tree: Tree)(body: Tree, tags: List[Tree])(using Context): Quote = tree match {
case tree: Quote if (body eq tree.body) && (tags eq tree.tags) => tree
case _ => finalize(tree, untpd.Quote(body, tags)(sourceFile(tree)))
}
def Splice(tree: Tree)(expr: Tree)(using Context): Splice = tree match {
case tree: Splice if (expr eq tree.expr) => tree
Expand Down Expand Up @@ -1558,8 +1563,8 @@ object Trees {
case Thicket(trees) =>
val trees1 = transform(trees)
if (trees1 eq trees) tree else Thicket(trees1)
case tree @ Quote(body) =>
cpy.Quote(tree)(transform(body)(using quoteContext))
case Quote(body, tags) =>
cpy.Quote(tree)(transform(body)(using quoteContext), transform(tags))
case tree @ Splice(expr) =>
cpy.Splice(tree)(transform(expr)(using spliceContext))
case tree @ Hole(isTerm, idx, args, content, tpt) =>
Expand Down Expand Up @@ -1703,8 +1708,8 @@ object Trees {
this(this(x, arg), annot)
case Thicket(ts) =>
this(x, ts)
case Quote(body) =>
this(x, body)(using quoteContext)
case Quote(body, tags) =>
this(this(x, body)(using quoteContext), tags)
case Splice(expr) =>
this(x, expr)(using spliceContext)
case Hole(_, _, args, content, tpt) =>
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def Inlined(call: Tree, bindings: List[MemberDef], expansion: Tree)(using Context): Inlined =
ta.assignType(untpd.Inlined(call, bindings, expansion), bindings, expansion)

def Quote(body: Tree)(using Context): Quote =
untpd.Quote(body).withBodyType(body.tpe)
def Quote(body: Tree, tags: List[Tree])(using Context): Quote =
untpd.Quote(body, tags).withBodyType(body.tpe)

def Splice(expr: Tree, tpe: Type)(using Context): Splice =
untpd.Splice(expr).withType(tpe)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def SeqLiteral(elems: List[Tree], elemtpt: Tree)(implicit src: SourceFile): SeqLiteral = new SeqLiteral(elems, elemtpt)
def JavaSeqLiteral(elems: List[Tree], elemtpt: Tree)(implicit src: SourceFile): JavaSeqLiteral = new JavaSeqLiteral(elems, elemtpt)
def Inlined(call: tpd.Tree, bindings: List[MemberDef], expansion: Tree)(implicit src: SourceFile): Inlined = new Inlined(call, bindings, expansion)
def Quote(body: Tree)(implicit src: SourceFile): Quote = new Quote(body)
def Quote(body: Tree, tags: List[Tree])(implicit src: SourceFile): Quote = new Quote(body, tags)
def Splice(expr: Tree)(implicit src: SourceFile): Splice = new Splice(expr)
def TypeTree()(implicit src: SourceFile): TypeTree = new TypeTree()
def InferredTypeTree()(implicit src: SourceFile): TypeTree = new InferredTypeTree()
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ object Phases {
private var mySbtExtractDependenciesPhase: Phase = _
private var myPicklerPhase: Phase = _
private var myInliningPhase: Phase = _
private var myStagingPhase: Phase = _
private var mySplicingPhase: Phase = _
private var myFirstTransformPhase: Phase = _
private var myCollectNullableFieldsPhase: Phase = _
Expand All @@ -235,6 +236,7 @@ object Phases {
final def sbtExtractDependenciesPhase: Phase = mySbtExtractDependenciesPhase
final def picklerPhase: Phase = myPicklerPhase
final def inliningPhase: Phase = myInliningPhase
final def stagingPhase: Phase = myStagingPhase
final def splicingPhase: Phase = mySplicingPhase
final def firstTransformPhase: Phase = myFirstTransformPhase
final def collectNullableFieldsPhase: Phase = myCollectNullableFieldsPhase
Expand Down Expand Up @@ -262,6 +264,7 @@ object Phases {
mySbtExtractDependenciesPhase = phaseOfClass(classOf[sbt.ExtractDependencies])
myPicklerPhase = phaseOfClass(classOf[Pickler])
myInliningPhase = phaseOfClass(classOf[Inlining])
myStagingPhase = phaseOfClass(classOf[Staging])
mySplicingPhase = phaseOfClass(classOf[Splicing])
myFirstTransformPhase = phaseOfClass(classOf[FirstTransform])
myCollectNullableFieldsPhase = phaseOfClass(classOf[CollectNullableFields])
Expand Down Expand Up @@ -449,6 +452,7 @@ object Phases {
def sbtExtractDependenciesPhase(using Context): Phase = ctx.base.sbtExtractDependenciesPhase
def picklerPhase(using Context): Phase = ctx.base.picklerPhase
def inliningPhase(using Context): Phase = ctx.base.inliningPhase
def stagingPhase(using Context): Phase = ctx.base.stagingPhase
def splicingPhase(using Context): Phase = ctx.base.splicingPhase
def firstTransformPhase(using Context): Phase = ctx.base.firstTransformPhase
def refchecksPhase(using Context): Phase = ctx.base.refchecksPhase
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ class TreePickler(pickler: TastyPickler) {
pickleTree(hi)
pickleTree(alias)
}
case tree @ Quote(body) =>
case tree @ Quote(body, Nil) =>
// TODO: Add QUOTE tag to TASTy
assert(body.isTerm,
"""Quote with type should not be pickled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ class TreeUnpickler(reader: TastyReader,

def quotedExpr(fn: Tree, args: List[Tree]): Tree =
val TypeApply(_, targs) = fn: @unchecked
untpd.Quote(args.head).withBodyType(targs.head.tpe)
untpd.Quote(args.head, Nil).withBodyType(targs.head.tpe)

def splicedExpr(fn: Tree, args: List[Tree]): Tree =
val TypeApply(_, targs) = fn: @unchecked
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/inlines/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ class Inliner(val call: tpd.Tree)(using Context):

override def typedQuote(tree: untpd.Quote, pt: Type)(using Context): Tree =
super.typedQuote(tree, pt) match
case Quote(Splice(inner)) => inner
case Quote(Splice(inner), _) => inner
case tree1 =>
ctx.compilationUnit.needsStaging = true
tree1
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ object Parsers {
}
}
in.nextToken()
Quote(t)
Quote(t, Nil)
}
else
if !in.featureEnabled(Feature.symbolLiterals) then
Expand Down Expand Up @@ -2480,7 +2480,7 @@ object Parsers {
val body =
if (in.token == LBRACKET) inBrackets(typ())
else stagedBlock()
Quote(body)
Quote(body, Nil)
}
}
case NEW =>
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -726,11 +726,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
"Thicket {" ~~ toTextGlobal(trees, "\n") ~~ "}"
case MacroTree(call) =>
keywordStr("macro ") ~ toTextGlobal(call)
case tree @ Quote(body) =>
case tree @ Quote(body, tags) =>
val tagsText = (keywordStr("<") ~ toTextGlobal(tags, ", ") ~ keywordStr(">")).provided(tree.tags.nonEmpty)
val exprTypeText = (keywordStr("[") ~ toTextGlobal(tree.bodyType) ~ keywordStr("]")).provided(printDebug && tree.typeOpt.exists)
val open = if (body.isTerm) keywordStr("{") else keywordStr("[")
val close = if (body.isTerm) keywordStr("}") else keywordStr("]")
keywordStr("'") ~ exprTypeText ~ open ~ toTextGlobal(body) ~ close
keywordStr("'") ~ tagsText ~ exprTypeText ~ open ~ toTextGlobal(body) ~ close
case Splice(expr) =>
val spliceTypeText = (keywordStr("[") ~ toTextGlobal(tree.typeOpt) ~ keywordStr("]")).provided(printDebug && tree.typeOpt.exists)
keywordStr("$") ~ spliceTypeText ~ keywordStr("{") ~ toTextGlobal(expr) ~ keywordStr("}")
Expand Down
80 changes: 40 additions & 40 deletions compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,36 @@ import dotty.tools.dotc.util.Property
import dotty.tools.dotc.util.Spans._
import dotty.tools.dotc.util.SrcPos

/** Checks that staging level consistency holds and heals staged types .
/** Checks that staging level consistency holds and heals staged types.
*
* Local term references are level consistent if and only if they are used at the same level as their definition.
*
* Local type references can be used at the level of their definition or lower. If used used at a higher level,
* it will be healed if possible, otherwise it is inconsistent.
*
* Type healing consists in transforming a level inconsistent type `T` into `summon[Type[T]].Underlying`.
* Healing a type consists in replacing locally defined types defined at staging level 0 and used in higher levels.
* For each type local `T` that is defined at level 0 and used in a quote, we summon a tag `t: Type[T]`. This `t`
* tag must be defined at level 0. The tags will be listed in the `tags` of the level 0 quote (`'<t>{ ... }`) and
* each reference to `T` will be replaced by `t.Underlying` in the body of the quote.
*
* We delay the healing of types in quotes at level 1 or higher until those quotes reach level 0. At this point
* more types will be statically known and fewer types will need to be healed. This also keeps the nested quotes
* in their original form, we do not want macro users to see any artifacts of this phase in quoted expressions
* they might inspect.
*
* Type heal example:
*
* As references to types do not necessarily have an associated tree it is not always possible to replace the types directly.
* Instead we always generate a type alias for it and place it at the start of the surrounding quote. This also avoids duplication.
* For example:
* '{
* val x: List[T] = List[T]()
* '{ .. T .. }
* ()
* }
*
* is transformed to
*
* '{
* type t$1 = summon[Type[T]].Underlying
* val x: List[t$1] = List[t$1]();
* '<t>{ // where `t` is a given term of type `Type[T]`
* val x: List[t.Underlying] = List[t.Underlying]();
* '{ .. t.Underlying .. }
* ()
* }
*
Expand All @@ -56,11 +64,18 @@ class CrossStageSafety extends TreeMapWithStages {
case tree: Quote =>
if (ctx.property(InAnnotation).isDefined)
report.error("Cannot have a quote in an annotation", tree.srcPos)
val body1 = transformQuoteBody(tree.body, tree.span)
val stripAnnotationsDeep: TypeMap = new TypeMap:
def apply(tp: Type): Type = mapOver(tp.stripAnnots)
val bodyType1 = healType(tree.srcPos)(stripAnnotationsDeep(tree.bodyType))
cpy.Quote(tree)(body1).withBodyType(bodyType1)

val tree1 =
val stripAnnotationsDeep: TypeMap = new TypeMap:
def apply(tp: Type): Type = mapOver(tp.stripAnnots)
val bodyType1 = healType(tree.srcPos)(stripAnnotationsDeep(tree.bodyType))
tree.withBodyType(bodyType1)

if level == 0 then
val (tags, body1) = inContextWithQuoteTypeTags { transform(tree1.body)(using quoteContext) }
cpy.Quote(tree1)(body1, tags)
else
super.transform(tree1)

case CancelledSplice(tree) =>
transform(tree) // Optimization: `${ 'x }` --> `x`
Expand All @@ -74,22 +89,18 @@ class CrossStageSafety extends TreeMapWithStages {
case tree @ QuotedTypeOf(body) =>
if (ctx.property(InAnnotation).isDefined)
report.error("Cannot have a quote in an annotation", tree.srcPos)
body.tpe match
case DirectTypeOf(termRef) =>
// Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x`
ref(termRef).withSpan(tree.span)
case _ =>
transformQuoteBody(body, tree.span) match
case DirectTypeOf.Healed(termRef) =>
// Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x`
ref(termRef).withSpan(tree.span)
case transformedBody =>
val quotes = transform(tree.args.head)
// `quoted.Type.of[<body>](quotes)` --> `quoted.Type.of[<body2>](quotes)`
val TypeApply(fun, _) = tree.fun: @unchecked
if level != 0 then cpy.Apply(tree)(cpy.TypeApply(tree.fun)(fun, transformedBody :: Nil), quotes :: Nil)
else tpd.Quote(transformedBody).select(nme.apply).appliedTo(quotes).withSpan(tree.span)

if level == 0 then
val (tags, body1) = inContextWithQuoteTypeTags { transform(body)(using quoteContext) }
val quotes = transform(tree.args.head)
tags match
case tag :: Nil if body1.isType && body1.tpe =:= tag.tpe.select(tpnme.Underlying) =>
tag // Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x`
case _ =>
// `quoted.Type.of[<body>](<quotes>)` --> `'[<body1>].apply(<quotes>)`
tpd.Quote(body1, tags).select(nme.apply).appliedTo(quotes).withSpan(tree.span)
else
super.transform(tree)
case _: DefDef if tree.symbol.isInlineMethod =>
tree

Expand Down Expand Up @@ -137,17 +148,6 @@ class CrossStageSafety extends TreeMapWithStages {
super.transform(tree)
end transform

private def transformQuoteBody(body: Tree, span: Span)(using Context): Tree = {
val taggedTypes = new QuoteTypeTags(span)
val contextWithQuote =
if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext)
else quoteContext
val transformedBody = transform(body)(using contextWithQuote)
taggedTypes.getTypeTags match
case Nil => transformedBody
case tags => tpd.Block(tags, transformedBody).withSpan(body.span)
}

def transformTypeAnnotationSplices(tp: Type)(using Context) = new TypeMap {
def apply(tp: Type): Type = tp match
case tp: AnnotatedType =>
Expand Down Expand Up @@ -234,7 +234,7 @@ class CrossStageSafety extends TreeMapWithStages {
def unapply(tree: Splice): Option[Tree] =
def rec(tree: Tree): Option[Tree] = tree match
case Block(Nil, expr) => rec(expr)
case Quote(inner) => Some(inner)
case Quote(inner, _) => Some(inner)
case _ => None
rec(tree.expr)
}
25 changes: 0 additions & 25 deletions compiler/src/dotty/tools/dotc/staging/DirectTypeOf.scala

This file was deleted.

12 changes: 5 additions & 7 deletions compiler/src/dotty/tools/dotc/staging/HealType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
*
* If `T` is a reference to a type at the wrong level, try to heal it by replacing it with
* a type tag of type `quoted.Type[T]`.
* The tag is generated by an instance of `QuoteTypeTags` directly if the splice is explicit
* The tag is recorded by an instance of `QuoteTypeTags` directly if the splice is explicit
* or indirectly by `tryHeal`.
*/
def apply(tp: Type): Type =
Expand All @@ -43,11 +43,9 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {

private def healTypeRef(tp: TypeRef): Type =
tp.prefix match
case NoPrefix if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
tp
case prefix: TermRef if tp.symbol.isTypeSplice =>
checkNotWildcardSplice(tp)
if level == 0 then tp else getQuoteTypeTags.getTagRef(prefix)
if level == 0 then tp else getTagRef(prefix)
case _: NamedType | _: ThisType | NoPrefix =>
if levelInconsistentRootOfPath(tp).exists then
tryHeal(tp)
Expand All @@ -58,7 +56,7 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {

private object NonSpliceAlias:
def unapply(tp: TypeRef)(using Context): Option[Type] = tp.underlying match
case TypeAlias(alias) if !tp.symbol.isTypeSplice && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) => Some(alias)
case TypeAlias(alias) if !tp.symbol.isTypeSplice => Some(alias)
case _ => None

private def checkNotWildcardSplice(splice: TypeRef): Unit =
Expand All @@ -78,7 +76,7 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {

/** Try to heal reference to type `T` used in a higher level than its definition.
* Returns a reference to a type tag generated by `QuoteTypeTags` that contains a
* reference to a type alias containing the equivalent of `${summon[quoted.Type[T]]}`.
* reference to a type alias containing the equivalent of `${summon[quoted.Type[T]]}.Underlying`.
* Emits an error if `T` cannot be healed and returns `T`.
*/
protected def tryHeal(tp: TypeRef): Type = {
Expand All @@ -88,7 +86,7 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
case tp: TermRef =>
ctx.typer.checkStable(tp, pos, "type witness")
if levelOf(tp.symbol) > 0 then tp.select(tpnme.Underlying)
else getQuoteTypeTags.getTagRef(tp)
else getTagRef(tp)
case _: SearchFailureType =>
report.error(
ctx.typer.missingArgMsg(tag, reqType, "")
Expand Down
Loading

0 comments on commit 0ee80f9

Please sign in to comment.