Skip to content

Commit

Permalink
Applying some type widening in ReturnElimination to avoid triggering …
Browse files Browse the repository at this point in the history
…AdtSpecialization
  • Loading branch information
mario-bucev committed Nov 2, 2023
1 parent c2d0492 commit d9b3471
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ class ReturnElimination(override val s: Trees, override val t: Trees)
case Seq(e) => transform(e, currentType)

case e +: rest if exprHasReturn(e) =>
val firstType = e.getType
val firstTypeChecked = simpleWhileTransformer.transform(e.getType)
val firstType = widenTp(e.getType)
val firstTypeChecked = simpleWhileTransformer.transform(firstType)
val controlFlowVal =
t.ValDef.fresh("cf",
ControlFlowSort.controlFlow(retTypeChecked, firstTypeChecked)
Expand Down Expand Up @@ -458,13 +458,15 @@ class ReturnElimination(override val s: Trees, override val t: Trees)
case Seq() => recons(ids, tvs, tes, ttps, tflags)
case e +: rest if !exprHasReturn(e) =>
// We use a let-binding here to preserve execution order.
val vd = t.ValDef.fresh("x", simpleWhileTransformer.transform(e.getType), true).copiedFrom(e)
val eTpe = widenTp(e.getType)
val vd = t.ValDef.fresh("x", simpleWhileTransformer.transform(eTpe), true).copiedFrom(e)
t.Let(vd, simpleWhileTransformer.transform(e), rec(rest, tes :+ vd.toVariable)).copiedFrom(e)
case e +: rest =>
val firstType = simpleWhileTransformer.transform(e.getType)
val eTpe = widenTp(e.getType)
val firstType = simpleWhileTransformer.transform(eTpe)
ControlFlowSort.andThen(
retTypeChecked, firstType, currentTypeChecked,
transform(e, e.getType),
transform(e, eTpe),
(v: t.Variable) => {
val transformedRest = rec(rest, tes :+ v)
if (rest.exists(exprHasReturn))
Expand Down Expand Up @@ -495,6 +497,19 @@ class ReturnElimination(override val s: Trees, override val t: Trees)
(res, fnSum)
}

// Recursively widen class types to their top ancestor, and strips sigma, pi and refinement types.
// This is used to type the control flow type parameters in order to avoid triggering AdtSpecialization
// generation of refinement types.
private def widenTp(tp: s.Type)(using s.Symbols): s.Type = {
s.typeOps.postMap {
case ct @ s.ClassType(_, _) =>
val ancestors = ct.tcd.ancestors
if (ancestors.isEmpty) None else Some(ancestors.last.toType)
case tp @ (s.SigmaType(_, _) | s.PiType(_, _) | s.RefinementType(_, _)) => Some(tp.getType)
case _ => None
} (tp)
}

override def combineSummaries(summaries: AllSummaries): ExtractionSummary = {
val (retFns, whileFns) = summaries.fnsSummary.foldLeft((Set.empty[Identifier], Set.empty[Identifier])) {
case ((retAcc, whileAcc), FunctionSummary.ReturnOnlyTransformed(fid)) => (retAcc + fid, whileAcc)
Expand Down
39 changes: 39 additions & 0 deletions frontends/benchmarks/imperative/valid/RefnChecksWithReturn.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import stainless.lang._

object RefnChecksWithReturn {

def fun0(x: BigInt, arr: Array[BigInt]): Option[BigInt] = {
require(arr.length >= 10)
if (x <= 0){
return Some(x)
}
arr(0) = 0
Some(x)
}.ensuring {
case Some(xx) => (xx == x) && ((x > 0) ==> (arr(0) == 0))
}

def fun1(x: BigInt): (BigInt, Option[BigInt]) = {
(x, (return (x + 1, Some(x))) : (BigInt, Option[BigInt]))._2
}.ensuring {
case (xx, Some(xx2)) => xx == x + 1 && xx2 == x
}

def fun2(x: BigInt): (BigInt, Option[BigInt]) = {
(x, if (x == 0) return (x + 1, Some(x + 1)) else (x + 2, Some(x + 2)))._2
}.ensuring {
case (xx, Some(xx2)) =>
val delta = if (x == 0) BigInt(1) else BigInt(2)
xx == x + delta && xx2 == x + delta
}

def fun3(x: BigInt): (BigInt, Option[BigInt]) = {
smth(0, if (x <= 0) return (0, Some(x)) else Some(0))
(x, Some(x))
}.ensuring {
case (xx, Some(xx2)) =>
if (x <= 0) xx == 0 && xx2 == x
else xx == x && xx2 == x
}
def smth(x: BigInt, tpl: Option[BigInt]): Unit = ()
}

0 comments on commit d9b3471

Please sign in to comment.