Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AntiAliasing: avoid rebuilding mutated objects when possible #1507

Merged
merged 1 commit into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 181 additions & 32 deletions core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -256,39 +256,187 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override
}

// NOTE: `args` must refer to the arguments of the function invocation before transformation (the original args)
def mapApplication(formalArgs: Seq[ValDef], args: Seq[Expr], nfi: Expr, nfiType: Type, fiEffects: Set[Effect], env: Env): Expr = {
def mapApplication(formalArgs: Seq[ValDef], args: Seq[Expr], nfi: Expr, nfiType: Type, fiEffects: Set[Effect], isOpaqueOrExtern: Boolean, env: Env): Expr = {

def affectedBindings(updTarget: Target, isReplacement: Boolean): Map[ValDef, Set[Target]] = {
def isAffected(t: Target): Boolean = {
if (isReplacement) t.maybeProperPrefixOf(updTarget)
else t.maybePrefixOf(updTarget) || updTarget.maybePrefixOf(t)
}
env.targets.map {
case (vd, targets) =>
val affected = targets.filter(isAffected)
vd -> affected
}.filter(_._2.nonEmpty)
}

if (fiEffects.exists(e => formalArgs contains e.receiver.toVal)) {
val localEffects: Seq[Set[(Effect, Set[(Effect, Option[Expr])])]] = (formalArgs zip args)
.map { case (vd, arg) => (fiEffects.filter(_.receiver == vd.toVariable), arg) }
.filter { case (effects, _) => effects.nonEmpty }
.map { case (effects, arg) => effects map (e => (e, e on arg)) }
val localEffects = formalArgs.zip(args)
.map { case (vd, arg) =>
// Effects for each parameter
(vd.toVariable, fiEffects.filter(_.receiver == vd.toVariable), arg)
}.filter { case (_, effects, _) => effects.nonEmpty }

val freshRes = ValDef(FreshIdentifier("res"), nfiType).copiedFrom(nfi)

val assgns = (for {
(effects, index) <- localEffects.zipWithIndex
(outerEffect0, innerEffects) <- effects
(effect0, effectCond) <- innerEffects
} yield {
val outerEffect = outerEffect0.removeUnknownAccessor
val effect = effect0.removeUnknownAccessor
val pos = args(index).getPos
val resSelect = TupleSelect(freshRes.toVariable, index + 2)
// Follow all aliases of the updated target (may include self if it has no alias)
val primaryTargs = dealiasTarget(effect.toTarget(effectCond), env)
val assignedPrimaryTargs = primaryTargs.toSeq
.map(t => makeAssignment(pos, resSelect, outerEffect.path.path, t, dropVcs = true))
val updAliasingValDefs = updatedAliasingValDefs(primaryTargs, env, pos)

assignedPrimaryTargs ++ updAliasingValDefs
}).flatten
val extractResults = Block(assgns, TupleSelect(freshRes.toVariable, 1))

if (isMutableType(nfiType)) {
LetVar(freshRes, nfi, extractResults)
} else {
Let(freshRes, nfi, extractResults)
val assgns = localEffects.zipWithIndex.flatMap {
case ((vd, effects, arg), effIndex) =>
// +1 because we are a tuple and +1 because the first component is for the result of the function
val resSelect = TupleSelect(freshRes.toVariable, effIndex + 2)
// All effects on the given parameter, applied to the given argument
val paramWithArgsEffect = for {
outerEffect0 <- effects
(effect0, effectCond) <- outerEffect0 on arg
} yield {
val outerEffect = outerEffect0.removeUnknownAccessor
val effect = effect0.removeUnknownAccessor
val primaryTargs = dealiasTarget(effect.toTarget(effectCond), env)
(outerEffect, primaryTargs)
}
// Suppose we have the following definitions:
// case class Ref(var x: Int, var y: Int)
// case class RefRef(var lhs: Ref, var rhs: Ref)
//
// def modifyLhs(rr: RefRef, v: Int): Unit = {
// rr.lhs.x = v
// rr.lhs.y = v
// }
// def test1(testRR: RefRef): Unit = {
// val rrAlias = testRR
// val lhsAlias = testRR.lhs
// modifyLhs(testRR, 123)
// // ...
// }
// `modifyLhs` is (essentially) transformed as follows by `AntiAliasing` (not here in `mapApplication`):
// def modifyLhs(rr: RefRef, v: Int): (Unit, RefRef) = {
// ((), RefRef(Ref(v, v), rr.rhs)
// }
// The transformed `modifyLhs` returns a copy of the "updated" `rr`.
//
// Our task here in `mapApplication` is to transform the call to `modifyLhs`.
// Intuitively, in this case, we can "update" `testRR` to point to the "updated" version
// returned by `modifyLhs`, and update the aliases accordingly:
// def test1(testRR: RefRef): Unit = {
// val rrAlias = testRR
// val lhsAlias = testRR.lhs
// val res = modifyLhs(testRR, 123)
// testRR = res._2
// rrAlias = testRR
// lhsAlias = testRR.lhs
// // ...
// }
// We can do so because we know precisely the `Targets` of the argument, namely `testRR`
// and we can update its aliases accordingly.
// This correspond to the `Success` case of having a `ModifyingEffect` on `vd` (here: `rr`)
// applied on `arg` (here: `testRR`).
//
// However, sometimes, we may not always succeed in computing the precise targets,
// as in the following example:
// def test2(testRR: RefRef): Unit = {
// val lhsAlias = testRR.lhs
// val rhsAlias = testRR.rhs
// modifyLhs(RefRef(lhsAlias, rhsAlias), 123)
// }
// Here, we are not able to compute the targets of `RefRef(lhsAlias, rhsAlias)`,
// which corresponds to the `Failure` case. As such, we cannot simply "update"
// the `testRR` variable using the returned result as-is (as we did for `test1`).
//
// Instead, we need to apply each effect of `modifyLhs` *individually* on the argument.
// The effects for `modifyLhs` are (stored in `localEffects`):
// rr -> Set(ReplacementEffect(rr.lhs.x), ReplacementEffect(rr.lhs.y)))
// So we need to apply two `ReplacementEffect`, one on `rr.lhs.x` and one on `rr.lhs.y` on the argument.
// Doing so with `paramWithArgsEffect` gives us:
// ReplacementEffect(rr.lhs.x) -> Set(Target(testRR, None, .lhs.x))
// ReplacementEffect(rr.lhs.y) -> Set(Target(testRR, None, .lhs.y))
// which we can then use to update `testRR` (alongside their aliases):
// def test2(testRR: RefRef): Unit = {
// var lhsAlias: Ref = testRR.lhs
// val rhsAlias: Ref = testRR.rhs
// val res: (Unit, RefRef) = modifyLhs(RefRef(lhsAlias, rhsAlias), 123)
// // Note that we "update" each field individually, this is due to
// // having each effect applied separately!
// testRR = RefRef(Ref(res._2.lhs.x, testRR.lhs.y), testRR.rhs)
// lhsAlias = testRR.lhs
// testRR = RefRef(Ref(testRR.lhs.x, res._2.lhs.y), testRR.rhs)
// lhsAlias = testRR.lhs
// // ...
// }
//
// Note that we can always apply this second technique even if we have precise aliases.
// However, this tends to "rebuild" the object instead of reusing the "updated" result
// which can lead to verification inefficiency (and does not work well in presence of
// @opaque or @extern functions).
Try(ModifyingEffect(vd, Path.empty).on(arg)) match {
case Success(modEffect) =>
// Update everything that the argument is aliasing
val primaryTargs = modEffect.flatMap { case (eff, cond) => dealiasTarget(eff.toTarget(cond), env) }
val assignedPrimaryTargs = primaryTargs
// The order of assignments does not matter between "primary targets"
// but it must precede the update of aliases (`updAliasingValDefs`)
.toSeq
.map(t => makeAssignment(arg.getPos, resSelect, Seq.empty, t, dropVcs = true))
// We need to be careful with what we are updating here.
// If we expand on the above example with the following function:
// def t3(refref: RefRef): Unit = {
// val lhs = refref.lhs
// val oldLhs = lhs.x
// replaceLhs(refref, 123)
// assert(lhs.x == oldLhs)
// assert(refref.lhs.x == 123)
// }
// In `replaceLhs`, we have a ReplacementEffect on `rr.lhs`, this means
// that `rr.lhs` is replaced with a new `Ref`, leaving all aliases of `rr.lhs`
// (in `t3`, the `val lhs`) untouched. So, after the call to `replaceLhs`,
// any modification to `rr.lhs` do not alter the other aliases (here, `lhs`).
// The function `t3` should be transformed as follows:
// def t3(refref: RefRef): Unit = {
// val lhs = refref.lhs
// val oldLhs = lhs.x
// val res = replaceLhs(refref, 123)
// refref = res._2
// assert(lhs.x == oldLhs)
// assert(refref.lhs.x == 123)
// }
// In particular, note that we *do not* touch `lhs`: the following transformation is incorrect:
// var lhs = refref.lhs
// val oldLhs = lhs.x
// val res = replaceLhs(refref, 123)
// refref = res._2
// lhs = refref.lhs
// because after the call to `replaceLhs`, `lhs` and `refref.lhs` become unrelated.
// Note that, for @opaque and @extern function, we assume the object was mutated in each of its field
// and therefore update all aliases.
val aliasingVds = {
if (isOpaqueOrExtern) {
primaryTargs.flatMap(affectedBindings(_, false))
} else {
paramWithArgsEffect.flatMap {
case (eff, targs) =>
targs.flatMap(affectedBindings(_, eff.kind == ReplacementKind))
}
}
}
val updAliasingValDefs = aliasingVds
.toSeq // See comment on `assignedPrimaryTargs`
.flatMap { case (vd, targs) =>
targs.map(t => makeAssignment(arg.getPos, t.wrap.get, Seq.empty, Target(vd.toVariable, t.condition, Path.empty), true))
}
assignedPrimaryTargs ++ updAliasingValDefs
case Failure(_) =>
paramWithArgsEffect.toSeq.flatMap { case (outerEffect, primaryTargs) =>
// Update everything that the argument is aliasing
val assignedPrimaryTargs = primaryTargs
.toSeq
.map(t => makeAssignment(arg.getPos, resSelect, outerEffect.path.path, t, dropVcs = true))
// Update everything aliasing the argument
val updAliasingValDefs = updatedAliasingValDefs(primaryTargs, env, arg.getPos)
assignedPrimaryTargs ++ updAliasingValDefs
}
}
}

val extractResults = Block(assgns, TupleSelect(freshRes.toVariable, 1))
Let(freshRes, nfi, extractResults)
} else {
nfi
}
Expand Down Expand Up @@ -724,8 +872,8 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override
val nfi = FunctionInvocation(
id, tps, args.map(transform(_, env))
).copiedFrom(fi)

mapApplication(fd.params, args, nfi, fi.tfd.instantiate(analysis.getReturnType(fd)), effects(fd), env)
val isExternOrOpaque = symbols.getFunction(id).flags.exists(f => f == Extern || f == Opaque)
mapApplication(fd.params, args, nfi, fi.tfd.instantiate(analysis.getReturnType(fd)), effects(fd), isExternOrOpaque, env)

case alr @ ApplyLetRec(id, tparams, tpe, tps, args) =>
val fd = Inner(env.locals(id))
Expand All @@ -752,7 +900,8 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override
).copiedFrom(alr)

val resultType = typeOps.instantiateType(analysis.getReturnType(fd), (tparams zip tps).toMap)
mapApplication(fd.params, args, nfi, resultType, effects(fd), env)
val isExternOrOpaque = env.locals(id).flags.exists(f => f == Extern || f == Opaque)
mapApplication(fd.params, args, nfi, resultType, effects(fd), isExternOrOpaque, env)

case app @ Application(callee, args) =>
checkAliasing(app, args, env)
Expand All @@ -770,7 +919,7 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override
case (vd, i) if ftEffects(i) => ModifyingEffect(vd.toVariable, Path.empty)
}
val to = makeFunctionTypeExplicit(ft).asInstanceOf[FunctionType].to
mapApplication(params, args, nfi, to, appEffects.toSet, env)
mapApplication(params, args, nfi, to, appEffects.toSet, false, env)
} else {
Application(transform(callee, env), args.map(transform(_, env))).copiedFrom(app)
}
Expand Down
4 changes: 3 additions & 1 deletion frontends/benchmarks/imperative/invalid/ExternMutation.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import stainless.lang._
import stainless.annotation._
import StaticChecks._

