diff --git a/compiler/src/dotty/tools/dotc/util/optional.scala b/compiler/src/dotty/tools/dotc/util/optional.scala new file mode 100644 index 000000000000..cb62315d3c98 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/util/optional.scala @@ -0,0 +1,19 @@ +package dotty.tools.dotc.util + +import scala.util.boundary + +/** Return type that indicates that the method returns a T or aborts to the enclosing boundary with a `None` */ +type optional[T] = boundary.Label[None.type] ?=> T + +/** A prompt for `Option`, which establishes a boundary which `_.?` on `Option` can return */ +object optional: + inline def apply[T](inline body: optional[T]): Option[T] = + boundary(Some(body)) + + extension [T](r: Option[T]) + inline def ? (using label: boundary.Label[None.type]): T = r match + case Some(x) => x + case None => boundary.break(None) + + inline def break()(using label: boundary.Label[None.type]): Nothing = + boundary.break(None) diff --git a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala index 7c952dbbe142..5477628a30a3 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala @@ -1,7 +1,6 @@ package scala.quoted package runtime.impl - import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.core.Contexts.* import dotty.tools.dotc.core.Flags.* @@ -9,6 +8,7 @@ import dotty.tools.dotc.core.Names.* import dotty.tools.dotc.core.Types.* import dotty.tools.dotc.core.StdNames.nme import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.util.optional /** Matches a quoted tree against a quoted pattern tree. * A quoted pattern tree may have type and term holes in addition to normal terms. @@ -103,12 +103,13 @@ import dotty.tools.dotc.core.Symbols.* object QuoteMatcher { import tpd.* - // TODO improve performance - // TODO use flag from Context. Maybe -debug or add -debug-macros private inline val debug = false - import Matching._ + /** Sequence of matched expressions. + * These expressions are part of the scrutinee and will be bound to the quote pattern term splices. + */ + type MatchingExprs = Seq[MatchResult] /** A map relating equivalent symbols from the scrutinee and the pattern * For example in @@ -121,32 +122,34 @@ object QuoteMatcher { private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env) - def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[Tuple] = + def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[MatchingExprs] = given Env = Map.empty - scrutineeTree =?= patternTree + optional: + scrutineeTree =?= patternTree /** Check that all trees match with `mtch` and concatenate the results with &&& */ - private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match { + private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => MatchingExprs): optional[MatchingExprs] = (l1, l2) match { case (x :: xs, y :: ys) => mtch(x, y) &&& matchLists(xs, ys)(mtch) case (Nil, Nil) => matched case _ => notMatched } extension (scrutinees: List[Tree]) - private def =?= (patterns: List[Tree])(using Env, Context): Matching = + private def =?= (patterns: List[Tree])(using Env, Context): optional[MatchingExprs] = matchLists(scrutinees, patterns)(_ =?= _) extension (scrutinee0: Tree) /** Check that the trees match and return the contents from the pattern holes. - * Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes. + * Return a sequence containing all the contents in the holes. + * If it does not match, continues to the `optional` with `None`. * * @param scrutinee The tree being matched * @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes. * @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. - * @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. + * @return The sequence with the contents of the holes of the matched expression. */ - private def =?= (pattern0: Tree)(using Env, Context): Matching = + private def =?= (pattern0: Tree)(using Env, Context): optional[MatchingExprs] = /* Match block flattening */ // TODO move to cases /** Normalize the tree */ @@ -203,31 +206,12 @@ object QuoteMatcher { // Matches an open term and wraps it into a lambda that provides the free variables case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil) if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) => - def hoasClosure = { - val names: List[TermName] = args.map { - case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName - case arg => arg.symbol.name.asTermName - } - val argTypes = args.map(x => x.tpe.widenTermRefExpr) - val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe) - val meth = newAnonFun(ctx.owner, methTpe) - def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { - val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap - val body = new TreeMap { - override def transform(tree: Tree)(using Context): Tree = - tree match - case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) - case tree => super.transform(tree) - }.transform(scrutinee) - TreeOps(body).changeNonLocalOwners(meth) - } - Closure(meth, bodyFn) - } + val env = summon[Env] val capturedArgs = args.map(_.symbol) - val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v)) + val captureEnv = env.filter((k, v) => !capturedArgs.contains(v)) withEnv(captureEnv) { scrutinee match - case ClosedPatternTerm(scrutinee) => matched(hoasClosure) + case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env) case _ => notMatched } @@ -431,7 +415,6 @@ object QuoteMatcher { case _ => scrutinee val pattern = patternTree.symbol - devirtualizedScrutinee == pattern || summon[Env].get(devirtualizedScrutinee).contains(pattern) || devirtualizedScrutinee.allOverriddenSymbols.contains(pattern) @@ -452,32 +435,67 @@ object QuoteMatcher { accumulator.apply(Set.empty, term) } - /** Result of matching a part of an expression */ - private type Matching = Option[Tuple] - - private object Matching { - - def notMatched: Matching = None - - val matched: Matching = Some(Tuple()) - - def matched(tree: Tree)(using Context): Matching = - Some(Tuple1(new ExprImpl(tree, SpliceScope.getCurrent))) - - extension (self: Matching) - def asOptionOfTuple: Option[Tuple] = self - - /** Concatenates the contents of two successful matchings or return a `notMatched` */ - def &&& (that: => Matching): Matching = self match { - case Some(x) => - that match { - case Some(y) => Some(x ++ y) - case _ => None - } - case _ => None - } - end extension - - } + enum MatchResult: + /** Closed pattern extracted value + * @param tree Scrutinee sub-tree that matched + */ + case ClosedTree(tree: Tree) + /** HOAS pattern extracted value + * + * @param tree Scrutinee sub-tree that matched + * @param patternTpe Type of the pattern hole (from the pattern) + * @param args HOAS arguments (from the pattern) + * @param env Mapping between scrutinee and pattern variables + */ + case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env) + + /** Return the expression that was extracted from a hole. + * + * If it was a closed expression it returns that expression. Otherwise, + * if it is a HOAS pattern, the surrounding lambda is generated using + * `mapTypeHoles` to create the signature of the lambda. + * + * This expression is assumed to be a valid expression in the given splice scope. + */ + def toExpr(mapTypeHoles: TypeMap, spliceScope: Scope)(using Context): Expr[Any] = this match + case MatchResult.ClosedTree(tree) => + new ExprImpl(tree, spliceScope) + case MatchResult.OpenTree(tree, patternTpe, args, env) => + val names: List[TermName] = args.map { + case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName + case arg => arg.symbol.name.asTermName + } + val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr)) + val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) + val meth = newAnonFun(ctx.owner, methTpe) + def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { + val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap + val body = new TreeMap { + override def transform(tree: Tree)(using Context): Tree = + tree match + case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) + case tree => super.transform(tree) + }.transform(tree) + TreeOps(body).changeNonLocalOwners(meth) + } + val hoasClosure = Closure(meth, bodyFn) + new ExprImpl(hoasClosure, spliceScope) + + private inline def notMatched: optional[MatchingExprs] = + optional.break() + + private inline def matched: MatchingExprs = + Seq.empty + + private inline def matched(tree: Tree)(using Context): MatchingExprs = + Seq(MatchResult.ClosedTree(tree)) + + private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs = + Seq(MatchResult.OpenTree(tree, patternTpe, args, env)) + + extension (self: MatchingExprs) + /** Concatenates the contents of two successful matchings */ + def &&& (that: MatchingExprs): MatchingExprs = self ++ that + end extension } diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 6688d6e81a89..d3685bee9e23 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -3137,20 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler ctx1.gadtState.addToConstraint(typeHoles) ctx1 - val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1) - - if typeHoles.isEmpty then matchings - else { - // After matching and doing all subtype checks, we have to approximate all the type bindings - // that we have found, seal them in a quoted.Type and add them to the result - def typeHoleApproximation(sym: Symbol) = - val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot) - val fullBounds = ctx1.gadt.fullBounds(sym) - val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo - reflect.TypeReprMethods.asType(tp) - matchings.map { tup => - Tuple.fromIArray(typeHoles.map(typeHoleApproximation).toArray.asInstanceOf[IArray[Object]]) ++ tup + // After matching and doing all subtype checks, we have to approximate all the type bindings + // that we have found, seal them in a quoted.Type and add them to the result + def typeHoleApproximation(sym: Symbol) = + val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot) + val fullBounds = ctx1.gadt.fullBounds(sym) + if fromAboveAnnot then fullBounds.hi else fullBounds.lo + + QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings => + import QuoteMatcher.MatchResult.* + lazy val spliceScope = SpliceScope.getCurrent + val typeHoleApproximations = typeHoles.map(typeHoleApproximation) + val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*) + val typeHoleMap = new Types.TypeMap { + def apply(tp: Types.Type): Types.Type = tp match + case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp) + case _ => mapOver(tp) } + val matchedExprs = matchings.map(_.toExpr(typeHoleMap, spliceScope)) + val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType) + val results = matchedTypes ++ matchedExprs + Tuple.fromIArray(IArray.unsafeFromArray(results.toArray)) } } diff --git a/tests/pos-macros/i15165a/Macro_1.scala b/tests/pos-macros/i15165a/Macro_1.scala new file mode 100644 index 000000000000..8838d4c06bd1 --- /dev/null +++ b/tests/pos-macros/i15165a/Macro_1.scala @@ -0,0 +1,9 @@ +import scala.quoted.* + +inline def valToFun[T](inline expr: T): T = + ${ impl('expr) } + +def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] = + expr match + case '{ { val ident = ($a: α); $rest(ident): T } } => + '{ { (y: α) => $rest(y) }.apply(???) } diff --git a/tests/pos-macros/i15165a/Test_2.scala b/tests/pos-macros/i15165a/Test_2.scala new file mode 100644 index 000000000000..f7caa67b2df7 --- /dev/null +++ b/tests/pos-macros/i15165a/Test_2.scala @@ -0,0 +1,4 @@ +def test = valToFun { + val a: Int = 1 + a + 1 +} diff --git a/tests/pos-macros/i15165b/Macro_1.scala b/tests/pos-macros/i15165b/Macro_1.scala new file mode 100644 index 000000000000..5d62db37e313 --- /dev/null +++ b/tests/pos-macros/i15165b/Macro_1.scala @@ -0,0 +1,16 @@ +import scala.quoted.* + +inline def valToFun[T](inline expr: T): T = + ${ impl('expr) } + +def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] = + expr match + case '{ { val ident = ($a: α); $rest(ident): T } } => + '{ + { (y: α) => + ${ + val bound = '{ ${ rest }(y) } + Expr.betaReduce(bound) + } + }.apply($a) + } diff --git a/tests/pos-macros/i15165b/Test_2.scala b/tests/pos-macros/i15165b/Test_2.scala new file mode 100644 index 000000000000..f7caa67b2df7 --- /dev/null +++ b/tests/pos-macros/i15165b/Test_2.scala @@ -0,0 +1,4 @@ +def test = valToFun { + val a: Int = 1 + a + 1 +} diff --git a/tests/pos-macros/i15165c/Macro_1.scala b/tests/pos-macros/i15165c/Macro_1.scala new file mode 100644 index 000000000000..036363bf274f --- /dev/null +++ b/tests/pos-macros/i15165c/Macro_1.scala @@ -0,0 +1,9 @@ +import scala.quoted.* + +inline def valToFun[T](inline expr: T): T = + ${ impl('expr) } + +def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] = + expr match + case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } => + '{ { (y: α) => $rest(y) }.apply(???) } diff --git a/tests/pos-macros/i15165c/Test_2.scala b/tests/pos-macros/i15165c/Test_2.scala new file mode 100644 index 000000000000..f7caa67b2df7 --- /dev/null +++ b/tests/pos-macros/i15165c/Test_2.scala @@ -0,0 +1,4 @@ +def test = valToFun { + val a: Int = 1 + a + 1 +}