From d9b34710709d91fdbb391813982db513fc95da24 Mon Sep 17 00:00:00 2001 From: Mario Bucev Date: Tue, 31 Oct 2023 16:42:49 +0100 Subject: [PATCH] Applying some type widening in ReturnElimination to avoid triggering AdtSpecialization --- .../imperative/ReturnElimination.scala | 25 +++++++++--- .../valid/RefnChecksWithReturn.scala | 39 +++++++++++++++++++ 2 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 frontends/benchmarks/imperative/valid/RefnChecksWithReturn.scala diff --git a/core/src/main/scala/stainless/extraction/imperative/ReturnElimination.scala b/core/src/main/scala/stainless/extraction/imperative/ReturnElimination.scala index 8ec54cb037..385e094b9e 100644 --- a/core/src/main/scala/stainless/extraction/imperative/ReturnElimination.scala +++ b/core/src/main/scala/stainless/extraction/imperative/ReturnElimination.scala @@ -411,8 +411,8 @@ class ReturnElimination(override val s: Trees, override val t: Trees) case Seq(e) => transform(e, currentType) case e +: rest if exprHasReturn(e) => - val firstType = e.getType - val firstTypeChecked = simpleWhileTransformer.transform(e.getType) + val firstType = widenTp(e.getType) + val firstTypeChecked = simpleWhileTransformer.transform(firstType) val controlFlowVal = t.ValDef.fresh("cf", ControlFlowSort.controlFlow(retTypeChecked, firstTypeChecked) @@ -458,13 +458,15 @@ class ReturnElimination(override val s: Trees, override val t: Trees) case Seq() => recons(ids, tvs, tes, ttps, tflags) case e +: rest if !exprHasReturn(e) => // We use a let-binding here to preserve execution order. - val vd = t.ValDef.fresh("x", simpleWhileTransformer.transform(e.getType), true).copiedFrom(e) + val eTpe = widenTp(e.getType) + val vd = t.ValDef.fresh("x", simpleWhileTransformer.transform(eTpe), true).copiedFrom(e) t.Let(vd, simpleWhileTransformer.transform(e), rec(rest, tes :+ vd.toVariable)).copiedFrom(e) case e +: rest => - val firstType = simpleWhileTransformer.transform(e.getType) + val eTpe = widenTp(e.getType) + val firstType = simpleWhileTransformer.transform(eTpe) ControlFlowSort.andThen( retTypeChecked, firstType, currentTypeChecked, - transform(e, e.getType), + transform(e, eTpe), (v: t.Variable) => { val transformedRest = rec(rest, tes :+ v) if (rest.exists(exprHasReturn)) @@ -495,6 +497,19 @@ class ReturnElimination(override val s: Trees, override val t: Trees) (res, fnSum) } + // Recursively widen class types to their top ancestor, and strips sigma, pi and refinement types. + // This is used to type the control flow type parameters in order to avoid triggering AdtSpecialization + // generation of refinement types. + private def widenTp(tp: s.Type)(using s.Symbols): s.Type = { + s.typeOps.postMap { + case ct @ s.ClassType(_, _) => + val ancestors = ct.tcd.ancestors + if (ancestors.isEmpty) None else Some(ancestors.last.toType) + case tp @ (s.SigmaType(_, _) | s.PiType(_, _) | s.RefinementType(_, _)) => Some(tp.getType) + case _ => None + } (tp) + } + override def combineSummaries(summaries: AllSummaries): ExtractionSummary = { val (retFns, whileFns) = summaries.fnsSummary.foldLeft((Set.empty[Identifier], Set.empty[Identifier])) { case ((retAcc, whileAcc), FunctionSummary.ReturnOnlyTransformed(fid)) => (retAcc + fid, whileAcc) diff --git a/frontends/benchmarks/imperative/valid/RefnChecksWithReturn.scala b/frontends/benchmarks/imperative/valid/RefnChecksWithReturn.scala new file mode 100644 index 0000000000..e8776d24f6 --- /dev/null +++ b/frontends/benchmarks/imperative/valid/RefnChecksWithReturn.scala @@ -0,0 +1,39 @@ +import stainless.lang._ + +object RefnChecksWithReturn { + + def fun0(x: BigInt, arr: Array[BigInt]): Option[BigInt] = { + require(arr.length >= 10) + if (x <= 0){ + return Some(x) + } + arr(0) = 0 + Some(x) + }.ensuring { + case Some(xx) => (xx == x) && ((x > 0) ==> (arr(0) == 0)) + } + + def fun1(x: BigInt): (BigInt, Option[BigInt]) = { + (x, (return (x + 1, Some(x))) : (BigInt, Option[BigInt]))._2 + }.ensuring { + case (xx, Some(xx2)) => xx == x + 1 && xx2 == x + } + + def fun2(x: BigInt): (BigInt, Option[BigInt]) = { + (x, if (x == 0) return (x + 1, Some(x + 1)) else (x + 2, Some(x + 2)))._2 + }.ensuring { + case (xx, Some(xx2)) => + val delta = if (x == 0) BigInt(1) else BigInt(2) + xx == x + delta && xx2 == x + delta + } + + def fun3(x: BigInt): (BigInt, Option[BigInt]) = { + smth(0, if (x <= 0) return (0, Some(x)) else Some(0)) + (x, Some(x)) + }.ensuring { + case (xx, Some(xx2)) => + if (x <= 0) xx == 0 && xx2 == x + else xx == x && xx2 == x + } + def smth(x: BigInt, tpl: Option[BigInt]): Unit = () +} \ No newline at end of file