Skip to content

Commit

Permalink
POC: Make ensuring inline
Browse files Browse the repository at this point in the history
  • Loading branch information
mbovel committed Dec 12, 2024
1 parent f341bcf commit fa18858
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import core.Annotations._
import util.SourcePosition

import scala.collection.mutable.{Map => MutableMap}
import dotty.tools.dotc.transform.Inlining

trait ASTExtractors {
val dottyCtx: DottyContext
Expand Down Expand Up @@ -527,6 +528,32 @@ trait ASTExtractors {
}
}

object ExInlinedCall {
/** Extracts an inlined function or method call, returning a 4-tuple
* containing the receiver, function or method symbol, type arguments,
* and term arguments.
*
* Unlift the receiver and arguments to their definitions if they are
* inline proxies.
*/
def unapply(tree: tpd.Tree): Option[(Option[tpd.Tree], Symbol, Seq[tpd.Tree], Seq[tpd.Tree])] = tree match {
// TODO(mbovel): should probably use InlineOrProxy instead of
// Synthetic. Why is this not working?
case Block(stats, Inlined(ExCall(rec, sym, tps, args), _, _)) if stats.forall(_.symbol.is(Synthetic)) =>
def unliftArg(arg: tpd.Tree) = arg match {
case arg @ Ident(_) if arg.symbol.is(Synthetic) =>
stats.find(_.symbol == arg.symbol) match
case Some(proxyDef: tpd.ValOrDefDef) => proxyDef.rhs
case _ => throw new IllegalStateException(s"Could not find inline proxy definition for $arg")
case arg => arg
}
Some(rec.map(unliftArg), sym, tps, args.map(unliftArg))
case Inlined(ExCall(rec, sym, tps, args), _, _) =>
Some(rec, sym, tps, args)
case _ => None
}
}

object ExClassConstruction {
def unapply(tree: tpd.Tree): Option[(Type, Seq[tpd.Tree])] = tree match {
case Apply(Select(New(tpt), nme.CONSTRUCTOR), args) =>
Expand Down Expand Up @@ -1180,17 +1207,25 @@ trait ASTExtractors {
}

object ExEnsuredExpression {
/** Extracts an `ensuring` call.
*
* When matching, returngs a triple containing the receiver, the contract
* and a boolean indicating if the check is static.
*/
def unapply(tree: tpd.Tree): Option[(tpd.Tree, tpd.Tree, Boolean)] = tree match {
// Dynamic check (Predef.ensuring)
// An optional message may comes after `contract`, but we do not make use of it.
case ExCall(Some(rec),
ExSymbol("scala", "Predef$", "Ensuring", "ensuring"),
_, contract +: _
) => Some((rec, contract, false))

// Ditto
case ExCall(Some(rec),
ExSymbol("stainless", "lang", "StaticChecks$", "Ensuring", "ensuring"),
_, contract +: _
// Static check (stainless.lang.StaticChecks.ensuring)
case ExInlinedCall(
_,
ExSymbol("stainless", "lang", "StaticChecks$", "ensuring"),
_,
Seq(rec, contract, message)
) => Some((rec, contract, true))

case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1327,10 +1327,6 @@ class CodeExtraction(inoxCtx: inox.Context,
case (v, vd) => v.symbol -> (() => vd.toVariable)
})))

case Block(es, e) =>
val b = extractBlock(es :+ e)
xt.exprOps.flattenBlocks(b)

case Try(body, cses, fin) =>
val rb = extractTree(body)
val rc = cses.map(extractMatchCase)
Expand Down Expand Up @@ -1382,6 +1378,11 @@ class CodeExtraction(inoxCtx: inox.Context,
})).setPos(post)
})

// Needs to be after `ExEnsuredExpression`, as it matches blocks.
case Block(es, e) =>
val b = extractBlock(es :+ e)
xt.exprOps.flattenBlocks(b)

case ExThrowingExpression(body, contract) =>
val pred = extractTree(contract)
val b = extractTreeOrNoTree(body)
Expand All @@ -1393,7 +1394,7 @@ class CodeExtraction(inoxCtx: inox.Context,
val vd = xt.ValDef.fresh("res", tpe).setPos(other)
xt.Lambda(Seq(vd), xt.Application(other, Seq(vd.toVariable)).setPos(other)).setPos(other)
})

case t @ ExHoldsWithProofExpression(body, ExMaybeBecauseExpressionWrapper(proof)) =>
val vd = xt.ValDef.fresh("holds", xt.BooleanType().setPos(tr.sourcePos)).setPos(tr.sourcePos)
val p = extractTreeOrNoTree(proof)
Expand Down
16 changes: 10 additions & 6 deletions frontends/library/stainless/lang/StaticChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ import stainless.annotation._

object StaticChecks {

@library
implicit class Ensuring[A](val x: A) extends AnyVal {
def ensuring(@ghost cond: A => Boolean): A = x

def ensuring(@ghost cond: A => Boolean, msg: => String): A = x
}
extension [A](x: A)
@library
inline def ensuring(@ghost inline cond: A => Boolean, inline msg: String = ""): A = x

//@library
//implicit class Ensuring[A](val x: A) extends AnyVal {
// def ensuring(@ghost cond: A => Boolean): A = x
//
// def ensuring(@ghost cond: A => Boolean, msg: => String): A = x
//}

@library @ignore
implicit class WhileDecorations(val u: Unit) {
Expand Down

0 comments on commit fa18858

Please sign in to comment.