Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Round 2 of AntiAliasing changes #1252

Merged
merged 3 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
972 changes: 772 additions & 200 deletions core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ trait EffectsAnalyzer extends oo.CachingPhase {
rec(path, that.path)
}

def maybeProperPrefixOf(that: Path): Boolean = path.length < that.length && maybePrefixOf(that)

// can return `false` even if `this` is a prefix of `that`
def definitelyPrefixOf(that: Path): Boolean = {
def rec(p1: Seq[Accessor], p2: Seq[Accessor]): Boolean = (p1, p2) match {
Expand Down Expand Up @@ -357,6 +359,9 @@ trait EffectsAnalyzer extends oo.CachingPhase {
def maybePrefixOf(that: Target): Boolean =
receiver == that.receiver && (path maybePrefixOf that.path)

def maybeProperPrefixOf(that: Target): Boolean =
receiver == that.receiver && (path maybeProperPrefixOf that.path)

def definitelyPrefixOf(that: Target): Boolean =
receiver == that.receiver && (path definitelyPrefixOf that.path)

Expand Down Expand Up @@ -430,7 +435,7 @@ trait EffectsAnalyzer extends oo.CachingPhase {
else this
}

def on(that: Expr)(using symbols: Symbols): Seq[(Effect, Option[Expr])] = {
def on(that: Expr)(using Symbols): Set[(Effect, Option[Expr])] = {
val res = try {
getTargets(that, kind, path.path).map(t => (t.toEffect(kind), t.condition))
} catch {
Expand All @@ -446,13 +451,21 @@ trait EffectsAnalyzer extends oo.CachingPhase {
res
}

def on(that: Target): (Effect, Option[Expr]) = {
if (kind == CombinedKind && that.path.isEmpty && !this.path.isEmpty) {
context.reporter.fatalError(that.receiver.getPos,
s"Ambiguous effect ${this.asString} on ${that.asString}")
}
(that.appendPath(this.path).toEffect(kind), that.condition)
}

def maybePrefixOf(that: Effect): Boolean =
receiver == that.receiver && (path maybePrefixOf that.path)

def definitelyPrefixOf(that: Effect): Boolean =
receiver == that.receiver && (path definitelyPrefixOf that.path)

def toTarget: Target = Target(receiver, None, path)
def toTarget(cond: Option[Expr] = None): Target = Target(receiver, cond, path)

def wrap(using Symbols): Option[Expr] = path.wrap(receiver)

Expand Down Expand Up @@ -481,13 +494,13 @@ trait EffectsAnalyzer extends oo.CachingPhase {
* effects (with `kind`) on `x` (field assignments, array updates, etc.) result in effects on
* these targets.
*/
def getTargets(expr: Expr, kind: EffectKind, path: Seq[Accessor] = Seq.empty)(using symbols: Symbols): Seq[Target] = expr match {
case _ if variablesOf(expr).forall(v => !symbols.isMutableType(v.tpe)) => Seq.empty
case _ if isExpressionFresh(expr) => Seq.empty
case _ if !symbols.isMutableType(expr.getType) => Seq.empty
case _ if kind == ReplacementKind && path.isEmpty => Seq.empty
def getTargets(expr: Expr, kind: EffectKind, path: Seq[Accessor] = Seq.empty)(using symbols: Symbols): Set[Target] = expr match {
case _ if variablesOf(expr).forall(v => !symbols.isMutableType(v.tpe)) => Set.empty
case _ if isExpressionFresh(expr) => Set.empty
case _ if !symbols.isMutableType(expr.getType) => Set.empty
case _ if kind == ReplacementKind && path.isEmpty => Set.empty

case v: Variable => Seq(Target(v, None, Path(path)))
case v: Variable => Set(Target(v, None, Path(path)))
case ADTSelector(e, id) => getTargets(e, kind, ADTFieldAccessor(id) +: path)
case ClassSelector(e, id) => getTargets(e, kind, ClassFieldAccessor(id) +: path)
case TupleSelect(e, idx) => getTargets(e, kind, TupleFieldAccessor(idx) +: path)
Expand All @@ -501,7 +514,7 @@ trait EffectsAnalyzer extends oo.CachingPhase {
case _ =>
if (kind != ReplacementKind)
throw MalformedStainlessCode(expr, s"Couldn't compute effect targets in ADT ${expr.asString}")
else Seq.empty
else Set.empty
}

case ClassConstructor(ct, args) => path match {
Expand All @@ -510,7 +523,7 @@ trait EffectsAnalyzer extends oo.CachingPhase {
case _ =>
if (kind != ReplacementKind)
throw MalformedStainlessCode(expr, s"Couldn't compute effect targets in class constructor ${expr.asString}")
else Seq.empty
else Set.empty
}

case Tuple(exprs) => path match {
Expand All @@ -519,7 +532,7 @@ trait EffectsAnalyzer extends oo.CachingPhase {
case _ =>
if (kind != ReplacementKind)
throw MalformedStainlessCode(expr, s"Couldn't compute effect targets in tuple ${expr.asString}")
else Seq.empty
else Set.empty
}

case FiniteArray(elems, _) => path match {
Expand All @@ -528,16 +541,17 @@ trait EffectsAnalyzer extends oo.CachingPhase {
if (i < elems.size) getTargets(elems(i), kind, rest)
else throw MalformedStainlessCode(expr, s"Out of bound array access in ${expr.asString}")
case Seq(UnknownArrayAccessor) if kind == ReplacementKind =>
Seq.empty
Set.empty
case _ if kind == ReplacementKind && path.isEmpty =>
Seq.empty
Set.empty
case _ if kind == ReplacementKind && !path.head.isInstanceOf[ArrayAccessor] && path.head != UnknownArrayAccessor =>
Seq.empty
Set.empty
case _ =>
throw MalformedStainlessCode(expr, s"Couldn't compute effect targets in finite array ${expr.asString}")
}

case Assert(_, _, e) => getTargets(e, kind, path)
case Assume(_, e) => getTargets(e, kind, path)
case Annotated(e, _) => getTargets(e, kind, path)

case m: MatchExpr =>
Expand All @@ -564,49 +578,83 @@ trait EffectsAnalyzer extends oo.CachingPhase {
specced.bodyOpt
.map(specced.wrapLets)
.map(getTargets(_, kind, path))
.getOrElse(Seq.empty)

case fi: FunctionInvocation => Seq.empty
case (_: ApplyLetRec | _: Application) => Seq.empty
.getOrElse(Set.empty)
// EffectsChecker will reject inner fn calls, recursive fn calls and lambda calls that do not return fresh expression.
// So we can simply return Set.empty.
case fi: FunctionInvocation => Set.empty
case (_: ApplyLetRec | _: Application) => Set.empty
case _: LargeArray | _: ArrayUpdated if kind == ReplacementKind && path.isEmpty =>
Seq.empty
Set.empty
case _: LargeArray | _: ArrayUpdated if kind == ReplacementKind && !path.head.isInstanceOf[ArrayAccessor] && path.head != UnknownArrayAccessor =>
Seq.empty
// TODO: These two cases are incorrect, but removing them breaks existing codebase...
case _: MutableMapUpdated => Seq.empty
case _: ArrayUpdated => Seq.empty
Set.empty
case au: ArrayUpdated => au.getType match {
case ArrayType(base) if !symbols.isMutableType(base) => Set.empty
case _ =>
throw MalformedStainlessCode(expr, s"Couldn't compute effect targets of array copy update ${au.asString}")
}
Comment on lines +592 to +594
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and the corresponding MapUpdated logic) seems to strict. This will also prevent users from using the snapshot(a).updated(i, v) construct which we were discussing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, for instance, the SortedArray benchmark needs to be updated to account for that restriction (I'm realizing that we can also return Seq.empty if the effect kind is replacement).
For ModfiyingKind though, I do not think there is a way to express the concerned targets.

case mu: MapUpdated => mu.getType match {
case MapType(from, to) if !symbols.isMutableType(from) && !symbols.isMutableType(to) => Set.empty
case _ =>
throw MalformedStainlessCode(expr, s"Couldn't compute effect targets of map copy update ${mu.asString}")
}
case IsInstanceOf(e, _) => getTargets(e, kind, path)
case AsInstanceOf(e, _) => getTargets(e, kind, path)
case Old(_) => Seq.empty
case Snapshot(_) => Seq.empty
case FreshCopy(_) => Seq.empty

case ArrayLength(_) => Seq.empty

case FiniteSet(elements, tpe) => Seq.empty
case SetUnion(s1, s2) => Seq.empty
case SetIntersection(s1, s2) => Seq.empty
case SetDifference(s1, s2) => Seq.empty
case SubsetOf(s1, s2) => Seq.empty
case ElementOfSet(element, set) => Seq.empty
case SetAdd(bag, element) => Seq.empty

case FiniteBag(elements, tpe) => Seq.empty
case BagUnion(s1, s2) => Seq.empty
case BagIntersection(s1, s2) => Seq.empty
case BagDifference(s1, s2) => Seq.empty
case MultiplicityInBag(element, bag) => Seq.empty
case BagAdd(bag, element) => Seq.empty
case Old(_) => Set.empty
case Snapshot(_) => Set.empty
case FreshCopy(_) => Set.empty

case ArrayLength(_) => Set.empty

case FiniteSet(elements, tpe) => Set.empty
case SetUnion(s1, s2) => Set.empty
case SetIntersection(s1, s2) => Set.empty
case SetDifference(s1, s2) => Set.empty
case SubsetOf(s1, s2) => Set.empty
case ElementOfSet(element, set) => Set.empty
case SetAdd(bag, element) => Set.empty

case FiniteBag(elements, tpe) => Set.empty
case BagUnion(s1, s2) => Set.empty
case BagIntersection(s1, s2) => Set.empty
case BagDifference(s1, s2) => Set.empty
case MultiplicityInBag(element, bag) => Set.empty
case BagAdd(bag, element) => Set.empty

case Block(_, last) => getTargets(last, kind, path)

case Let(vd, e, b) if !symbols.isMutableType(vd.tpe) =>
getTargets(b, kind, path).map(_.bind(vd, e))

case Let(vd, e, b) =>
getTargets(b, kind, path).map(_.bind(vd, e)).flatMap { be =>
if (be.receiver == vd.toVariable) getTargets(e, kind, be.path.path)
else Seq(be)
val targs0 = getTargets(b, kind, path)
// If `e` is referentially transparent (such as `i + 3`, assuming `i` is a val),
// we can bind and substitute it anywhere we want within b.
// For instance, if we have:
// val vd = i + 3
// val x = f(vd + 2)
// y.field = vd
// Then, the following are equivalent:
// val x = f(i + 5)
// y.field = i + 3
// and
// val x = f({val tmp = vd + 2; tmp})
// y.field = i + 3
// Such operations may be performed by target.bind(vd, e)
// On the other hand, we cannot apply these transformation for non-referentially transparent
// expressions, as the resulting expression may not be equivalent.
// For example, assuming `ii` is declared as a `var` and if we have:
// val vd = ii + 3
// ii += 1
// y.field = vd
// Then, it is clear that replacing `vd` in the assignment by `ii + 3` is incorrect.
// As such, we do not rebind or substitute `vd` by `e` within the targets
// (`vd` will appear as-is, i.e. as a variable, "forgetting" its definition).
val targs = if (isReferentiallyTransparent(e)) targs0.map(_.bind(vd, e)) else targs0

if (!symbols.isMutableType(vd.tpe)) {
targs
} else {
targs.flatMap { be =>
if (be.receiver == vd.toVariable) getTargets(e, kind, be.path.path)
else Set(be)
}
}

case _ =>
Expand Down Expand Up @@ -700,6 +748,18 @@ trait EffectsAnalyzer extends oo.CachingPhase {
case _ => throw FatalError(s"Cannot have accessors over type $tpe")
}

def isReferentiallyTransparent(e: Expr)(using syms: Symbols): Boolean = e match {
case Variable(_, tpe, flags) => !flags.contains(IsVar) && !syms.isMutableType(tpe)
case ClassSelector(expr, field) =>
val c @ ClassType(_, _) = expr.getType
!c.getField(field).get.flags.contains(IsVar) && isReferentiallyTransparent(expr)
case _: (ArraySelect | MutableMapApply) => false
case _: (Literal[t] | Lambda) => true
case fi @ FunctionInvocation(_, _, _) => functionTypeEffects(fi.tfd.functionType).isEmpty
case _: (Application | ApplyLetRec | Swap | ArrayUpdate | MutableMapUpdate | FieldAssignment | Assignment) => false
case Operator(es, _) => es.forall(isReferentiallyTransparent)
}

/** Return all effects of expr
*
* Effects of expr are any free variables in scope (either local vars
Expand All @@ -717,33 +777,33 @@ trait EffectsAnalyzer extends oo.CachingPhase {
import symbols._
val freeVars = variablesOf(expr).filter(vd => isMutableType(vd.tpe) || vd.flags.contains(IsVar))

def inEnv(effect: Effect, env: Map[Variable, Effect]): Option[Effect] =
env.get(effect.receiver).map(e => Effect(effect.kind, e.receiver, e.path ++ effect.path))
def inEnv(effect: Effect, env: Map[Variable, Set[Effect]]): Set[Effect] =
env.getOrElse(effect.receiver, Set.empty).map(e => Effect(effect.kind, e.receiver, e.path ++ effect.path))

def effect(expr: Expr, env: Map[Variable, Effect]): Seq[Effect] =
def effect(expr: Expr, env: Map[Variable, Set[Effect]]): Set[Effect] =
getAllTargets(expr) flatMap { (target: Target) =>
inEnv(target.toEffect(ModifyingKind), env)
}

def rec(expr: Expr, env: Map[Variable, Effect]): Set[Effect] = expr match {
def rec(expr: Expr, env: Map[Variable, Set[Effect]]): Set[Effect] = expr match {
case Let(vd, e, b) if symbols.isMutableType(vd.tpe) =>

if ((variablesOf(e) & variablesOf(b)).forall(v => !isMutableType(v.tpe))) {
val effe = rec(e, env)
val newEnv = (variablesOf(b) ++ freeVars).map(v => v -> ModifyingEffect(v, Path.empty)).toMap
val newEnv = (variablesOf(b) ++ freeVars).map(v => v -> Set(ModifyingEffect(v, Path.empty): Effect)).toMap
val effb = rec(b, newEnv)
effe ++ effb.flatMap { ef =>
if (ef.receiver == vd.toVariable) ef.on(e).map(_._1)
else Set(ef)
}.flatMap(inEnv(_, env))
}
else
rec(e, env) ++ rec(b, env ++ effect(e, env).map(vd.toVariable -> _))
rec(e, env) ++ rec(b, env + (vd.toVariable -> effect(e, env)))

case MatchExpr(scrut, cses) if symbols.isMutableType(scrut.getType) =>
rec(scrut, env) ++ cses.flatMap { case MatchCase(pattern, guard, rhs) =>
val newEnv = env ++ mapForPattern(scrut, pattern).flatMap {
case (v, e) => effect(e, env).map(v.toVariable -> _)
val newEnv = env ++ mapForPattern(scrut, pattern).map {
case (v, e) => v.toVariable -> effect(e, env)
}
guard.toSeq.flatMap(rec(_, newEnv)).toSet ++ rec(rhs, newEnv)
}
Expand Down Expand Up @@ -782,7 +842,7 @@ trait EffectsAnalyzer extends oo.CachingPhase {
.filter(p => effects contains p._2)
.flatMap(_._1)

case Assignment(v, value) => rec(value, env) ++ env.get(v)
case Assignment(v, value) => rec(value, env) ++ env.getOrElse(v, Set.empty)

case IfExpr(cnd, thn, els) =>
rec(cnd, env) ++ rec(thn, env) ++ rec(els, env)
Expand Down Expand Up @@ -851,7 +911,7 @@ trait EffectsAnalyzer extends oo.CachingPhase {
effect.withPath(newPath).withKind(newKind)
}

val mutated = try (rec(expr, freeVars.map(v => v -> ModifyingEffect(v, Path.empty)).toMap))
val mutated = try (rec(expr, freeVars.map(v => v -> Set(ModifyingEffect(v, Path.empty): Effect)).toMap))
catch {
case _: MalformedStainlessCode =>
freeVars.map(v => ModifyingEffect(v, Path.empty)).toSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,38 @@ trait EffectsChecker { self: EffectsAnalyzer =>

super.traverse(ct)

case st @ SetType(elemTp) =>
if (isMutableType(elemTp)) {
throw ImperativeEliminationException(tpe,
s"Cannot instantiate a set ${tpe.asString} with a mutable type ${elemTp.asString}")
}

super.traverse(st)

case bt @ BagType(elemTp) =>
if (isMutableType(elemTp)) {
throw ImperativeEliminationException(tpe,
s"Cannot instantiate a bag ${tpe.asString} with a mutable type ${elemTp.asString}")
}

super.traverse(bt)

case mt @ MapType(from, _) =>
if (isMutableType(from)) {
throw ImperativeEliminationException(tpe,
s"Cannot instantiate a map ${tpe.asString} with a mutable key type ${from.asString}")
}

super.traverse(mt)

case mt @ MutableMapType(from, _) =>
if (isMutableType(from)) {
throw ImperativeEliminationException(tpe,
s"Cannot instantiate a mutable map ${tpe.asString} with a mutable key type ${from.asString}")
}

super.traverse(mt)

case _ => super.traverse(tpe)
}

Expand All @@ -87,8 +119,8 @@ trait EffectsChecker { self: EffectsAnalyzer =>
super.traverse(l)

case l @ LetVar(vd, e, b) =>
if (!isExpressionFresh(e) && isMutableType(vd.tpe))
throw ImperativeEliminationException(e, "Illegal aliasing: " + e.asString)
if (isMutableType(vd.tpe))
throw ImperativeEliminationException(e, "Cannot bind expression of a mutable type to a `var`: " + e.asString)

super.traverse(l)

Expand All @@ -98,6 +130,12 @@ trait EffectsChecker { self: EffectsAnalyzer =>

super.traverse(au)

case au @ ArrayUpdated(a, i, e) =>
if (isMutableType(e.getType) && !isExpressionFresh(e))
throw ImperativeEliminationException(e, "Illegal aliasing: " + e.asString)

super.traverse(au)

case mu @ MapUpdated(m, k, e) =>
if (isMutableType(e.getType) && !isExpressionFresh(e))
throw ImperativeEliminationException(e, "Illegal aliasing: " + e.asString)
Expand Down Expand Up @@ -162,6 +200,15 @@ trait EffectsChecker { self: EffectsAnalyzer =>

super.traverse(dup)

case la @ LargeArray(_, default, _, _) =>
// The `default` expression is the one that is going to be repeated n times, so it must be referentially transparent.
if (!isReferentiallyTransparent(default)) {
throw ImperativeEliminationException(e,
s"Cannot use effectfull computations within Array.fill (${default.asString})")
}

super.traverse(la)

case _ => super.traverse(e)
}
}
Expand Down
Loading