-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformation for return keyword #923
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
/* Copyright 2009-2021 EPFL, Lausanne */ | ||
|
||
package stainless | ||
package extraction | ||
package imperative | ||
|
||
trait ReturnElimination | ||
extends oo.CachingPhase | ||
with IdentitySorts | ||
with SimplyCachedFunctions | ||
with SimpleFunctions | ||
with oo.IdentityTypeDefs | ||
with oo.IdentityClasses | ||
with utils.SyntheticSorts { self => | ||
|
||
val s: Trees | ||
val t: s.type | ||
|
||
import s._ | ||
|
||
protected class TransformerContext(val symbols: Symbols) { | ||
// we precompute the set of expressions that contain a return | ||
val exprHasReturn = collection.mutable.Set[Expr]() | ||
for (fd <- symbols.functions.values) { | ||
exprOps.postTraversal { | ||
case e @ Return(_) => exprHasReturn += e | ||
case e @ Operator(es, _) if (es.exists(exprHasReturn)) => exprHasReturn += e | ||
case _ => () | ||
}(fd.fullBody) | ||
} | ||
|
||
val funHasReturn: Set[Identifier] = symbols.functions.values.collect { | ||
case fd if exprHasReturn(fd.fullBody) => fd.id | ||
}.toSet | ||
} | ||
|
||
override protected def getContext(symbols: Symbols) = new TransformerContext(symbols) | ||
|
||
protected def extractFunction(tc: TransformerContext, fd: FunDef): FunDef = { | ||
implicit val symboms = tc.symbols | ||
|
||
if (tc.funHasReturn(fd.id)) { | ||
val specced = exprOps.BodyWithSpecs(fd.fullBody) | ||
val retType = fd.returnType | ||
|
||
object ReturnTransformer extends TransformerWithType { | ||
override val s: self.s.type = self.s | ||
override val t: self.s.type = self.s | ||
override val symbols: s.Symbols = tc.symbols | ||
|
||
private def proceedOrTransform(expr: Expr, currentType: Type): Expr = { | ||
if (tc.exprHasReturn(expr)) transform(expr, currentType) | ||
else ControlFlowSort.proceed(retType, currentType, expr) | ||
} | ||
|
||
private def proceedOrTransform(mc: MatchCase, currentType: Type): MatchCase = { | ||
val MatchCase(pattern, optGuard, rhs) = mc | ||
MatchCase(pattern, optGuard, proceedOrTransform(rhs, currentType)) | ||
} | ||
|
||
override def transform(expr: Expr, currentType: Type): Expr = expr match { | ||
case _ if !tc.exprHasReturn(expr) => expr | ||
|
||
case Return(e) if !tc.exprHasReturn(e) => ControlFlowSort.ret(retType, currentType, e) | ||
|
||
case IfExpr(cond, e1, e2) if !tc.exprHasReturn(cond) => | ||
IfExpr(cond, proceedOrTransform(e1, currentType), proceedOrTransform(e2, currentType)) | ||
|
||
case MatchExpr(scrut, cases) if !tc.exprHasReturn(scrut) => | ||
MatchExpr(scrut, | ||
cases.map(proceedOrTransform(_, currentType)) | ||
) | ||
|
||
case Let(vd, e, body) if tc.exprHasReturn(e) => | ||
val firstType = vd.tpe | ||
val controlFlowVal = | ||
ValDef.fresh("cf", ControlFlowSort.controlFlow(retType, firstType)).setPos(e) | ||
|
||
Let( | ||
controlFlowVal, | ||
transform(e, firstType), | ||
ControlFlowSort.andThen( | ||
retType, firstType, currentType, | ||
controlFlowVal.toVariable, | ||
v => exprOps.replaceFromSymbols(Map(vd -> v), proceedOrTransform(body, currentType)), | ||
body.getPos | ||
) | ||
).setPos(expr) | ||
|
||
case Let(vd, e, body) => | ||
Let(vd, e, transform(body, currentType)) | ||
|
||
case Block(es, last) => | ||
def processBlockExpressions(es: Seq[Expr]): Expr = es match { | ||
case Seq(e) => transform(e, currentType) | ||
|
||
case e +: rest if (tc.exprHasReturn(e)) => | ||
val firstType = e.getType | ||
val controlFlowVal = | ||
ValDef.fresh("cf", ControlFlowSort.controlFlow(retType, firstType)).setPos(e) | ||
val transformedRest = processBlockExpressions(rest) | ||
|
||
if (rest.exists(tc.exprHasReturn)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: this would probably be more readable if you pushed the if-expression down to the |
||
Let( | ||
controlFlowVal, | ||
transform(e, firstType), | ||
ControlFlowSort.andThen( | ||
retType, firstType, currentType, | ||
controlFlowVal.toVariable, | ||
_ => transformedRest, | ||
rest.head.getPos | ||
) | ||
).setPos(e) | ||
} else { | ||
Let( | ||
controlFlowVal, | ||
transform(e, firstType), | ||
ControlFlowSort.andThen( | ||
retType, firstType, currentType, | ||
controlFlowVal.toVariable, | ||
_ => ControlFlowSort.proceed(retType, currentType, transformedRest), | ||
rest.head.getPos | ||
) | ||
).setPos(e) | ||
} | ||
|
||
case e +: rest => | ||
val unusedVal = ValDef.fresh("unused", e.getType) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could preserve the case es =>
val nonReturnEs = es.takeWhile(e => !tc.hasExprReturn(e))
Block(nonReturnEs, processBlockExpressions(es.drop(nonReturnEs.size()))).copiedFrom(expr) You'll probably need to handle the case where |
||
Let(unusedVal, e, processBlockExpressions(rest)) | ||
} | ||
processBlockExpressions(es :+ last) | ||
|
||
case _ => | ||
context.reporter.fatalError(expr.getPos, s"Keyword `return` is not supported in expression ${expr.asString}") | ||
} | ||
} | ||
|
||
val newBody = specced.bodyOpt.map { body => | ||
val topLevelCF = ValDef.fresh("topLevelCF", ControlFlowSort.controlFlow(retType, retType)).setPos(fd.fullBody) | ||
Let(topLevelCF, ReturnTransformer.transform(body), | ||
ControlFlowSort.buildMatch(retType, retType, topLevelCF.toVariable, | ||
v => v, | ||
v => v, | ||
body.getPos | ||
) | ||
).setPos(body) | ||
} | ||
fd.copy(fullBody = specced.withBody(newBody, retType).reconstructed).setPos(fd) | ||
} | ||
else fd | ||
} | ||
|
||
override protected def extractSymbols(context: TransformerContext, symbols: s.Symbols): t.Symbols = { | ||
if (symbols.functions.values.exists(fd => context.funHasReturn(fd.id))) | ||
super.extractSymbols(context, symbols) | ||
.withSorts(Seq(ControlFlowSort.syntheticControlFlow)) | ||
else | ||
super.extractSymbols(context, symbols) | ||
} | ||
} | ||
|
||
object ReturnElimination { | ||
def apply(trees: Trees)(implicit ctx: inox.Context): ExtractionPipeline { | ||
val s: trees.type | ||
val t: trees.type | ||
} = new ReturnElimination { | ||
override val s: trees.type = trees | ||
override val t: trees.type = trees | ||
override val context = ctx | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,14 @@ trait Trees extends oo.Trees with Definitions { self => | |
|
||
/* XLang imperative trees to desugar */ | ||
|
||
/** Return an [[ast.Expressions.Expr]]. | ||
* | ||
* @param expr The expression to return | ||
*/ | ||
sealed case class Return(expr: Expr) extends Expr with CachingTyped { | ||
override protected def computeType(implicit s: Symbols): Type = NothingType() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should check that |
||
} | ||
|
||
/** $encodingof `{ expr1; expr2; ...; exprn; last }` */ | ||
case class Block(exprs: Seq[Expr], last: Expr) extends Expr with CachingTyped { | ||
protected def computeType(implicit s: Symbols): Type = if (exprs.forall(_.isTyped)) last.getType else Untyped | ||
|
@@ -193,6 +201,9 @@ trait Printer extends oo.Printer { | |
case Block(exprs, last) => | ||
p"${nary(exprs :+ last, "\n")}" | ||
|
||
case Return(e) => | ||
p"return $e" | ||
|
||
case LetVar(vd, value, expr) => | ||
p"""|var $vd = $value | ||
|$expr""" | ||
|
@@ -259,9 +270,15 @@ trait Printer extends oo.Printer { | |
|
||
override protected def noBracesSub(e: Tree): Seq[Expr] = e match { | ||
case LetVar(_, _, bd) => Seq(bd) | ||
case Return(e) => Seq(e) | ||
case _ => super.noBracesSub(e) | ||
} | ||
|
||
override protected def requiresParentheses(ex: Tree, within: Option[Tree]): Boolean = (ex, within) match { | ||
case (_, Some(_: Return)) => false | ||
case _ => super.requiresParentheses(ex, within) | ||
} | ||
|
||
override protected def requiresBraces(ex: Tree, within: Option[Tree]): Boolean = (ex, within) match { | ||
case (_: Expr, Some(_: Block)) => false | ||
case (_: Block, Some(_: While)) => false | ||
|
@@ -306,6 +323,9 @@ trait TreeDeconstructor extends oo.TreeDeconstructor { | |
case s.Old(e) => | ||
(Seq(), Seq(), Seq(e), Seq(), Seq(), (_, _, es, _, _) => t.Old(es.head)) | ||
|
||
case s.Return(e) => | ||
(Seq(), Seq(), Seq(e), Seq(), Seq(), (_, _, es, _, _) => t.Return(es(0))) | ||
|
||
case s.MutableMapWithDefault(from, to, default) => | ||
(Seq(), Seq(), Seq(default), Seq(from, to), Seq(), (_, _, es, tps, _) => t.MutableMapWithDefault(tps(0), tps(1), es(0))) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,4 +98,66 @@ trait SyntheticSorts extends ExtractionCaches { self: CachingPhase => | |
symbols.lookup.get[s.FunDef]("stainless.internal.Option.get").map(FunctionKey(_)) | ||
) | ||
} | ||
|
||
|
||
// ControlFlowSort represents the following class: | ||
// sealed abstract class ControlFlow[Ret, Cur] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: remove trailing brace |
||
// case class Return[Ret, Cur](value: Ret) extends ControlFlow[Ret, Cur] | ||
// case class Proceed[Ret, Cur](value: Cur) extends ControlFlow[Ret, Cur] | ||
protected object ControlFlowSort { | ||
import t._ | ||
import t.dsl._ | ||
|
||
val syntheticControlFlow: t.ADTSort = { | ||
val Seq(controlFlow, ret, proceed) = | ||
Seq("ControlFlow", "Return", "Proceed").map(name => ast.SymbolIdentifier("stainless.internal." + name)) | ||
val retValue = FreshIdentifier("value") | ||
val proceedValue = FreshIdentifier("value") | ||
mkSort(controlFlow)("Ret", "Cur") { case Seq(retT, curT) => | ||
Seq( | ||
(ret, Seq(t.ValDef(retValue, retT))), | ||
(proceed, Seq(t.ValDef(proceedValue, curT))) | ||
) | ||
} | ||
} | ||
|
||
val controlFlowId: Identifier = syntheticControlFlow.id | ||
val retId: Identifier = syntheticControlFlow.constructors.find(_.id.name == "Return").get.id | ||
val proceedId: Identifier = syntheticControlFlow.constructors.find(_.id.name == "Proceed").get.id | ||
|
||
def controlFlow(retT: Type, curT: Type): Type = ADTType(controlFlowId, Seq(retT, curT)) | ||
def ret(retT: Type, curT: Type, e: Expr) = ADT(retId, Seq(retT, curT), Seq(e)).setPos(e) | ||
def proceed(retT: Type, curT: Type, e: Expr) = ADT(proceedId, Seq(retT, curT), Seq(e)).setPos(e) | ||
|
||
def buildMatch( | ||
retT: Type, curT: Type, | ||
scrut: Expr, | ||
retCase: Variable => Expr, | ||
proceedCase: Variable => Expr, | ||
pos: inox.utils.Position | ||
): Expr = { | ||
val retVal = ValDef.fresh("retValue", retT).setPos(pos) | ||
val proceedVal = ValDef.fresh("proceedValue", curT).setPos(pos) | ||
MatchExpr(scrut, Seq( | ||
MatchCase( | ||
ADTPattern(None, retId, Seq(retT, curT), Seq(WildcardPattern(Some(retVal)))).setPos(pos), | ||
None, | ||
retCase(retVal.toVariable) | ||
).setPos(pos), | ||
MatchCase( | ||
ADTPattern(None, proceedId, Seq(retT, curT), Seq(WildcardPattern(Some(proceedVal)))).setPos(pos), | ||
None, | ||
proceedCase(proceedVal.toVariable) | ||
).setPos(pos), | ||
)).setPos(pos) | ||
} | ||
|
||
def andThen(retT: Type, curT: Type, nextT: Type, previous: Expr, next: Variable => Expr, pos: inox.utils.Position): Expr = { | ||
buildMatch(retT, curT, previous, | ||
rv => ret(retT, nextT, rv), | ||
next, | ||
pos | ||
) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import stainless.lang._ | ||
|
||
object ControlFlow2 { | ||
|
||
def foo(x: Option[BigInt], a: Boolean, b: Boolean): BigInt = { | ||
if (a && b) { | ||
return 1 | ||
} | ||
|
||
val y = x match { | ||
case None() => return 0 | ||
case Some(x) if a => return x + 1 | ||
case Some(x) if b => return x + 2 | ||
case Some(x) => x | ||
}; | ||
|
||
-y | ||
} | ||
|
||
def testFoo: Unit = { | ||
assert(foo(None(), false, false) == 0) | ||
assert(foo(Some(10), true, true) == 1) | ||
assert(foo(Some(10), true, false) == 11) | ||
assert(foo(Some(10), false, true) == 12) | ||
assert(foo(Some(10), false, false) == -10) | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your transformation isn't checking all trees (e.g. a
return
inside a type refinement wouldn't be rejected).In general, it's safer to keep
s
andt
types potentially distinct so the (Scala) type checker will make sure you visit all trees in your transformation.