Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for redundant type checks for pattern matching #1489

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns1.scala:13:47: Unsupported pattern: _:B
case MyOtherClass(s1: A, MyOtherClass(s2: B, moc2: MyOtherClass[A, B])) => s1 // s2: B is invalid
^
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import stainless.lang._

object InvalidTypedPatterns1 {

case class MyOtherClass[S, T](s1: S, s2: S, t: T)

object MyOtherClass {
def unapply[S, T](moc: MyOtherClass[S, T]): Option[(S, MyOtherClass[S, T])] = Some((moc.s1, moc))
}


def test[A, B](moc: MyOtherClass[A, B]): A = moc match {
case MyOtherClass(s1: A, MyOtherClass(s2: B, moc2: MyOtherClass[A, B])) => s1 // s2: B is invalid
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns1.scala:13:45: Unsupported pattern: (_: B)
case MyOtherClass(s1: A, MyOtherClass(s2: B, moc2: MyOtherClass[A, B])) => s1 // s2: B is invalid
^^^
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[ Error ] Stainless does not support type B & A
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
object InvalidTypedPatterns2 {
def test[A, B](a: A, b: B): Unit = {
val (aa1: A, bb: A) = (a, b) // bb: A is invalid
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns2.scala:3:20: Unsupported pattern: (_: A)
val (aa1: A, bb: A) = (a, b) // bb: A is invalid
^^^
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns3.scala:5:29: Unsupported pattern: _:B
case MyClass(a1: A, a2: B, b: B) => a1 // a2: B is invalid
^
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
object InvalidTypedPatterns3 {
case class MyClass[S, T](s1: S, s2: S, t: T)

def test[A, B](mc: MyClass[A, B]): A = mc match {
case MyClass(a1: A, a2: B, b: B) => a1 // a2: B is invalid
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns3.scala:5:27: Unsupported pattern: (_: B)
case MyClass(a1: A, a2: B, b: B) => a1 // a2: B is invalid
^^^
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[ Error ] InvalidTypedPatterns4.scala:5:31: Unsupported pattern: _:String
case MyClass(a1: Int, a2: String, b: String) => a1 // a2: String is invalid
^^^^^^
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
object InvalidTypedPatterns4 {
case class MyClass[S, T](s1: S, s2: S, t: T)

def test5(mc: MyClass[Int, String]): Int = mc match {
case MyClass(a1: Int, a2: String, b: String) => a1 // a2: String is invalid
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[ Error ] InvalidTypedPatterns4.scala:5:31: scrutinee is incompatible with pattern type;
[ Error ] found : String
[ Error ] required: Int
case MyClass(a1: Int, a2: String, b: String) => a1 // a2: String is invalid
^
38 changes: 38 additions & 0 deletions frontends/benchmarks/extraction/valid/StateMonadTypeRefIssue.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
object StateMonad {

case class St[A,S](fun: S => (A,S)) {
def ^=(that: St[A,S])(implicit anyS: S): St[A,S] = {
assert(this.fun(anyS) == that.fun(anyS))
that
}

def qed: Unit =
???

def flatMap[B](f: A => St[B,S]): St[B,S] =
St[B,S] {
s0 =>
val (a,s1) = fun(s0)
f(a).fun(s1)
}

def unit[A,S](a: A) =
St[A,S] {
(s:S) => (a,s)
}

def leftUnit[A,S,B](a: A, f: A => St[B,S])(implicit anyS: S): Unit = {
(unit(a).flatMap(f) ^=
St((s:S) => (a:A,s:S)).flatMap(f) ^=
St { (s0:S) =>
val (a1:A,s1:S) = ((s:S) => (a,s))(s0)
f(a1).fun(s1) } ^=
St { s0 =>
val (a1:A,s1:S) = (a,s0)
f(a1).fun(s1) } ^=
St(s0 => f(a).fun(s0)) ^=
f(a))
.qed
}.ensuring(_ => unit(a).flatMap(f) == f(a))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import stainless.lang._

object RedundantlyTypedPatterns {

case class MyClass[S, T](s1: S, s2: S, t: T)

case class MyOtherClass[S, T](s1: S, s2: S, t: T)

case class YetAnotherClass(i1: Int, s: String, i2: Int)

object MyOtherClass {
def unapply[S, T](moc: MyOtherClass[S, T]): Option[(S, MyOtherClass[S, T])] = Some((moc.s1, moc))
}


def test1[A, B](moc: MyOtherClass[A, B]): A = moc match {
case MyOtherClass(s1: A, MyOtherClass(s2: A, moc2: MyOtherClass[A, B])) => s1
}

def test2[A, B](moc: MyOtherClass[A, B]): A = moc match {
case MyOtherClass(s1, MyOtherClass(s2, moc2)) => s1
}

def test3[A, B](a: A, b: B): Unit = {
val (aa1: A, bb: B) = (a, b)
val (aa2: A, ab1: (A, B)) = (a, (a, b))
val (ab2: (A, B), aa3: A) = ((a, b), a)
}

def test4[A, B](mc: MyClass[A, B]): A = mc match {
case MyClass(a1: A, a2: A, b: B) => a1
}

def test5(mc: MyClass[Int, String]): Int = mc match {
case MyClass(a1: Int, a2: Int, b: String) => a1
}

def test6(yac: YetAnotherClass): Int = yac match {
case YetAnotherClass(a1: Int, a2: String, a3: Int) => a1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import core.StdNames._
import core.Symbols._
import core.Types._
import core.Flags._
import core.Constants._
import core.NameKinds
import util.{NoSourcePosition, SourcePosition}
import stainless.ast.SymbolIdentifier
Expand Down Expand Up @@ -896,24 +897,32 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
}._2
}

private def extractPattern(p: tpd.Tree, binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match {
// Note: `expectedTpe` is used to check for redundant type checks that can appear in some patterns.
// For instance, the following expression (assuming a: A and b: B) is a valid pattern:
// val (aa: A, bb: B) = (a, b)
// When we recursively traverse the pattern aa: A and bb: B, we set `expectedTpe` to be `A` and `B` respectively
// since these type tests are redundant. If we do not do so, we would be falling into an "Unsupported pattern" error.
// Note that this pattern will be correctly rejected as "Unsupported pattern" (in fact, it cannot be even tested at runtime):
// val (aa: B, bb: B) = (a, b)
private def extractPattern(p: tpd.Tree, expectedTpe: Option[xt.Type], binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match {
case b @ Bind(name, t @ Typed(pat, tpt)) =>
val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(tpt), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.sourcePos)
val pctx = dctx.withNewVar(b.symbol -> (() => vd.toVariable))
extractPattern(t, Some(vd))(using pctx)
extractPattern(t, expectedTpe, Some(vd))(using pctx)

case b @ Bind(name, pat) =>
val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(b), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.sourcePos)
val pctx = dctx.withNewVar(b.symbol -> (() => vd.toVariable))
extractPattern(pat, Some(vd))(using pctx)
extractPattern(pat, expectedTpe, Some(vd))(using pctx)

case t @ Typed(Ident(nme.WILDCARD), tpt) =>
extractType(tpt)(using dctx.setResolveTypes(true)) match {
case ct: xt.ClassType =>
(xt.InstanceOfPattern(binder, ct).setPos(p.sourcePos), dctx)

case lt if expectedTpe.contains(lt) =>
(xt.WildcardPattern(binder), dctx)
case lt =>
outOfSubsetError(tpt, "Invalid type "+tpt.tpe+" for .isInstanceOf")
outOfSubsetError(p, s"Unsupported pattern: ${p.show}")
}

case Ident(nme.WILDCARD) =>
Expand All @@ -924,31 +933,15 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
case ct: xt.ClassType =>
(xt.ClassPattern(binder, ct, Seq()).setPos(p.sourcePos), dctx)
case _ =>
outOfSubsetError(s, "Invalid instance pattern: "+s)
outOfSubsetError(p, s"Unsupported pattern: ${p.show}")
}

case id @ Ident(_) if id.symbol isOneOf (Case | Module) =>
extractType(id)(using dctx.setResolveTypes(true)) match {
case ct: xt.ClassType =>
(xt.ClassPattern(binder, ct, Seq()).setPos(p.sourcePos), dctx)
case _ =>
outOfSubsetError(id, "Invalid instance pattern: "+id)
}

case a @ Apply(fn, args) =>
extractType(a)(using dctx.setResolveTypes(true)) match {
case ct: xt.ClassType =>
val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip
val nctx = subDctx.foldLeft(dctx)(_ union _)
(xt.ClassPattern(binder, ct, subPatterns).setPos(p.sourcePos), nctx)

case xt.TupleType(argsTpes) =>
val (subPatterns, subDctx) = args.map(extractPattern(_)).unzip
val nctx = subDctx.foldLeft(dctx)(_ union _)
(xt.TuplePattern(binder, subPatterns).setPos(p.sourcePos), nctx)

case _ =>
outOfSubsetError(a, "Invalid type "+a.tpe+" for .isInstanceOf")
outOfSubsetError(p, s"Unsupported pattern: ${p.show}")
}

case ExBigIntPattern(l) =>
Expand All @@ -963,8 +956,10 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
case ExUnitLiteral() => (xt.LiteralPattern(binder, xt.UnitLiteral()), dctx)
case ExStringLiteral(s) => (xt.LiteralPattern(binder, xt.StringLiteral(s)), dctx)

case t @ Typed(UnApply(f, _, pats), tp) =>
val (subPatterns, subDctx) = pats.map(extractPattern(_)).unzip
case t @ Typed(un@UnApply(f, _, pats), tp) =>
val subPatTps = resolveUnapplySubPatternsTps(un)
assert(subPatTps.size == pats.size)
val (subPatterns, subDctx) = pats.zip(subPatTps).map { case (pat, tp) => extractPattern(pat, Some(tp)) }.unzip
val nctx = subDctx.foldLeft(dctx)(_ union _)

val sym = f.symbol
Expand All @@ -978,8 +973,10 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
(xt.UnapplyPattern(binder, Seq(), id, tps, subPatterns).setPos(t.sourcePos), nctx)
}

case UnApply(f, _, pats) =>
val (subPatterns, subDctx) = pats.map(extractPattern(_)).unzip
case un@UnApply(f, _, pats) =>
val subPatTps = resolveUnapplySubPatternsTps(un)
assert(subPatTps.size == pats.size)
val (subPatterns, subDctx) = pats.zip(subPatTps).map { case (pat, tp) => extractPattern(pat, Some(tp)) }.unzip
val nctx = subDctx.foldLeft(dctx)(_ union _)

val sym = f.symbol
Expand All @@ -997,11 +994,58 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
}

case _ =>
outOfSubsetError(p, "Unsupported pattern: "+p)
outOfSubsetError(p, s"Unsupported pattern: ${p.show}")
}

private def resolveUnapplySubPatternsTps(un: tpd.UnApply)(using dctx: DefContext): Seq[xt.Type] = {
def classFieldsAccessors(tpe: Type): Seq[Type] = {
// We only keep the fields of the constructor (and disregard the inherited ones)
val fields = tpe.fields.filter { denot =>
val sym = denot.symbol
!sym.is(Accessor) && (sym.is(ParamAccessor) || sym.is(CaseAccessor))
}
fields.map(_.info)
}
def resolve(resTpe: Type): Seq[xt.Type] = {
val subPatTps = resTpe match {
// The return type is Option[T] where T may be a tuple - in which case we flatten it
case AppliedType(opt, List(underlying)) if opt.typeSymbol == optionClassSym || opt.typeSymbol == optionSymbol =>
underlying match {
case at@AppliedType(tr: TypeRef, tps) if TupleSymbol.unapply(tr.classSymbol).isDefined =>
val AppliedType(_, theTps) = at.dealias: @unchecked
theTps
case _ => Seq(underlying)
}
// The following two cases are for patterns that do not "return" any value such as None
case ConstantType(Constant(true)) => Seq.empty
case _ if resTpe.typeSymbol == defn.BooleanClass => Seq.empty
// The following two cases are for ADT patterns such as Left/Right .unapply which are typically compiler-generated
case at@AppliedType(tt@TypeRef(_, _), args) if tt.symbol.isClass =>
classFieldsAccessors(at)
case tt@TypeRef(_, _) if tt.symbol.isClass =>
classFieldsAccessors(tt)
case _ =>
outOfSubsetError(un, s"Unsupported pattern: ${un.show}")
}
subPatTps.map(extractType(_)(using dctx, un.sourcePos))
}
un.fun.tpe match {
case mt: MethodType => resolve(mt.resultType)
case tr: TermRef =>
// If we have a TermRef, this `unapply` method does not take type parameter.
// We can unveil its underlying type with `info` (and cry if we get something else...)
tr.info match {
case mt: MethodType => resolve(mt.resultType)
case _ =>
outOfSubsetError(un, s"Unsupported pattern: ${un.show}")
}
case _ =>
outOfSubsetError(un, s"Unsupported pattern: ${un.show}")
}
}

private def extractMatchCase(cd: tpd.CaseDef)(using dctx: DefContext): xt.MatchCase = {
val (recPattern, ndctx) = extractPattern(cd.pat)
val (recPattern, ndctx) = extractPattern(cd.pat, None)
val recBody = extractTree(cd.body)(using ndctx)

if (cd.guard == tpd.EmptyTree) {
Expand Down Expand Up @@ -1578,7 +1622,7 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using
case e => (xt.TupleSelect(e, 1).setPos(e), xt.TupleSelect(e, 2).setPos(e))
}, extractType(tpt))

case ExClassConstruction(tpe, args) =>
case ExClassConstruction(tpe, args) =>
extractType(tpe)(using dctx, tr.sourcePos) match {
case lct: xt.LocalClassType => xt.LocalClassConstructor(lct, args map extractTree)
case ct: xt.ClassType => xt.ClassConstructor(ct, args map extractTree)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ trait ASTExtractors {

protected lazy val arraySym = classFromName("scala.Array")
protected lazy val someClassSym = classFromName("scala.Some")
protected lazy val optionClassSym = classFromName("scala.Option")
protected lazy val byNameSym = classFromName("scala.<byname>")
protected lazy val bigIntSym = classFromName("scala.math.BigInt")
protected lazy val stringSym = classFromName("java.lang.String")
Expand Down
Loading