Skip to content

Commit

Permalink
Replace quoted type variables in signature of HOAS pattern result
Browse files Browse the repository at this point in the history
To be able to construct the lambda returned by the HOAS pattern we need:
first resolve the type variables and then use the result to construct the
signature of the lambdas.

To simplify this transformation, `QuoteMatcher` returns a `Seq[MatchResult]`
instead of an untyped `Tuple` containing `Expr[?]`. The tuple is created
once we have accumulated and processed all extracted values.

Fixes #15165
  • Loading branch information
nicolasstucki committed Mar 3, 2023
1 parent dbdca17 commit 20174d7
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 34 deletions.
87 changes: 60 additions & 27 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object QuoteMatcher {
/** Sequence of matched expressions.
* These expressions are part of the scrutinee and will be bound to the quote pattern term splices.
*/
type MatchingExprs = Seq[Expr[Any]]
type MatchingExprs = Seq[MatchResult]

/** A map relating equivalent symbols from the scrutinee and the pattern
* For example in
Expand Down Expand Up @@ -141,12 +141,13 @@ object QuoteMatcher {
extension (scrutinee0: Tree)

/** Check that the trees match and return the contents from the pattern holes.
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
* Return a sequence containing all the contents in the holes.
* If it does not match, continues to the `optional` with `None`.
*
* @param scrutinee The tree being matched
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
* @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @return `None` if it did not match or `Some(tup: MatchingExprs)` if it matched where `tup` contains the contents of the holes.
* @return The sequence with the contents of the holes of the matched expression.
*/
private def =?= (pattern0: Tree)(using Env, Context): optional[MatchingExprs] =

Expand Down Expand Up @@ -205,31 +206,12 @@ object QuoteMatcher {
// Matches an open term and wraps it into a lambda that provides the free variables
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
def hoasClosure = {
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(scrutinee)
TreeOps(body).changeNonLocalOwners(meth)
}
Closure(meth, bodyFn)
}
val env = summon[Env]
val capturedArgs = args.map(_.symbol)
val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v))
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v))
withEnv(captureEnv) {
scrutinee match
case ClosedPatternTerm(scrutinee) => matched(hoasClosure)
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
case _ => notMatched
}

Expand Down Expand Up @@ -453,16 +435,67 @@ object QuoteMatcher {
accumulator.apply(Set.empty, term)
}

enum MatchResult:
/** Closed pattern extracted value
* @param tree Scrutinee sub-tree that matched
*/
case ClosedTree(tree: Tree)
/** HOAS pattern extracted value
*
* @param tree Scrutinee sub-tree that matched
* @param patternTpe Type of the pattern hole (from the pattern)
* @param args HOAS arguments (from the pattern)
* @param env Mapping between scrutinee and pattern variables
*/
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)

/** Return the expression that was extracted from a hole.
*
* If it was a closed expression it returns that expression. Otherwise,
* if it is a HOAS pattern, the surrounding lambda is generated using
* `mapTypeHoles` to create the signature of the lambda.
*
* This expression is assumed to be a valid expression in the given splice scope.
*/
def toExpr(mapTypeHoles: TypeMap, spliceScope: Scope)(using Context): Expr[Any] = this match
case MatchResult.ClosedTree(tree) =>
new ExprImpl(tree, spliceScope)
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
TreeOps(body).changeNonLocalOwners(meth)
}
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)

private inline def notMatched: optional[MatchingExprs] =
optional.break()

private inline def matched: MatchingExprs =
Seq.empty

private inline def matched(tree: Tree)(using Context): MatchingExprs =
Seq(new ExprImpl(tree, SpliceScope.getCurrent))
Seq(MatchResult.ClosedTree(tree))

private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs =
Seq(MatchResult.OpenTree(tree, patternTpe, args, env))

extension (self: MatchingExprs)
private inline def &&& (that: MatchingExprs): MatchingExprs = self ++ that
/** Concatenates the contents of two successful matchings */
def &&& (that: MatchingExprs): MatchingExprs = self ++ that
end extension

}
23 changes: 16 additions & 7 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3137,18 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
ctx1.gadtState.addToConstraint(typeHoles)
ctx1

val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1)

// After matching and doing all subtype checks, we have to approximate all the type bindings
// that we have found, seal them in a quoted.Type and add them to the result
def typeHoleApproximation(sym: Symbol) =
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
val fullBounds = ctx1.gadt.fullBounds(sym)
val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo
reflect.TypeReprMethods.asType(tp)
matchings.map { tup =>
val results = typeHoles.map(typeHoleApproximation) ++ tup
Tuple.fromIArray(results.toArray.asInstanceOf[IArray[Object]])
if fromAboveAnnot then fullBounds.hi else fullBounds.lo

QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings =>
import QuoteMatcher.MatchResult.*
lazy val spliceScope = SpliceScope.getCurrent
val typeHoleApproximations = typeHoles.map(typeHoleApproximation)
val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*)
val typeHoleMap = new Types.TypeMap {
def apply(tp: Types.Type): Types.Type = tp match
case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp)
case _ => mapOver(tp)
}
val matchedExprs = matchings.map(_.toExpr(typeHoleMap, spliceScope))
val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType)
val results = matchedTypes ++ matchedExprs
Tuple.fromIArray(IArray.unsafeFromArray(results.toArray))
}
}

Expand Down
9 changes: 9 additions & 0 deletions tests/pos-macros/i15165a/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ { val ident = ($a: α); $rest(ident): T } } =>
'{ { (y: α) => $rest(y) }.apply(???) }
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165a/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}
16 changes: 16 additions & 0 deletions tests/pos-macros/i15165b/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ { val ident = ($a: α); $rest(ident): T } } =>
'{
{ (y: α) =>
${
val bound = '{ ${ rest }(y) }
Expr.betaReduce(bound)
}
}.apply($a)
}
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165b/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}
9 changes: 9 additions & 0 deletions tests/pos-macros/i15165c/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } =>
'{ { (y: α) => $rest(y) }.apply(???) }
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165c/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}

0 comments on commit 20174d7

Please sign in to comment.