Skip to content

Commit

Permalink
Allow for redundant type checks for pattern matching (#1489)
Browse files Browse the repository at this point in the history
  • Loading branch information
mario-bucev authored Dec 1, 2023
1 parent 5cf8cde commit b46f01f
Show file tree
Hide file tree
Showing 17 changed files with 263 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns1.scala:13:47: Unsupported pattern: _:B
case MyOtherClass(s1: A, MyOtherClass(s2: B, moc2: MyOtherClass[A, B])) => s1 // s2: B is invalid
^
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import stainless.lang._

object InvalidTypedPatterns1 {

case class MyOtherClass[S, T](s1: S, s2: S, t: T)

object MyOtherClass {
def unapply[S, T](moc: MyOtherClass[S, T]): Option[(S, MyOtherClass[S, T])] = Some((moc.s1, moc))
}


def test[A, B](moc: MyOtherClass[A, B]): A = moc match {
case MyOtherClass(s1: A, MyOtherClass(s2: B, moc2: MyOtherClass[A, B])) => s1 // s2: B is invalid
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns1.scala:13:45: Unsupported pattern: (_: B)
case MyOtherClass(s1: A, MyOtherClass(s2: B, moc2: MyOtherClass[A, B])) => s1 // s2: B is invalid
^^^
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[ Error ] Stainless does not support type B & A
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
object InvalidTypedPatterns2 {
def test[A, B](a: A, b: B): Unit = {
val (aa1: A, bb: A) = (a, b) // bb: A is invalid
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns2.scala:3:20: Unsupported pattern: (_: A)
val (aa1: A, bb: A) = (a, b) // bb: A is invalid
^^^
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns3.scala:5:29: Unsupported pattern: _:B
case MyClass(a1: A, a2: B, b: B) => a1 // a2: B is invalid
^
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
object InvalidTypedPatterns3 {
case class MyClass[S, T](s1: S, s2: S, t: T)

def test[A, B](mc: MyClass[A, B]): A = mc match {
case MyClass(a1: A, a2: B, b: B) => a1 // a2: B is invalid
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns3.scala:5:27: Unsupported pattern: (_: B)
case MyClass(a1: A, a2: B, b: B) => a1 // a2: B is invalid
^^^
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns4.scala:5:31: Unsupported pattern: _:String
case MyClass(a1: Int, a2: String, b: String) => a1 // a2: String is invalid
^^^^^^
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
object InvalidTypedPatterns4 {
case class MyClass[S, T](s1: S, s2: S, t: T)

def test5(mc: MyClass[Int, String]): Int = mc match {
case MyClass(a1: Int, a2: String, b: String) => a1 // a2: String is invalid
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[ Error ] InvalidTypedPatterns4.scala:5:31: scrutinee is incompatible with pattern type;
[ Error ] found : String
[ Error ] required: Int
case MyClass(a1: Int, a2: String, b: String) => a1 // a2: String is invalid
^
38 changes: 38 additions & 0 deletions frontends/benchmarks/extraction/valid/StateMonadTypeRefIssue.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
object StateMonad {

case class St[A,S](fun: S => (A,S)) {
def ^=(that: St[A,S])(implicit anyS: S): St[A,S] = {
assert(this.fun(anyS) == that.fun(anyS))
that
}

def qed: Unit =
???

def flatMap[B](f: A => St[B,S]): St[B,S] =
St[B,S] {
s0 =>
val (a,s1) = fun(s0)
f(a).fun(s1)
}

def unit[A,S](a: A) =
St[A,S] {
(s:S) => (a,s)
}

def leftUnit[A,S,B](a: A, f: A => St[B,S])(implicit anyS: S): Unit = {
(unit(a).flatMap(f) ^=
St((s:S) => (a:A,s:S)).flatMap(f) ^=
St { (s0:S) =>
val (a1:A,s1:S) = ((s:S) => (a,s))(s0)
f(a1).fun(s1) } ^=
St { s0 =>
val (a1:A,s1:S) = (a,s0)
f(a1).fun(s1) } ^=
St(s0 => f(a).fun(s0)) ^=
f(a))
.qed
}.ensuring(_ => unit(a).flatMap(f) == f(a))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import stainless.lang._

object RedundantlyTypedPatterns {

case class MyClass[S, T](s1: S, s2: S, t: T)

case class MyOtherClass[S, T](s1: S, s2: S, t: T)

case class YetAnotherClass(i1: Int, s: String, i2: Int)

object MyOtherClass {
def unapply[S, T](moc: MyOtherClass[S, T]): Option[(S, MyOtherClass[S, T])] = Some((moc.s1, moc))
}


def test1[A, B](moc: MyOtherClass[A, B]): A = moc match {
case MyOtherClass(s1: A, MyOtherClass(s2: A, moc2: MyOtherClass[A, B])) => s1
}

def test2[A, B](moc: MyOtherClass[A, B]): A = moc match {
case MyOtherClass(s1, MyOtherClass(s2, moc2)) => s1
}

def test3[A, B](a: A, b: B): Unit = {
val (aa1: A, bb: B) = (a, b)
val (aa2: A, ab1: (A, B)) = (a, (a, b))
val (ab2: (A, B), aa3: A) = ((a, b), a)
}

def test4[A, B](mc: MyClass[A, B]): A = mc match {
case MyClass(a1: A, a2: A, b: B) => a1
}

def test5(mc: MyClass[Int, String]): Int = mc match {
case MyClass(a1: Int, a2: Int, b: String) => a1
}

def test6(yac: YetAnotherClass): Int = yac match {
case YetAnotherClass(a1: Int, a2: String, a3: Int) => a1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import core.StdNames._
import core.Symbols._
import core.Types._
import core.Flags._
import core.Constants._
import core.NameKinds
import util.{NoSourcePosition, SourcePosition}
import stainless.ast.SymbolIdentifier
Expand Down Expand Up @@ -896,24 +897,32 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
}._2
}

private def extractPattern(p: tpd.Tree, binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match {
// Note: `expectedTpe` is used to check for redundant type checks that can appear in some patterns.
// For instance, the following expression (assuming a: A and b: B) is a valid pattern:
// val (aa: A, bb: B) = (a, b)
// When we recursively traverse the pattern aa: A and bb: B, we set `expectedTpe` to be `A` and `B` respectively
// since these type tests are redundant. If we do not do so, we would be falling into an "Unsupported pattern" error.
// Note that this pattern will be correctly rejected as "Unsupported pattern" (in fact, it cannot be even tested at runtime):
// val (aa: B, bb: B) = (a, b)
private def extractPattern(p: tpd.Tree, expectedTpe: Option[xt.Type], binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match {
case b @ Bind(name, t @ Typed(pat, tpt)) =>
val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(tpt), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.sourcePos)
val pctx = dctx.withNewVar(b.symbol -> (() => vd.toVariable))
extractPattern(t, Some(vd))(using pctx)
extractPattern(t, expectedTpe, Some(vd))(using pctx)

case b @ Bind(name, pat) =>
val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(b), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.sourcePos)
val pctx = dctx.withNewVar(b.symbol -> (() => vd.toVariable))
extractPattern(pat, Some(vd))(using pctx)
extractPattern(pat, expectedTpe, Some(vd))(using pctx)

case t @ Typed(Ident(nme.WILDCARD), tpt) =>
extractType(tpt)(using dctx.setResolveTypes(true)) match {
case ct: xt.ClassType =>
(xt.InstanceOfPattern(binder, ct).setPos(p.sourcePos), dctx)

case lt if expectedTpe.contains(lt) =>
(xt.WildcardPattern(binder), dctx)
case lt =>
outOfSubsetError(tpt, "Invalid type "+tpt.tpe+" for .isInstanceOf")
outOfSubsetError(p, s"Unsupported pattern: ${p.show}")
}

case Ident(nme.WILDCARD) =>
Expand All @@ -924,31 +933,15 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
case ct: xt.ClassType =>
(xt.ClassPattern(binder, ct, Seq()).setPos(p.sourcePos), dctx)
case _ =>
outOfSubsetError(s, "Invalid instance pattern: "+s)
outOfSubsetError(p, s"Unsupported pattern: ${p.show}")
}

case id @ Ident(_) if id.symbol isOneOf (Case | Module) =>
extractType(id)(using dctx.setResolveTypes(true)) match {
case ct: xt.ClassType =>
(xt.ClassPattern(binder, ct, Seq()).setPos(p.sourcePos), dctx)
case _ =>
outOfSubsetError(id, "Invalid instance pattern: "+id)
}

case a @ Apply(fn, args) =>
extractType(a)(using dctx.setResolveTypes(true)) match {
case ct: xt.ClassType =>
val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip
val nctx = subDctx.foldLeft(dctx)(_ union _)
(xt.ClassPattern(binder, ct, subPatterns).setPos(p.sourcePos), nctx)

case xt.TupleType(argsTpes) =>
val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip
val nctx = subDctx.foldLeft(dctx)(_ union _)
(xt.TuplePattern(binder, subPatterns).setPos(p.sourcePos), nctx)

case _ =>
outOfSubsetError(a, "Invalid type "+a.tpe+" for .isInstanceOf")
outOfSubsetError(p, s"Unsupported pattern: ${p.show}")
}

case ExBigIntPattern(l) =>
Expand All @@ -963,8 +956,10 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
case ExUnitLiteral() => (xt.LiteralPattern(binder, xt.UnitLiteral()), dctx)
case ExStringLiteral(s) => (xt.LiteralPattern(binder, xt.StringLiteral(s)), dctx)

case t @ Typed(UnApply(f, _, pats), tp) =>
val (subPatterns, subDctx) = pats.map(extractPattern(_)).unzip
case t @ Typed(un@UnApply(f, _, pats), tp) =>
val subPatTps = resolveUnapplySubPatternsTps(un)
assert(subPatTps.size == pats.size)
val (subPatterns, subDctx) = pats.zip(subPatTps).map { case (pat, tp) => extractPattern(pat, Some(tp)) }.unzip
val nctx = subDctx.foldLeft(dctx)(_ union _)

val sym = f.symbol
Expand All @@ -978,8 +973,10 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
(xt.UnapplyPattern(binder, Seq(), id, tps, subPatterns).setPos(t.sourcePos), nctx)
}

case UnApply(f, _, pats) =>
val (subPatterns, subDctx) = pats.map(extractPattern(_)).unzip
case un@UnApply(f, _, pats) =>
val subPatTps = resolveUnapplySubPatternsTps(un)
assert(subPatTps.size == pats.size)
val (subPatterns, subDctx) = pats.zip(subPatTps).map { case (pat, tp) => extractPattern(pat, Some(tp)) }.unzip
val nctx = subDctx.foldLeft(dctx)(_ union _)

val sym = f.symbol
Expand All @@ -997,11 +994,58 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
}

case _ =>
outOfSubsetError(p, "Unsupported pattern: "+p)
outOfSubsetError(p, s"Unsupported pattern: ${p.show}")
}

private def resolveUnapplySubPatternsTps(un: tpd.UnApply)(using dctx: DefContext): Seq[xt.Type] = {
def classFieldsAccessors(tpe: Type): Seq[Type] = {
// We only keep the fields of the constructor (and disregard the inherited ones)
val fields = tpe.fields.filter { denot =>
val sym = denot.symbol
!sym.is(Accessor) && (sym.is(ParamAccessor) || sym.is(CaseAccessor))
}
fields.map(_.info)
}
def resolve(resTpe: Type): Seq[xt.Type] = {
val subPatTps = resTpe match {
// The return type is Option[T] where T may be a tuple - in which case we flatten it
case AppliedType(opt, List(underlying)) if opt.typeSymbol == optionClassSym || opt.typeSymbol == optionSymbol =>
underlying match {
case at@AppliedType(tr: TypeRef, tps) if TupleSymbol.unapply(tr.classSymbol).isDefined =>
val AppliedType(_, theTps) = at.dealias: @unchecked
theTps
case _ => Seq(underlying)
}
// The following two cases are for patterns that do not "return" any value such as None
case ConstantType(Constant(true)) => Seq.empty
case _ if resTpe.typeSymbol == defn.BooleanClass => Seq.empty
// The following two cases are for ADT patterns such as Left/Right .unapply which are typically compiler-generated
case at@AppliedType(tt@TypeRef(_, _), args) if tt.symbol.isClass =>
classFieldsAccessors(at)
case tt@TypeRef(_, _) if tt.symbol.isClass =>
classFieldsAccessors(tt)
case _ =>
outOfSubsetError(un, s"Unsupported pattern: ${un.show}")
}
subPatTps.map(extractType(_)(using dctx, un.sourcePos))
}
un.fun.tpe match {
case mt: MethodType => resolve(mt.resultType)
case tr: TermRef =>
// If we have a TermRef, this `unapply` method does not take type parameter.
// We can unveil its underlying type with `info` (and cry if we get something else...)
tr.info match {
case mt: MethodType => resolve(mt.resultType)
case _ =>
outOfSubsetError(un, s"Unsupported pattern: ${un.show}")
}
case _ =>
outOfSubsetError(un, s"Unsupported pattern: ${un.show}")
}
}

private def extractMatchCase(cd: tpd.CaseDef)(using dctx: DefContext): xt.MatchCase = {
val (recPattern, ndctx) = extractPattern(cd.pat)
val (recPattern, ndctx) = extractPattern(cd.pat, None)
val recBody = extractTree(cd.body)(using ndctx)

if (cd.guard == tpd.EmptyTree) {
Expand Down Expand Up @@ -1578,7 +1622,7 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
case e => (xt.TupleSelect(e, 1).setPos(e), xt.TupleSelect(e, 2).setPos(e))
}, extractType(tpt))

case ExClassConstruction(tpe, args) =>
case ExClassConstruction(tpe, args) =>
extractType(tpe)(using dctx, tr.sourcePos) match {
case lct: xt.LocalClassType => xt.LocalClassConstructor(lct, args map extractTree)
case ct: xt.ClassType => xt.ClassConstructor(ct, args map extractTree)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ trait ASTExtractors {

protected lazy val arraySym = classFromName("scala.Array")
protected lazy val someClassSym = classFromName("scala.Some")
protected lazy val optionClassSym = classFromName("scala.Option")
protected lazy val byNameSym = classFromName("scala.<byname>")
protected lazy val bigIntSym = classFromName("scala.math.BigInt")
protected lazy val stringSym = classFromName("java.lang.String")
Expand Down
Loading

0 comments on commit b46f01f

Please sign in to comment.