diff --git a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala index 9a77cba97400..5477628a30a3 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala @@ -109,7 +109,7 @@ object QuoteMatcher { /** Sequence of matched expressions. * These expressions are part of the scrutinee and will be bound to the quote pattern term splices. */ - type MatchingExprs = Seq[Expr[Any]] + type MatchingExprs = Seq[MatchResult] /** A map relating equivalent symbols from the scrutinee and the pattern * For example in @@ -141,12 +141,13 @@ object QuoteMatcher { extension (scrutinee0: Tree) /** Check that the trees match and return the contents from the pattern holes. - * Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes. + * Return a sequence containing all the contents in the holes. + * If it does not match, continues to the `optional` with `None`. * * @param scrutinee The tree being matched * @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes. * @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. - * @return `None` if it did not match or `Some(tup: MatchingExprs)` if it matched where `tup` contains the contents of the holes. + * @return The sequence with the contents of the holes of the matched expression. */ private def =?= (pattern0: Tree)(using Env, Context): optional[MatchingExprs] = @@ -205,31 +206,12 @@ object QuoteMatcher { // Matches an open term and wraps it into a lambda that provides the free variables case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil) if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) => - def hoasClosure = { - val names: List[TermName] = args.map { - case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName - case arg => arg.symbol.name.asTermName - } - val argTypes = args.map(x => x.tpe.widenTermRefExpr) - val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe) - val meth = newAnonFun(ctx.owner, methTpe) - def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { - val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap - val body = new TreeMap { - override def transform(tree: Tree)(using Context): Tree = - tree match - case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) - case tree => super.transform(tree) - }.transform(scrutinee) - TreeOps(body).changeNonLocalOwners(meth) - } - Closure(meth, bodyFn) - } + val env = summon[Env] val capturedArgs = args.map(_.symbol) - val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v)) + val captureEnv = env.filter((k, v) => !capturedArgs.contains(v)) withEnv(captureEnv) { scrutinee match - case ClosedPatternTerm(scrutinee) => matched(hoasClosure) + case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env) case _ => notMatched } @@ -453,6 +435,52 @@ object QuoteMatcher { accumulator.apply(Set.empty, term) } + enum MatchResult: + /** Closed pattern extracted value + * @param tree Scrutinee sub-tree that matched + */ + case ClosedTree(tree: Tree) + /** HOAS pattern extracted value + * + * @param tree Scrutinee sub-tree that matched + * @param patternTpe Type of the pattern hole (from the pattern) + * @param args HOAS arguments (from the pattern) + * @param env Mapping between scrutinee and pattern variables + */ + case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env) + + /** Return the expression that was extracted from a hole. + * + * If it was a closed expression it returns that expression. Otherwise, + * if it is a HOAS pattern, the surrounding lambda is generated using + * `mapTypeHoles` to create the signature of the lambda. + * + * This expression is assumed to be a valid expression in the given splice scope. + */ + def toExpr(mapTypeHoles: TypeMap, spliceScope: Scope)(using Context): Expr[Any] = this match + case MatchResult.ClosedTree(tree) => + new ExprImpl(tree, spliceScope) + case MatchResult.OpenTree(tree, patternTpe, args, env) => + val names: List[TermName] = args.map { + case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName + case arg => arg.symbol.name.asTermName + } + val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr)) + val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) + val meth = newAnonFun(ctx.owner, methTpe) + def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { + val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap + val body = new TreeMap { + override def transform(tree: Tree)(using Context): Tree = + tree match + case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) + case tree => super.transform(tree) + }.transform(tree) + TreeOps(body).changeNonLocalOwners(meth) + } + val hoasClosure = Closure(meth, bodyFn) + new ExprImpl(hoasClosure, spliceScope) + private inline def notMatched: optional[MatchingExprs] = optional.break() @@ -460,9 +488,14 @@ object QuoteMatcher { Seq.empty private inline def matched(tree: Tree)(using Context): MatchingExprs = - Seq(new ExprImpl(tree, SpliceScope.getCurrent)) + Seq(MatchResult.ClosedTree(tree)) + + private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs = + Seq(MatchResult.OpenTree(tree, patternTpe, args, env)) extension (self: MatchingExprs) - private inline def &&& (that: MatchingExprs): MatchingExprs = self ++ that + /** Concatenates the contents of two successful matchings */ + def &&& (that: MatchingExprs): MatchingExprs = self ++ that + end extension } diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 378f3f6a5c40..a94f40a8d068 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -3137,18 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler ctx1.gadtState.addToConstraint(typeHoles) ctx1 - val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1) - // After matching and doing all subtype checks, we have to approximate all the type bindings // that we have found, seal them in a quoted.Type and add them to the result def typeHoleApproximation(sym: Symbol) = val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot) val fullBounds = ctx1.gadt.fullBounds(sym) - val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo - reflect.TypeReprMethods.asType(tp) - matchings.map { tup => - val results = typeHoles.map(typeHoleApproximation) ++ tup - Tuple.fromIArray(results.toArray.asInstanceOf[IArray[Object]]) + if fromAboveAnnot then fullBounds.hi else fullBounds.lo + + QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings => + import QuoteMatcher.MatchResult.* + lazy val spliceScope = SpliceScope.getCurrent + val typeHoleApproximations = typeHoles.map(typeHoleApproximation) + val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*) + val typeHoleMap = new Types.TypeMap { + def apply(tp: Types.Type): Types.Type = tp match + case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp) + case _ => mapOver(tp) + } + val matchedExprs = matchings.map(_.toExpr(typeHoleMap, spliceScope)) + val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType) + val results = matchedTypes ++ matchedExprs + Tuple.fromIArray(IArray.unsafeFromArray(results.toArray)) } } diff --git a/tests/pos-macros/i15165a/Macro_1.scala b/tests/pos-macros/i15165a/Macro_1.scala new file mode 100644 index 000000000000..8838d4c06bd1 --- /dev/null +++ b/tests/pos-macros/i15165a/Macro_1.scala @@ -0,0 +1,9 @@ +import scala.quoted.* + +inline def valToFun[T](inline expr: T): T = + ${ impl('expr) } + +def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] = + expr match + case '{ { val ident = ($a: α); $rest(ident): T } } => + '{ { (y: α) => $rest(y) }.apply(???) } diff --git a/tests/pos-macros/i15165a/Test_2.scala b/tests/pos-macros/i15165a/Test_2.scala new file mode 100644 index 000000000000..f7caa67b2df7 --- /dev/null +++ b/tests/pos-macros/i15165a/Test_2.scala @@ -0,0 +1,4 @@ +def test = valToFun { + val a: Int = 1 + a + 1 +} diff --git a/tests/pos-macros/i15165b/Macro_1.scala b/tests/pos-macros/i15165b/Macro_1.scala new file mode 100644 index 000000000000..5d62db37e313 --- /dev/null +++ b/tests/pos-macros/i15165b/Macro_1.scala @@ -0,0 +1,16 @@ +import scala.quoted.* + +inline def valToFun[T](inline expr: T): T = + ${ impl('expr) } + +def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] = + expr match + case '{ { val ident = ($a: α); $rest(ident): T } } => + '{ + { (y: α) => + ${ + val bound = '{ ${ rest }(y) } + Expr.betaReduce(bound) + } + }.apply($a) + } diff --git a/tests/pos-macros/i15165b/Test_2.scala b/tests/pos-macros/i15165b/Test_2.scala new file mode 100644 index 000000000000..f7caa67b2df7 --- /dev/null +++ b/tests/pos-macros/i15165b/Test_2.scala @@ -0,0 +1,4 @@ +def test = valToFun { + val a: Int = 1 + a + 1 +} diff --git a/tests/pos-macros/i15165c/Macro_1.scala b/tests/pos-macros/i15165c/Macro_1.scala new file mode 100644 index 000000000000..036363bf274f --- /dev/null +++ b/tests/pos-macros/i15165c/Macro_1.scala @@ -0,0 +1,9 @@ +import scala.quoted.* + +inline def valToFun[T](inline expr: T): T = + ${ impl('expr) } + +def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] = + expr match + case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } => + '{ { (y: α) => $rest(y) }.apply(???) } diff --git a/tests/pos-macros/i15165c/Test_2.scala b/tests/pos-macros/i15165c/Test_2.scala new file mode 100644 index 000000000000..f7caa67b2df7 --- /dev/null +++ b/tests/pos-macros/i15165c/Test_2.scala @@ -0,0 +1,4 @@ +def test = valToFun { + val a: Int = 1 + a + 1 +}