diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns1.dotty.check b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns1.dotty.check new file mode 100644 index 0000000000..c0a1abb8d9 --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns1.dotty.check @@ -0,0 +1,3 @@ +[ Error ] InvalidTypedPatterns1.scala:13:47: Unsupported pattern: _:B + case MyOtherClass(s1: A, MyOtherClass(s2: B, moc2: MyOtherClass[A, B])) => s1 // s2: B is invalid + ^ \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns1.scala b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns1.scala new file mode 100644 index 0000000000..83e34c0f9c --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns1.scala @@ -0,0 +1,15 @@ +import stainless.lang._ + +object InvalidTypedPatterns1 { + + case class MyOtherClass[S, T](s1: S, s2: S, t: T) + + object MyOtherClass { + def unapply[S, T](moc: MyOtherClass[S, T]): Option[(S, MyOtherClass[S, T])] = Some((moc.s1, moc)) + } + + + def test[A, B](moc: MyOtherClass[A, B]): A = moc match { + case MyOtherClass(s1: A, MyOtherClass(s2: B, moc2: MyOtherClass[A, B])) => s1 // s2: B is invalid + } +} \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns1.scalac.check b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns1.scalac.check new file mode 100644 index 0000000000..38bef44ccb --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns1.scalac.check @@ -0,0 +1,3 @@ +[ Error ] InvalidTypedPatterns1.scala:13:45: Unsupported pattern: (_: B) + case MyOtherClass(s1: A, MyOtherClass(s2: B, moc2: MyOtherClass[A, B])) => s1 // s2: B is invalid + ^^^ \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns2.dotty.check b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns2.dotty.check new file mode 100644 index 0000000000..ed7cfff05f --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns2.dotty.check @@ -0,0 +1 @@ +[ Error ] Stainless does not support type B & A \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns2.scala b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns2.scala new file mode 100644 index 0000000000..3d0d0bb9af --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns2.scala @@ -0,0 +1,5 @@ +object InvalidTypedPatterns2 { + def test[A, B](a: A, b: B): Unit = { + val (aa1: A, bb: A) = (a, b) // bb: A is invalid + } +} \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns2.scalac.check b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns2.scalac.check new file mode 100644 index 0000000000..5d65653f0b --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns2.scalac.check @@ -0,0 +1,3 @@ +[ Error ] InvalidTypedPatterns2.scala:3:20: Unsupported pattern: (_: A) + val (aa1: A, bb: A) = (a, b) // bb: A is invalid + ^^^ \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns3.dotty.check b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns3.dotty.check new file mode 100644 index 0000000000..cc9b7c844b --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns3.dotty.check @@ -0,0 +1,3 @@ +[ Error ] InvalidTypedPatterns3.scala:5:29: Unsupported pattern: _:B + case MyClass(a1: A, a2: B, b: B) => a1 // a2: B is invalid + ^ \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns3.scala b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns3.scala new file mode 100644 index 0000000000..b7647a95dd --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns3.scala @@ -0,0 +1,7 @@ +object InvalidTypedPatterns3 { + case class MyClass[S, T](s1: S, s2: S, t: T) + + def test[A, B](mc: MyClass[A, B]): A = mc match { + case MyClass(a1: A, a2: B, b: B) => a1 // a2: B is invalid + } +} \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns3.scalac.check b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns3.scalac.check new file mode 100644 index 0000000000..7d9bc3be82 --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns3.scalac.check @@ -0,0 +1,3 @@ +[ Error ] InvalidTypedPatterns3.scala:5:27: Unsupported pattern: (_: B) + case MyClass(a1: A, a2: B, b: B) => a1 // a2: B is invalid + ^^^ \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns4.dotty.check b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns4.dotty.check new file mode 100644 index 0000000000..197317cdd8 --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns4.dotty.check @@ -0,0 +1,3 @@ +[ Error ] InvalidTypedPatterns4.scala:5:31: Unsupported pattern: _:String + case MyClass(a1: Int, a2: String, b: String) => a1 // a2: String is invalid + ^^^^^^ \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns4.scala b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns4.scala new file mode 100644 index 0000000000..7645b63c2f --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns4.scala @@ -0,0 +1,7 @@ +object InvalidTypedPatterns4 { + case class MyClass[S, T](s1: S, s2: S, t: T) + + def test5(mc: MyClass[Int, String]): Int = mc match { + case MyClass(a1: Int, a2: String, b: String) => a1 // a2: String is invalid + } +} \ No newline at end of file diff --git a/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns4.scalac.check b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns4.scalac.check new file mode 100644 index 0000000000..39924a7f21 --- /dev/null +++ b/frontends/benchmarks/extraction/invalid/InvalidTypedPatterns4.scalac.check @@ -0,0 +1,5 @@ +[ Error ] InvalidTypedPatterns4.scala:5:31: scrutinee is incompatible with pattern type; +[ Error ] found : String +[ Error ] required: Int + case MyClass(a1: Int, a2: String, b: String) => a1 // a2: String is invalid + ^ \ No newline at end of file diff --git a/frontends/benchmarks/extraction/valid/StateMonadTypeRefIssue.scala b/frontends/benchmarks/extraction/valid/StateMonadTypeRefIssue.scala new file mode 100644 index 0000000000..fbc8d3825f --- /dev/null +++ b/frontends/benchmarks/extraction/valid/StateMonadTypeRefIssue.scala @@ -0,0 +1,38 @@ +object StateMonad { + + case class St[A,S](fun: S => (A,S)) { + def ^=(that: St[A,S])(implicit anyS: S): St[A,S] = { + assert(this.fun(anyS) == that.fun(anyS)) + that + } + + def qed: Unit = + ??? + + def flatMap[B](f: A => St[B,S]): St[B,S] = + St[B,S] { + s0 => + val (a,s1) = fun(s0) + f(a).fun(s1) + } + + def unit[A,S](a: A) = + St[A,S] { + (s:S) => (a,s) + } + + def leftUnit[A,S,B](a: A, f: A => St[B,S])(implicit anyS: S): Unit = { + (unit(a).flatMap(f) ^= + St((s:S) => (a:A,s:S)).flatMap(f) ^= + St { (s0:S) => + val (a1:A,s1:S) = ((s:S) => (a,s))(s0) + f(a1).fun(s1) } ^= + St { s0 => + val (a1:A,s1:S) = (a,s0) + f(a1).fun(s1) } ^= + St(s0 => f(a).fun(s0)) ^= + f(a)) + .qed + }.ensuring(_ => unit(a).flatMap(f) == f(a)) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/verification/valid/RedundantlyTypedPatterns.scala b/frontends/benchmarks/verification/valid/RedundantlyTypedPatterns.scala new file mode 100644 index 0000000000..8bb2964d95 --- /dev/null +++ b/frontends/benchmarks/verification/valid/RedundantlyTypedPatterns.scala @@ -0,0 +1,41 @@ +import stainless.lang._ + +object RedundantlyTypedPatterns { + + case class MyClass[S, T](s1: S, s2: S, t: T) + + case class MyOtherClass[S, T](s1: S, s2: S, t: T) + + case class YetAnotherClass(i1: Int, s: String, i2: Int) + + object MyOtherClass { + def unapply[S, T](moc: MyOtherClass[S, T]): Option[(S, MyOtherClass[S, T])] = Some((moc.s1, moc)) + } + + + def test1[A, B](moc: MyOtherClass[A, B]): A = moc match { + case MyOtherClass(s1: A, MyOtherClass(s2: A, moc2: MyOtherClass[A, B])) => s1 + } + + def test2[A, B](moc: MyOtherClass[A, B]): A = moc match { + case MyOtherClass(s1, MyOtherClass(s2, moc2)) => s1 + } + + def test3[A, B](a: A, b: B): Unit = { + val (aa1: A, bb: B) = (a, b) + val (aa2: A, ab1: (A, B)) = (a, (a, b)) + val (ab2: (A, B), aa3: A) = ((a, b), a) + } + + def test4[A, B](mc: MyClass[A, B]): A = mc match { + case MyClass(a1: A, a2: A, b: B) => a1 + } + + def test5(mc: MyClass[Int, String]): Int = mc match { + case MyClass(a1: Int, a2: Int, b: String) => a1 + } + + def test6(yac: YetAnotherClass): Int = yac match { + case YetAnotherClass(a1: Int, a2: String, a3: Int) => a1 + } +} \ No newline at end of file diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala index ce943d6faf..8a1fe8a8ca 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala @@ -15,6 +15,7 @@ import core.StdNames._ import core.Symbols._ import core.Types._ import core.Flags._ +import core.Constants._ import core.NameKinds import util.{NoSourcePosition, SourcePosition} import stainless.ast.SymbolIdentifier @@ -896,24 +897,32 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using }._2 } - private def extractPattern(p: tpd.Tree, binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match { + // Note: `expectedTpe` is used to check for redundant type checks that can appear in some patterns. + // For instance, the following expression (assuming a: A and b: B) is a valid pattern: + // val (aa: A, bb: B) = (a, b) + // When we recursively traverse the pattern aa: A and bb: B, we set `expectedTpe` to be `A` and `B` respectively + // since these type tests are redundant. If we do not do so, we would be falling into an "Unsupported pattern" error. + // Note that this pattern will be correctly rejected as "Unsupported pattern" (in fact, it cannot be even tested at runtime): + // val (aa: B, bb: B) = (a, b) + private def extractPattern(p: tpd.Tree, expectedTpe: Option[xt.Type], binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match { case b @ Bind(name, t @ Typed(pat, tpt)) => val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(tpt), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.sourcePos) val pctx = dctx.withNewVar(b.symbol -> (() => vd.toVariable)) - extractPattern(t, Some(vd))(using pctx) + extractPattern(t, expectedTpe, Some(vd))(using pctx) case b @ Bind(name, pat) => val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(b), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.sourcePos) val pctx = dctx.withNewVar(b.symbol -> (() => vd.toVariable)) - extractPattern(pat, Some(vd))(using pctx) + extractPattern(pat, expectedTpe, Some(vd))(using pctx) case t @ Typed(Ident(nme.WILDCARD), tpt) => extractType(tpt)(using dctx.setResolveTypes(true)) match { case ct: xt.ClassType => (xt.InstanceOfPattern(binder, ct).setPos(p.sourcePos), dctx) - + case lt if expectedTpe.contains(lt) => + (xt.WildcardPattern(binder), dctx) case lt => - outOfSubsetError(tpt, "Invalid type "+tpt.tpe+" for .isInstanceOf") + outOfSubsetError(p, s"Unsupported pattern: ${p.show}") } case Ident(nme.WILDCARD) => @@ -924,7 +933,7 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using case ct: xt.ClassType => (xt.ClassPattern(binder, ct, Seq()).setPos(p.sourcePos), dctx) case _ => - outOfSubsetError(s, "Invalid instance pattern: "+s) + outOfSubsetError(p, s"Unsupported pattern: ${p.show}") } case id @ Ident(_) if id.symbol isOneOf (Case | Module) => @@ -932,23 +941,7 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using case ct: xt.ClassType => (xt.ClassPattern(binder, ct, Seq()).setPos(p.sourcePos), dctx) case _ => - outOfSubsetError(id, "Invalid instance pattern: "+id) - } - - case a @ Apply(fn, args) => - extractType(a)(using dctx.setResolveTypes(true)) match { - case ct: xt.ClassType => - val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip - val nctx = subDctx.foldLeft(dctx)(_ union _) - (xt.ClassPattern(binder, ct, subPatterns).setPos(p.sourcePos), nctx) - - case xt.TupleType(argsTpes) => - val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip - val nctx = subDctx.foldLeft(dctx)(_ union _) - (xt.TuplePattern(binder, subPatterns).setPos(p.sourcePos), nctx) - - case _ => - outOfSubsetError(a, "Invalid type "+a.tpe+" for .isInstanceOf") + outOfSubsetError(p, s"Unsupported pattern: ${p.show}") } case ExBigIntPattern(l) => @@ -963,8 +956,10 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using case ExUnitLiteral() => (xt.LiteralPattern(binder, xt.UnitLiteral()), dctx) case ExStringLiteral(s) => (xt.LiteralPattern(binder, xt.StringLiteral(s)), dctx) - case t @ Typed(UnApply(f, _, pats), tp) => - val (subPatterns, subDctx) = pats.map(extractPattern(_)).unzip + case t @ Typed(un@UnApply(f, _, pats), tp) => + val subPatTps = resolveUnapplySubPatternsTps(un) + assert(subPatTps.size == pats.size) + val (subPatterns, subDctx) = pats.zip(subPatTps).map { case (pat, tp) => extractPattern(pat, Some(tp)) }.unzip val nctx = subDctx.foldLeft(dctx)(_ union _) val sym = f.symbol @@ -978,8 +973,10 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using (xt.UnapplyPattern(binder, Seq(), id, tps, subPatterns).setPos(t.sourcePos), nctx) } - case UnApply(f, _, pats) => - val (subPatterns, subDctx) = pats.map(extractPattern(_)).unzip + case un@UnApply(f, _, pats) => + val subPatTps = resolveUnapplySubPatternsTps(un) + assert(subPatTps.size == pats.size) + val (subPatterns, subDctx) = pats.zip(subPatTps).map { case (pat, tp) => extractPattern(pat, Some(tp)) }.unzip val nctx = subDctx.foldLeft(dctx)(_ union _) val sym = f.symbol @@ -997,11 +994,58 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using } case _ => - outOfSubsetError(p, "Unsupported pattern: "+p) + outOfSubsetError(p, s"Unsupported pattern: ${p.show}") + } + + private def resolveUnapplySubPatternsTps(un: tpd.UnApply)(using dctx: DefContext): Seq[xt.Type] = { + def classFieldsAccessors(tpe: Type): Seq[Type] = { + // We only keep the fields of the constructor (and disregard the inherited ones) + val fields = tpe.fields.filter { denot => + val sym = denot.symbol + !sym.is(Accessor) && (sym.is(ParamAccessor) || sym.is(CaseAccessor)) + } + fields.map(_.info) + } + def resolve(resTpe: Type): Seq[xt.Type] = { + val subPatTps = resTpe match { + // The return type is Option[T] where T may be a tuple - in which case we flatten it + case AppliedType(opt, List(underlying)) if opt.typeSymbol == optionClassSym || opt.typeSymbol == optionSymbol => + underlying match { + case at@AppliedType(tr: TypeRef, tps) if TupleSymbol.unapply(tr.classSymbol).isDefined => + val AppliedType(_, theTps) = at.dealias: @unchecked + theTps + case _ => Seq(underlying) + } + // The following two cases are for patterns that do not "return" any value such as None + case ConstantType(Constant(true)) => Seq.empty + case _ if resTpe.typeSymbol == defn.BooleanClass => Seq.empty + // The following two cases are for ADT patterns such as Left/Right .unapply which are typically compiler-generated + case at@AppliedType(tt@TypeRef(_, _), args) if tt.symbol.isClass => + classFieldsAccessors(at) + case tt@TypeRef(_, _) if tt.symbol.isClass => + classFieldsAccessors(tt) + case _ => + outOfSubsetError(un, s"Unsupported pattern: ${un.show}") + } + subPatTps.map(extractType(_)(using dctx, un.sourcePos)) + } + un.fun.tpe match { + case mt: MethodType => resolve(mt.resultType) + case tr: TermRef => + // If we have a TermRef, this `unapply` method does not take type parameter. + // We can unveil its underlying type with `info` (and cry if we get something else...) + tr.info match { + case mt: MethodType => resolve(mt.resultType) + case _ => + outOfSubsetError(un, s"Unsupported pattern: ${un.show}") + } + case _ => + outOfSubsetError(un, s"Unsupported pattern: ${un.show}") + } } private def extractMatchCase(cd: tpd.CaseDef)(using dctx: DefContext): xt.MatchCase = { - val (recPattern, ndctx) = extractPattern(cd.pat) + val (recPattern, ndctx) = extractPattern(cd.pat, None) val recBody = extractTree(cd.body)(using ndctx) if (cd.guard == tpd.EmptyTree) { @@ -1578,7 +1622,7 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using case e => (xt.TupleSelect(e, 1).setPos(e), xt.TupleSelect(e, 2).setPos(e)) }, extractType(tpt)) - case ExClassConstruction(tpe, args) => + case ExClassConstruction(tpe, args) => extractType(tpe)(using dctx, tr.sourcePos) match { case lct: xt.LocalClassType => xt.LocalClassConstructor(lct, args map extractTree) case ct: xt.ClassType => xt.ClassConstructor(ct, args map extractTree) diff --git a/frontends/scalac/src/main/scala/stainless/frontends/scalac/ASTExtractors.scala b/frontends/scalac/src/main/scala/stainless/frontends/scalac/ASTExtractors.scala index 0a5eaeadbf..ce1635a29e 100644 --- a/frontends/scalac/src/main/scala/stainless/frontends/scalac/ASTExtractors.scala +++ b/frontends/scalac/src/main/scala/stainless/frontends/scalac/ASTExtractors.scala @@ -113,6 +113,7 @@ trait ASTExtractors { protected lazy val arraySym = classFromName("scala.Array") protected lazy val someClassSym = classFromName("scala.Some") + protected lazy val optionClassSym = classFromName("scala.Option") protected lazy val byNameSym = classFromName("scala.") protected lazy val bigIntSym = classFromName("scala.math.BigInt") protected lazy val stringSym = classFromName("java.lang.String") diff --git a/frontends/scalac/src/main/scala/stainless/frontends/scalac/CodeExtraction.scala b/frontends/scalac/src/main/scala/stainless/frontends/scalac/CodeExtraction.scala index b2289d45b6..13c1f42b98 100644 --- a/frontends/scalac/src/main/scala/stainless/frontends/scalac/CodeExtraction.scala +++ b/frontends/scalac/src/main/scala/stainless/frontends/scalac/CodeExtraction.scala @@ -745,24 +745,32 @@ trait CodeExtraction extends ASTExtractors { }._2 } - private def extractPattern(p: Tree, binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match { + // Note: `expectedTpe` is used to check for redundant type checks that can appear in some patterns. + // For instance, the following expression (assuming a: A and b: B) is a valid pattern: + // val (aa: A, bb: B) = (a, b) + // When we recursively traverse the pattern aa: A and bb: B, we set `expectedTpe` to be `A` and `B` respectively + // since these type tests are redundant. If we do not do so, we would be falling into an "Unsupported pattern" error. + // Note that this pattern will be correctly rejected as "Unsupported pattern" (in fact, it cannot be even tested at runtime): + // val (aa: B, bb: B) = (a, b) + private def extractPattern(p: Tree, expectedTpe: Option[xt.Type], binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match { case b @ Bind(name, t @ Typed(pat, tpt)) => val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(tpt), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.pos) val pctx = dctx.withNewVar(b.symbol, () => vd.toVariable) - extractPattern(t, Some(vd))(using pctx) + extractPattern(t, expectedTpe, Some(vd))(using pctx) case b @ Bind(name, pat) => val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(b), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.pos) val pctx = dctx.withNewVar(b.symbol, () => vd.toVariable) - extractPattern(pat, Some(vd))(using pctx) + extractPattern(pat, expectedTpe, Some(vd))(using pctx) case t @ Typed(Ident(nme.WILDCARD), tpt) => extractType(tpt) match { case ct: xt.ClassType => (xt.InstanceOfPattern(binder, ct).setPos(p.pos), dctx) - + case lt if expectedTpe.contains(lt) => + (xt.WildcardPattern(binder), dctx) case lt => - outOfSubsetError(tpt, "Invalid type "+tpt.tpe+" for .isInstanceOf") + outOfSubsetError(p, s"Unsupported pattern: $p") } case Ident(nme.WILDCARD) => @@ -773,7 +781,7 @@ trait CodeExtraction extends ASTExtractors { case ct: xt.ClassType => (xt.ClassPattern(binder, ct, Seq()).setPos(p.pos), dctx) case _ => - outOfSubsetError(s, "Invalid instance pattern: " + s) + outOfSubsetError(p, s"Unsupported pattern: $p") } case id @ Ident(_) if id.tpe.typeSymbol.isCase => @@ -781,23 +789,31 @@ trait CodeExtraction extends ASTExtractors { case ct: xt.ClassType => (xt.ClassPattern(binder, ct, Seq()).setPos(p.pos), dctx) case _ => - outOfSubsetError(id, "Invalid instance pattern: " + id) + outOfSubsetError(p, s"Unsupported pattern: $p") } case a @ Apply(fn, args) => extractType(a) match { case ct: xt.ClassType => - val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip + val TypeRef(_, sym, targs) = a.tpe: @unchecked + val fields = sym.constrParamAccessors + val subPatTps = fields.map { fld => + val substedTpe = fld.info.subst(sym.typeParams, targs) + extractType(substedTpe)(using dctx, a.pos) + } + Predef.assert(subPatTps.size == args.size) + val (subPatterns, subDctx) = args.zip(subPatTps).map { case (arg, tpe) => extractPattern(arg, Some(tpe)) }.unzip val nctx = subDctx.foldLeft(dctx)(_ union _) (xt.ClassPattern(binder, ct, subPatterns).setPos(p.pos), nctx) case xt.TupleType(argsTpes) => - val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip + Predef.assert(args.size == argsTpes.size) + val (subPatterns, subDctx) = args.zip(argsTpes).map { case (arg, tpe) => extractPattern(arg, Some(tpe)) }.unzip val nctx = subDctx.foldLeft(dctx)(_ union _) (xt.TuplePattern(binder, subPatterns).setPos(p.pos), nctx) case _ => - outOfSubsetError(a, "Invalid type "+a.tpe+" for .isInstanceOf") + outOfSubsetError(p, s"Unsupported pattern: $p") } case ExBigIntPattern(n: Literal) => @@ -812,11 +828,13 @@ trait CodeExtraction extends ASTExtractors { case ExUnitLiteral() => (xt.LiteralPattern(binder, xt.UnitLiteral()), dctx) case ExStringLiteral(s) => (xt.LiteralPattern(binder, xt.StringLiteral(s)), dctx) - case up @ ExUnapplyPattern(t, args) => - val (sub, ctx) = args.map (extractPattern(_)).unzip + case up @ ExUnapplyPattern(fun, args) => + val subPatTps = resolveUnapplySubPatternsTps(up, fun) + Predef.assert(subPatTps.size == args.size) + val (sub, ctx) = args.zip(subPatTps).map { case (pat, tp) => extractPattern(pat, Some(tp)) }.unzip val nctx = ctx.foldLeft(dctx)(_ union _) - val id = getIdentifier(t.symbol) - val tps = t match { + val id = getIdentifier(fun.symbol) + val tps = fun match { case TypeApply(_, tps) => tps.map(extractType) case _ => Seq.empty } @@ -824,11 +842,27 @@ trait CodeExtraction extends ASTExtractors { (xt.UnapplyPattern(binder, Seq(), id, tps, sub).setPos(up.pos), ctx.foldLeft(dctx)(_ union _)) case _ => - outOfSubsetError(p, "Unsupported pattern: " + p) + outOfSubsetError(p, s"Unsupported pattern: $p") + } + + private def resolveUnapplySubPatternsTps(un: Tree, fun: Tree)(using dctx: DefContext): Seq[xt.Type] = { + val subPatTps = fun.tpe match { + case mt: MethodType => + mt.resultType match { + case TypeRef(_, sym, List(underlying)) if sym == optionSymbol || sym == optionClassSym => + underlying match { + case TypeRef(_, sym, tps) if isTuple(sym, tps.size) => tps + case _ => Seq(underlying) + } + case _ => outOfSubsetError(un, s"Unsupported pattern: $un") + } + case _ => outOfSubsetError(un, s"Unsupported pattern: $un") + } + subPatTps.map(extractType(_)(using dctx, un.pos)) } private def extractMatchCase(cd: CaseDef)(using DefContext): xt.MatchCase = { - val (recPattern, ndctx) = extractPattern(cd.pat) + val (recPattern, ndctx) = extractPattern(cd.pat, None) val recBody = extractTree(cd.body)(using ndctx) if(cd.guard == EmptyTree) { @@ -1279,7 +1313,7 @@ trait CodeExtraction extends ASTExtractors { case swap @ ExSwapExpression(array1, index1, array2, index2) => xt.Swap(extractTree(array1), extractTree(index1), extractTree(array2), extractTree(index2)) - + case cellSwap @ ExCellSwapExpression(cell1, cell2) => xt.CellSwap(extractTree(cell1), extractTree(cell2))