object ExternMutation {
case class Box(var value: BigInt)
Expand All @@ -8,7 +10,7 @@ object ExternMutation {
def f2(b: Container[Box]): Unit = ???

def g2(b: Container[Box]) = {
val b0 = b
@ghost val b0 = snapshot(b)
f2(b)
assert(b == b0) // fails because `Container` is mutable
}
Expand Down
32 changes: 32 additions & 0 deletions frontends/benchmarks/imperative/invalid/OpaqueMutation1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import stainless.lang.{ghost => ghostExpr, _}
import stainless.proof._
import stainless.annotation._
import StaticChecks._

object OpaqueMutation1 {

case class Box(var cnt: BigInt, var other: BigInt) {
@opaque // Note the opaque
def secretSauce(x: BigInt): BigInt = cnt + x // Nobody thought of it!

@opaque // Note the opaque here as well
def increment(): Unit = {
@ghost val oldBox = snapshot(this)
cnt += 1
ghostExpr {
unfold(secretSauce(other))
unfold(oldBox.secretSauce(other))
check(oldBox.secretSauce(other) + 1 == this.secretSauce(other))
}
}.ensuring(_ => old(this).secretSauce(other) + 1 == this.secretSauce(other))
}

def test(b: Box): Unit = {
@ghost val oldBox = snapshot(b)
b.increment()
// Note that, even though the implementation of `increment` does not alter `other`,
// we do not have that knowledge here since the function is marked as opaque.
// Therefore, the following is incorrect (but it holds for `b.other`, see the other `valid/OpaqueMutation`)
assert(oldBox.secretSauce(oldBox.other) + 1 == b.secretSauce(oldBox.other))
}
}
33 changes: 33 additions & 0 deletions frontends/benchmarks/imperative/invalid/OpaqueMutation2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import stainless.lang.{ghost => ghostExpr, _}
import stainless.proof._
import stainless.annotation._
import StaticChecks._

object OpaqueMutation2 {
case class SmallerBox(var otherCnt: BigInt)

case class Box(var cnt: BigInt, var smallerBox: SmallerBox) {
@opaque // Note the opaque
def secretSauce(x: BigInt): BigInt = cnt + x // Nobody thought of it!

@opaque // Note the opaque here as well
def increment(): Unit = {
@ghost val oldBox = snapshot(this)
cnt += 1
ghostExpr {
unfold(secretSauce(smallerBox.otherCnt))
unfold(oldBox.secretSauce(smallerBox.otherCnt))
check(oldBox.secretSauce(smallerBox.otherCnt) + 1 == this.secretSauce(smallerBox.otherCnt))
}
}.ensuring(_ => old(this).secretSauce(smallerBox.otherCnt) + 1 == this.secretSauce(smallerBox.otherCnt))
}

def test(b: Box): Unit = {
@ghost val oldBox = snapshot(b)
b.increment()
// Note that, even though the implementation of `increment` does not alter `smallerBox`,
// we do not have that knowledge here since the function is marked as opaque.
// Therefore, the following is incorrect (but it holds for `b.other`, see the other `valid/OpaqueMutation`)
assert(oldBox.secretSauce(oldBox.smallerBox.otherCnt) + 1 == b.secretSauce(oldBox.smallerBox.otherCnt))
}
}
15 changes: 15 additions & 0 deletions frontends/benchmarks/imperative/valid/ExternMutation.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import stainless.annotation._

object ExternMutation {
case class Box(var value: BigInt)
case class Container[@mutable T](t: T)

@extern
def f2(b: Container[Box]): Unit = ???

def g2(b: Container[Box]) = {
val b0 = b
f2(b)
assert(b == b0) // Ok, even though `b` is assumed to be modified because `b0` is an alias of `b`
}
}
4 changes: 2 additions & 2 deletions frontends/benchmarks/imperative/valid/MutableTuple.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ object MutableTuple {
}

def t3(): (Foo, Bar) = {
val bar = Bar(1)
val foo = Foo(2)
val bar = Bar(10)
val foo = Foo(20)
(foo, bar)
}

Expand Down
Loading