Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformation for return keyword #923

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/stainless/MainHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ trait MainHelpers extends inox.MainHelpers { self =>
optNoColors -> Description(General, "Disable colored output"),
frontend.optPersistentCache -> Description(General, "Enable caching of program extraction & analysis"),
frontend.optBatchedProgram -> Description(General, "Process the whole program together, skip dependency analysis"),
frontend.optKeep -> Description(General, "Keep library objects marked by @keep(g) for some g in g1,g2,... (implies --batched)"),
frontend.optKeep -> Description(General, "Keep library objects marked by @keepFor(g) for some g in g1,g2,... (implies --batched)"),
frontend.optExtraDeps -> Description(General, "Fetch the specified extra source dependencies and add their source files to the session"),
frontend.optExtraResolvers -> Description(General, "Extra resolvers to use to fetch extra source dependencies"),
utils.Caches.optCacheDir -> Description(General, "Specify the directory in which cache files should be stored")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/* Copyright 2009-2021 EPFL, Lausanne */

package stainless
package extraction
package imperative

trait ReturnElimination
extends oo.CachingPhase
with IdentitySorts
with SimplyCachedFunctions
with SimpleFunctions
with oo.IdentityTypeDefs
with oo.IdentityClasses
with utils.SyntheticSorts { self =>

val s: Trees
val t: s.type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your transformation isn't checking all trees (e.g. a return inside a type refinement wouldn't be rejected).

In general, it's safer to keep s and t types potentially distinct so the (Scala) type checker will make sure you visit all trees in your transformation.


import s._

protected class TransformerContext(val symbols: Symbols) {
// we precompute the set of expressions that contain a return
val exprHasReturn = collection.mutable.Set[Expr]()
for (fd <- symbols.functions.values) {
exprOps.postTraversal {
case e @ Return(_) => exprHasReturn += e
case e @ Operator(es, _) if (es.exists(exprHasReturn)) => exprHasReturn += e
case _ => ()
}(fd.fullBody)
}

val funHasReturn: Set[Identifier] = symbols.functions.values.collect {
case fd if exprHasReturn(fd.fullBody) => fd.id
}.toSet
}

override protected def getContext(symbols: Symbols) = new TransformerContext(symbols)

protected def extractFunction(tc: TransformerContext, fd: FunDef): FunDef = {
implicit val symboms = tc.symbols

if (tc.funHasReturn(fd.id)) {
val specced = exprOps.BodyWithSpecs(fd.fullBody)
val retType = fd.returnType

object ReturnTransformer extends TransformerWithType {
override val s: self.s.type = self.s
override val t: self.s.type = self.s
override val symbols: s.Symbols = tc.symbols

private def proceedOrTransform(expr: Expr, currentType: Type): Expr = {
if (tc.exprHasReturn(expr)) transform(expr, currentType)
else ControlFlowSort.proceed(retType, currentType, expr)
}

private def proceedOrTransform(mc: MatchCase, currentType: Type): MatchCase = {
val MatchCase(pattern, optGuard, rhs) = mc
MatchCase(pattern, optGuard, proceedOrTransform(rhs, currentType))
}

override def transform(expr: Expr, currentType: Type): Expr = expr match {
case _ if !tc.exprHasReturn(expr) => expr

case Return(e) if !tc.exprHasReturn(e) => ControlFlowSort.ret(retType, currentType, e)

case IfExpr(cond, e1, e2) if !tc.exprHasReturn(cond) =>
IfExpr(cond, proceedOrTransform(e1, currentType), proceedOrTransform(e2, currentType))

case MatchExpr(scrut, cases) if !tc.exprHasReturn(scrut) =>
MatchExpr(scrut,
cases.map(proceedOrTransform(_, currentType))
)

case Let(vd, e, body) if tc.exprHasReturn(e) =>
val firstType = vd.tpe
val controlFlowVal =
ValDef.fresh("cf", ControlFlowSort.controlFlow(retType, firstType)).setPos(e)

Let(
controlFlowVal,
transform(e, firstType),
ControlFlowSort.andThen(
retType, firstType, currentType,
controlFlowVal.toVariable,
v => exprOps.replaceFromSymbols(Map(vd -> v), proceedOrTransform(body, currentType)),
body.getPos
)
).setPos(expr)

case Let(vd, e, body) =>
Let(vd, e, transform(body, currentType))

case Block(es, last) =>
def processBlockExpressions(es: Seq[Expr]): Expr = es match {
case Seq(e) => transform(e, currentType)

case e +: rest if (tc.exprHasReturn(e)) =>
val firstType = e.getType
val controlFlowVal =
ValDef.fresh("cf", ControlFlowSort.controlFlow(retType, firstType)).setPos(e)
val transformedRest = processBlockExpressions(rest)

if (rest.exists(tc.exprHasReturn)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this would probably be more readable if you pushed the if-expression down to the ControlFlowSort.andThen lambda argument.

Let(
controlFlowVal,
transform(e, firstType),
ControlFlowSort.andThen(
retType, firstType, currentType,
controlFlowVal.toVariable,
_ => transformedRest,
rest.head.getPos
)
).setPos(e)
} else {
Let(
controlFlowVal,
transform(e, firstType),
ControlFlowSort.andThen(
retType, firstType, currentType,
controlFlowVal.toVariable,
_ => ControlFlowSort.proceed(retType, currentType, transformedRest),
rest.head.getPos
)
).setPos(e)
}

case e +: rest =>
val unusedVal = ValDef.fresh("unused", e.getType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could preserve the Block, something like

case es =>
  val nonReturnEs = es.takeWhile(e => !tc.hasExprReturn(e))
  Block(nonReturnEs, processBlockExpressions(es.drop(nonReturnEs.size()))).copiedFrom(expr)

You'll probably need to handle the case where es.drop(nonReturnEs.size()) is empty specially.

Let(unusedVal, e, processBlockExpressions(rest))
}
processBlockExpressions(es :+ last)

case _ =>
context.reporter.fatalError(expr.getPos, s"Keyword `return` is not supported in expression ${expr.asString}")
}
}

val newBody = specced.bodyOpt.map { body =>
val topLevelCF = ValDef.fresh("topLevelCF", ControlFlowSort.controlFlow(retType, retType)).setPos(fd.fullBody)
Let(topLevelCF, ReturnTransformer.transform(body),
ControlFlowSort.buildMatch(retType, retType, topLevelCF.toVariable,
v => v,
v => v,
body.getPos
)
).setPos(body)
}
fd.copy(fullBody = specced.withBody(newBody, retType).reconstructed).setPos(fd)
}
else fd
}

override protected def extractSymbols(context: TransformerContext, symbols: s.Symbols): t.Symbols = {
if (symbols.functions.values.exists(fd => context.funHasReturn(fd.id)))
super.extractSymbols(context, symbols)
.withSorts(Seq(ControlFlowSort.syntheticControlFlow))
else
super.extractSymbols(context, symbols)
}
}

object ReturnElimination {
def apply(trees: Trees)(implicit ctx: inox.Context): ExtractionPipeline {
val s: trees.type
val t: trees.type
} = new ReturnElimination {
override val s: trees.type = trees
override val t: trees.type = trees
override val context = ctx
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package stainless
package extraction
package imperative

// FIXME: @romac
trait TransformerWithType extends oo.TransformerWithType {
val s: Trees
val t: Trees
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/scala/stainless/extraction/imperative/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ trait Trees extends oo.Trees with Definitions { self =>

/* XLang imperative trees to desugar */

/** Return an [[ast.Expressions.Expr]].
*
* @param expr The expression to return
*/
sealed case class Return(expr: Expr) extends Expr with CachingTyped {
override protected def computeType(implicit s: Symbols): Type = NothingType()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check that expr.isTyped here and return Untyped otherwise.

}

/** $encodingof `{ expr1; expr2; ...; exprn; last }` */
case class Block(exprs: Seq[Expr], last: Expr) extends Expr with CachingTyped {
protected def computeType(implicit s: Symbols): Type = if (exprs.forall(_.isTyped)) last.getType else Untyped
Expand Down Expand Up @@ -193,6 +201,9 @@ trait Printer extends oo.Printer {
case Block(exprs, last) =>
p"${nary(exprs :+ last, "\n")}"

case Return(e) =>
p"return $e"

case LetVar(vd, value, expr) =>
p"""|var $vd = $value
|$expr"""
Expand Down Expand Up @@ -259,9 +270,15 @@ trait Printer extends oo.Printer {

override protected def noBracesSub(e: Tree): Seq[Expr] = e match {
case LetVar(_, _, bd) => Seq(bd)
case Return(e) => Seq(e)
case _ => super.noBracesSub(e)
}

override protected def requiresParentheses(ex: Tree, within: Option[Tree]): Boolean = (ex, within) match {
case (_, Some(_: Return)) => false
case _ => super.requiresParentheses(ex, within)
}

override protected def requiresBraces(ex: Tree, within: Option[Tree]): Boolean = (ex, within) match {
case (_: Expr, Some(_: Block)) => false
case (_: Block, Some(_: While)) => false
Expand Down Expand Up @@ -306,6 +323,9 @@ trait TreeDeconstructor extends oo.TreeDeconstructor {
case s.Old(e) =>
(Seq(), Seq(), Seq(e), Seq(), Seq(), (_, _, es, _, _) => t.Old(es.head))

case s.Return(e) =>
(Seq(), Seq(), Seq(e), Seq(), Seq(), (_, _, es, _, _) => t.Return(es(0)))

case s.MutableMapWithDefault(from, to, default) =>
(Seq(), Seq(), Seq(default), Seq(from, to), Seq(), (_, _, es, tps, _) => t.MutableMapWithDefault(tps(0), tps(1), es(0)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ package object imperative {

def extractor(implicit ctx: inox.Context) = {
utils.DebugPipeline("AntiAliasing", AntiAliasing(trees)) andThen
utils.DebugPipeline("ReturnElimination", ReturnElimination(trees)) andThen
utils.DebugPipeline("ImperativeCodeElimination", ImperativeCodeElimination(trees)) andThen
utils.DebugPipeline("ImperativeCleanup", ImperativeCleanup(trees, oo.trees))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ package object methods {
utils.DebugPipeline("MergeInvariants", MergeInvariants(trees)) andThen
utils.DebugPipeline("FieldAccessors", FieldAccessors(trees)) andThen
utils.DebugPipeline("ValueClasses", ValueClasses(trees)) andThen
lowering
utils.DebugPipeline("MethodsLowering", lowering)
}

def fullExtractor(implicit ctx: inox.Context) = extractor andThen nextExtractor
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/stainless/extraction/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ package object extraction {
"ValueClasses" -> "Erase value classes",
"FieldAccessors" -> "Inline field accessors of concrete classes",
"AntiAliasing" -> "Rewrite field and array mutations",
"ReturnElimination" -> "Eliminate `return` expressions",
"ImperativeCodeElimination" -> "Eliminate while loops and assignments",
"ImperativeCleanup" -> "Cleanup after imperative transformations",
"AdtSpecialization" -> "Specialize classes into ADTs (when possible)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,66 @@ trait SyntheticSorts extends ExtractionCaches { self: CachingPhase =>
symbols.lookup.get[s.FunDef]("stainless.internal.Option.get").map(FunctionKey(_))
)
}


// ControlFlowSort represents the following class:
// sealed abstract class ControlFlow[Ret, Cur] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: remove trailing brace

// case class Return[Ret, Cur](value: Ret) extends ControlFlow[Ret, Cur]
// case class Proceed[Ret, Cur](value: Cur) extends ControlFlow[Ret, Cur]
protected object ControlFlowSort {
import t._
import t.dsl._

val syntheticControlFlow: t.ADTSort = {
val Seq(controlFlow, ret, proceed) =
Seq("ControlFlow", "Return", "Proceed").map(name => ast.SymbolIdentifier("stainless.internal." + name))
val retValue = FreshIdentifier("value")
val proceedValue = FreshIdentifier("value")
mkSort(controlFlow)("Ret", "Cur") { case Seq(retT, curT) =>
Seq(
(ret, Seq(t.ValDef(retValue, retT))),
(proceed, Seq(t.ValDef(proceedValue, curT)))
)
}
}

val controlFlowId: Identifier = syntheticControlFlow.id
val retId: Identifier = syntheticControlFlow.constructors.find(_.id.name == "Return").get.id
val proceedId: Identifier = syntheticControlFlow.constructors.find(_.id.name == "Proceed").get.id

def controlFlow(retT: Type, curT: Type): Type = ADTType(controlFlowId, Seq(retT, curT))
def ret(retT: Type, curT: Type, e: Expr) = ADT(retId, Seq(retT, curT), Seq(e)).setPos(e)
def proceed(retT: Type, curT: Type, e: Expr) = ADT(proceedId, Seq(retT, curT), Seq(e)).setPos(e)

def buildMatch(
retT: Type, curT: Type,
scrut: Expr,
retCase: Variable => Expr,
proceedCase: Variable => Expr,
pos: inox.utils.Position
): Expr = {
val retVal = ValDef.fresh("retValue", retT).setPos(pos)
val proceedVal = ValDef.fresh("proceedValue", curT).setPos(pos)
MatchExpr(scrut, Seq(
MatchCase(
ADTPattern(None, retId, Seq(retT, curT), Seq(WildcardPattern(Some(retVal)))).setPos(pos),
None,
retCase(retVal.toVariable)
).setPos(pos),
MatchCase(
ADTPattern(None, proceedId, Seq(retT, curT), Seq(WildcardPattern(Some(proceedVal)))).setPos(pos),
None,
proceedCase(proceedVal.toVariable)
).setPos(pos),
)).setPos(pos)
}

def andThen(retT: Type, curT: Type, nextT: Type, previous: Expr, next: Variable => Expr, pos: inox.utils.Position): Expr = {
buildMatch(retT, curT, previous,
rv => ret(retT, nextT, rv),
next,
pos
)
}
}
}
3 changes: 2 additions & 1 deletion core/src/main/scala/stainless/frontend/BatchedCallBack.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class BatchedCallBack(components: Seq[Component])(implicit val context: inox.Con
val keepGroups = context.options.findOptionOrDefault(optKeep)

def hasKeepFlag(flags: Seq[xt.Flag]) =
keepGroups.exists(g => flags.contains(xt.Annotation("keep", Seq(xt.StringLiteral(g)))))
flags.exists(_.name == "keep") ||
keepGroups.exists(g => flags.contains(xt.Annotation("keepFor", Seq(xt.StringLiteral(g)))))

def keepDefinition(defn: xt.Definition): Boolean =
hasKeepFlag(defn.flags) || userDependencies.contains(defn.id)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/stainless/frontend/CallBack.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package frontend

import extraction.xlang.{ trees => xt }

// Always keep library objects marked by @keep(g) for some g in g1,...,gn
// Keep library objects marked by @keepFor(g) for some g in g1,...,gn
object optKeep extends inox.OptionDef[Seq[String]] {
val name = "keep"
val default = Seq[String]()
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/stainless/utils/Serialization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class XLangSerializer(override val trees: extraction.xlang.Trees, serializeProdu
/** An extension to the set of registered classes in the `StainlessSerializer`.
* occur within Stainless programs.
*
* The new identifiers in the mapping range from 180 to 256.
* The new identifiers in the mapping range from 180 to 257.
*
* NEXT ID: 257
* NEXT ID: 258
*/
override protected def classSerializers: Map[Class[_], Serializer[_]] =
super.classSerializers ++ Map(
Expand Down Expand Up @@ -181,6 +181,7 @@ class XLangSerializer(override val trees: extraction.xlang.Trees, serializeProdu
classSerializer[Throwing](211),
classSerializer[Throw] (212),
classSerializer[Try] (213),
classSerializer[Return] (257),

// Methods trees
classSerializer[This] (214),
Expand Down
28 changes: 28 additions & 0 deletions frontends/benchmarks/imperative/valid/ControlFlow2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import stainless.lang._

object ControlFlow2 {

def foo(x: Option[BigInt], a: Boolean, b: Boolean): BigInt = {
if (a && b) {
return 1
}

val y = x match {
case None() => return 0
case Some(x) if a => return x + 1
case Some(x) if b => return x + 2
case Some(x) => x
};

-y
}

def testFoo: Unit = {
assert(foo(None(), false, false) == 0)
assert(foo(Some(10), true, true) == 1)
assert(foo(Some(10), true, false) == 11)
assert(foo(Some(10), false, true) == 12)
assert(foo(Some(10), false, false) == -10)
}

}
Loading