Skip to content

Commit

Permalink
Equivalence checker: allows for let-binding in tests and output 'expe…
Browse files Browse the repository at this point in the history
…cted but got' results (#1485)
  • Loading branch information
mario-bucev authored Nov 13, 2023
1 parent 2fa1c55 commit 87f90a0
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 96 deletions.
161 changes: 81 additions & 80 deletions core/src/main/scala/stainless/equivchk/EquivalenceChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class EquivalenceChecker(override val trees: Trees,

enum Classification {
case Valid(directModel: Identifier)
case Invalid(ctex: Seq[Seq[(ValDef, Expr)]])
case Invalid(ctex: Seq[Ctex])
case Unknown
}

Expand All @@ -115,11 +115,12 @@ class EquivalenceChecker(override val trees: Trees,
// Incorrect signature
wrongs: Set[Identifier],
weights: Map[Identifier, Int])
case class Ctex(mapping: Seq[(ValDef, Expr)], expected: Expr, got: Expr)
case class Eval(expected: Expr, got: Expr)
case class Ctex(mapping: Seq[(ValDef, Expr)], eval: Option[Eval])
case class ValidData(path: Seq[Identifier], solvingInfo: SolvingInfo)
// The list of counter-examples can be empty; the candidate is still invalid but a ctex could not be extracted
// If the solvingInfo is None, the candidate has been pruned.
case class UnequivalentData(ctexs: Seq[Seq[(ValDef, Expr)]], solvingInfo: Option[SolvingInfo])
case class UnequivalentData(ctexs: Seq[Ctex], solvingInfo: Option[SolvingInfo])
case class UnsafeData(self: Seq[UnsafeCtex], auxiliaries: Map[Identifier, Seq[UnsafeCtex]])
case class UnsafeCtex(kind: VCKind, pos: Position, ctex: Option[Seq[(ValDef, Expr)]], solvingInfo: SolvingInfo)

Expand Down Expand Up @@ -304,10 +305,10 @@ class EquivalenceChecker(override val trees: Trees,
case EvalCheck.Ok =>
picked = Some(candId)
case EvalCheck.FailsTest(testId, sampleIx, ctex) =>
unequivalent += candId -> UnequivalentData(Seq(ctex.mapping), None)
unequivalent += candId -> UnequivalentData(Seq(ctex), None)
pruned += candId -> PruningReason.ByTest(testId, sampleIx, ctex)
case EvalCheck.FailsCtex(ctex) =>
unequivalent += candId -> UnequivalentData(Seq(ctex.mapping), None)
unequivalent += candId -> UnequivalentData(Seq(ctex), None)
pruned += candId -> PruningReason.ByPreviousCtex(ctex)
}
} else {
Expand All @@ -327,6 +328,8 @@ class EquivalenceChecker(override val trees: Trees,
examinationState = ExaminationState.Examining(candId, RoundState(topN.head, topN.tail, strat, EquivLemmas.ToGenerate, 0L))
NextExamination.NewCandidate(candId, topN.head, strat, pruned.toMap)
} else {
// This candidate has been tested with all models, so put it pack into unknowns
unknowns += candId -> UnknownData(SolvingInfo(0L, None, false, false))
pickNextExamination() match {
case d@NextExamination.Done(_, _) => d.copy(pruned = pruned.toMap ++ d.pruned)
case nc@NextExamination.NewCandidate(_, _, _, _) => nc.copy(pruned = pruned.toMap ++ nc.pruned)
Expand Down Expand Up @@ -469,7 +472,10 @@ class EquivalenceChecker(override val trees: Trees,
val candFd = symbols.functions(cand)
// Take all ctex for `cand`, `eqLemma` and `proof`
val ctexOrderedArgs = (Seq(cand, eqLemma) ++ proof.toSeq).flatMap(id => allCtexs.getOrElse(id, Seq.empty))
val ctexsMap = ctexOrderedArgs.map(ctex => candFd.params.zip(ctex))
val ctexsMap = ctexOrderedArgs.map { ctex =>
val eval = evalOn(symbols.functions(model), candFd, ctex)
Ctex(candFd.params.zip(ctex), eval)
}
unequivalent += cand -> UnequivalentData(ctexsMap, Some(solvingInfo.withAddedTime(currCumulativeSolvingTime)))
examinationState = ExaminationState.PickNext
RoundConclusion.CandidateClassified(cand, Classification.Invalid(ctexsMap), Set.empty)
Expand Down Expand Up @@ -622,8 +628,8 @@ class EquivalenceChecker(override val trees: Trees,
val (samples, instParams) = tests(id)
findMap(samples.zipWithIndex) { case (arg, sampleIx) =>
passTestSample(arg, instParams).map(_ -> sampleIx)
}.map { case ((evalArgs, expected, got), sampleIx) =>
EvalCheck.FailsTest(id, sampleIx, Ctex(cand.params.zip(evalArgs), expected, got))
}.map { case ((evalArgs, eval), sampleIx) =>
EvalCheck.FailsTest(id, sampleIx, Ctex(cand.params.zip(evalArgs), Some(eval)))
}
}

Expand All @@ -638,9 +644,10 @@ class EquivalenceChecker(override val trees: Trees,
loop(tests.keys.toSeq)
}

def passTestSample(arg: Expr, instTparams: Seq[Type]): Option[(Seq[Expr], Expr, Expr)] = {
def passTestSample(arg: Expr, instTparams: Seq[Type]): Option[(Seq[Expr], Eval)] = {
val evaluator = mkEvaluator()
val evalArg = try {
evaluate(arg) match {
evaluator.eval(arg) match {
case inox.evaluators.EvaluationResults.Successful(evalArg) => evalArg
case _ =>
return None // If we cannot evaluate the argument, then we consider this test to be "successful"
Expand All @@ -661,40 +668,17 @@ class EquivalenceChecker(override val trees: Trees,
val invocationCand = FunctionInvocation(cand.id, instTparams, argsSplit)
val invocationModel = FunctionInvocation(allModels.head, instTparams, argsSplit) // any model will do
try {
(evaluate(invocationCand), evaluate(invocationModel)) match {
(evaluator.eval(invocationCand), evaluator.eval(invocationModel)) match {
case (inox.evaluators.EvaluationResults.Successful(output), inox.evaluators.EvaluationResults.Successful(expected)) =>
if (output == expected) None
else Some((argsSplit, expected, output))
else Some((argsSplit, Eval(expected, output)))
case _ => None
}
} catch {
case NonFatal(_) => None
}
}

def evaluate(expr: Expr) = {
val syms: symbols.type = symbols
type ProgramType = inox.Program {val trees: self.trees.type; val symbols: syms.type}
val prog: ProgramType = inox.Program(self.trees)(syms)
val sem = new inox.Semantics {
val trees: self.trees.type = self.trees
val symbols: syms.type = syms
val program: prog.type = prog

def createEvaluator(ctx: inox.Context) = ???

def createSolver(ctx: inox.Context) = ???
}
class EvalImpl(override val program: prog.type, override val context: inox.Context)
(using override val semantics: sem.type)
extends evaluators.RecursiveEvaluator(program, context)
with inox.evaluators.HasDefaultGlobalContext
with inox.evaluators.HasDefaultRecContext

val evaluator = new EvalImpl(prog, self.context)(using sem)
evaluator.eval(expr)
}

val permutation = ArgPermutation(model.params.indices) // No permutation for top-level model and candidate
passAllTests
.orElse(evalCheckCtexOnly(model, cand, permutation).map(EvalCheck.FailsCtex.apply))
Expand All @@ -706,7 +690,7 @@ class EquivalenceChecker(override val trees: Trees,
assert(areSignaturesCompatibleModuloPerm(model, cand, candPerm))
val subst = TyParamSubst(IntegerType(), i => Some(IntegerLiteral(i)))

def passUnordCtex(ctex: UnordCtex): Option[(Seq[Expr], Expr, Expr)] = {
def passUnordCtex(ctex: UnordCtex): Option[(Seq[Expr], Eval)] = {
// From `ctex`, generate all possible ordered permutations of args according to the types
// If the type multiplicity is 1 for all params, then there is only one ordered ctex possible
val ctexSeq = ctex.args.toSeq
Expand All @@ -724,57 +708,69 @@ class EquivalenceChecker(override val trees: Trees,
tpeIxs(vdTpeInst) = tpeIxs(vdTpeInst) + 1
arg
}
passOrdCtex(ordArgs).map { case (exp, got) => (ordArgs, exp, got) }
passOrdCtex(ordArgs).map(ordArgs -> _)
}
}

def passOrdCtex(args: Seq[Expr]): Option[(Expr, Expr)] = {
val syms: symbols.type = symbols
type ProgramType = inox.Program {val trees: self.trees.type; val symbols: syms.type}
val prog: ProgramType = inox.Program(self.trees)(syms)
val sem = new inox.Semantics {
val trees: prog.trees.type = prog.trees
val symbols: syms.type = prog.symbols
val program: prog.type = prog

def createEvaluator(ctx: inox.Context) = ???

def createSolver(ctx: inox.Context) = ???
}
class EvalImpl(override val program: prog.type, override val context: inox.Context)
(using override val semantics: sem.type)
extends evaluators.RecursiveEvaluator(program, context)
with inox.evaluators.HasDefaultGlobalContext
with inox.evaluators.HasDefaultRecContext {
override lazy val maxSteps: Int = maxStepsEval
}
val evaluator = new EvalImpl(prog, self.context)(using sem)

val tparams = model.tparams.map(_ => IntegerType())
val invocationModel = evaluator.program.trees.FunctionInvocation(model.id, tparams, args)
val invocationCand = evaluator.program.trees.FunctionInvocation(cand.id, tparams, candPerm.reverse.m2c.map(args))
try {
(evaluator.eval(invocationCand), evaluator.eval(invocationModel)) match {
case (inox.evaluators.EvaluationResults.Successful(output), inox.evaluators.EvaluationResults.Successful(expected)) =>
if (output == expected) None
else Some((expected, output))
case _ => None
}
} catch {
case NonFatal(_) => None
}
}
// Returns an Option of (expected, got) if evaluation succeeds and got is different from expected
def passOrdCtex(args: Seq[Expr]): Option[Eval] =
evalOn(model, cand, args, candPerm)
.filter { case Eval(expected, got) => expected != got }

// Substitute tparams with IntegerType()
val argsTpe = model.params.map(vd => substTypeParams(model.tparams, vd.tpe)(using subst))
val unordSig = UnordSig(argsTpe.groupMapReduce(identity)(_ => 1)(_ + _))
val ctexs = ctexsDb.getOrElse(unordSig, mutable.ArrayBuffer.empty)
findMap(ctexs.toSeq)(passUnordCtex)
.map { case (ctex, expected, got) =>
.map { case (ctex, eval) =>
// ctex is ordered according to the model, so we need to reorder cand according to the permutation
val candReorg = candPerm.m2c.map(cand.params)
Ctex(candReorg.zip(ctex), expected, got)
Ctex(candReorg.zip(ctex), Some(eval))
}
}

// Evaluate `model` and `cand` with the given `args` and whose candidate argument permutation is given by `candPerm`.
// Note: this expects `args` to have generic type substituted to integers, as it is done in `ctexOrderedArguments`.
private def evalOn(model: FunDef, cand: FunDef, args: Seq[Expr], candPerm: ArgPermutation): Option[Eval] = {
val evaluator = mkEvaluator()
val tparams = model.tparams.map(_ => IntegerType())
val invocationModel = evaluator.program.trees.FunctionInvocation(model.id, tparams, args)
val invocationCand = evaluator.program.trees.FunctionInvocation(cand.id, tparams, candPerm.reverse.m2c.map(args))
try {
(evaluator.eval(invocationModel), evaluator.eval(invocationCand)) match {
case (inox.evaluators.EvaluationResults.Successful(expected), inox.evaluators.EvaluationResults.Successful(output)) =>
Some(Eval(expected, output))
case _ => None
}
} catch {
case NonFatal(_) => None
}
}

private def evalOn(model: FunDef, cand: FunDef, args: Seq[Expr]): Option[Eval] =
evalOn(model, cand, args, ArgPermutation(args.indices))

private def mkEvaluator() = {
val syms: symbols.type = symbols
type ProgramType = inox.Program {val trees: self.trees.type; val symbols: syms.type}
val prog: ProgramType = inox.Program(self.trees)(syms)
val sem = new inox.Semantics {
val trees: prog.trees.type = prog.trees
val symbols: syms.type = prog.symbols
val program: prog.type = prog

def createEvaluator(ctx: inox.Context) = sys.error("Unsupported")

def createSolver(ctx: inox.Context) = sys.error("Unsupported")
}
class EvalImpl(override val program: prog.type, override val context: inox.Context)
(using override val semantics: sem.type)
extends evaluators.RecursiveEvaluator(program, context)
with inox.evaluators.HasDefaultGlobalContext
with inox.evaluators.HasDefaultRecContext {
override lazy val maxSteps: Int = maxStepsEval
}
new EvalImpl(prog, self.context)(using sem)
}
//endregion

Expand Down Expand Up @@ -916,6 +912,7 @@ class EquivalenceChecker(override val trees: Trees,

//region Miscellaneous

// Note: expects the ctex to have type parameter substituted with integer literals (as it is done in ctexOrderedArguments).
private def addCtex(ctex: Seq[Expr]): Unit = {
val currNbCtex = ctexsDb.map(_._2.size).sum
if (currNbCtex < maxCtex) {
Expand Down Expand Up @@ -988,7 +985,7 @@ object EquivalenceChecker {
val defaultInitScore = 200
val defaultMaxMatchingPermutation = 16
val defaultMaxCtex = 1024
val defaultMaxStepsEval = 512
val defaultMaxStepsEval = 10000

type Path = Seq[String]

Expand Down Expand Up @@ -1306,15 +1303,19 @@ object EquivalenceChecker {
case _ => return ExtractedTest.Failure(TestExtractionFailure.ReturnTypeMismatch)
}

def peel(e: Expr, acc: Seq[Expr]): Either[Expr, Seq[Expr]] = e match {
type Bdgs = Seq[(ValDef, Expr)]
def peel(e: Expr, bdgs: Bdgs, samplesAcc: Seq[Expr]): Either[Expr, Seq[Expr]] = e match {
case Let(vd, e, body) =>
peel(body, bdgs :+ (vd, e), samplesAcc)
case ADT(id: SymbolIdentifier, _, Seq(head, tail)) if id.symbol.path == Seq("stainless", "collection", "Cons") =>
peel(tail, acc :+ head)
val sample = bdgs.foldRight(head) { case ((vd, e), body) => Let(vd, e, body).copiedFrom(e) }
peel(tail, bdgs, samplesAcc :+ sample)
case ADT(id: SymbolIdentifier, _, Seq()) if id.symbol.path == Seq("stainless", "collection", "Nil") =>
Right(acc)
Right(samplesAcc)
case _ => Left(e)
}

val samples = peel(fd.fullBody, Seq.empty) match {
val samples = peel(fd.fullBody, Seq.empty, Seq.empty) match {
case Left(_) => return ExtractedTest.Failure(TestExtractionFailure.UnknownExpr)
case Right(Seq()) => return ExtractedTest.Failure(TestExtractionFailure.NoData)
case Right(samplesTupled) => samplesTupled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,24 +242,22 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
}

private def debugPruned(ec: EquivalenceChecker)(pruned: Map[Identifier, ec.PruningReason]): Unit = {
def pretty(fn: Identifier, reason: ec.PruningReason): String = {
def pretty(fn: Identifier, reason: ec.PruningReason): Seq[String] = {
val rsonStr = reason match {
case ec.PruningReason.SignatureMismatch => "signature mismatch"
case ec.PruningReason.SignatureMismatch => Seq("signature mismatch")
case ec.PruningReason.ByTest(testId, sampleIx, ctex) =>
s"""test falsification by ${testId.fullName} sample n°${sampleIx + 1} with ${prettyCtex(ec)(ctex.mapping)}
| Expected: ${ctex.expected} but got: ${ctex.got}""".stripMargin
Seq(s"test falsification by ${testId.fullName} sample n°${sampleIx + 1}:") ++ prettyCtex(ec)(ctex).map(" " ++ _) // add 2 indentation
case ec.PruningReason.ByPreviousCtex(ctex) =>
s"""counter-example falsification with ${prettyCtex(ec)(ctex.mapping)}
| Expected: ${ctex.expected} but got: ${ctex.got}""".stripMargin
Seq(s"counter-example falsification:") ++ prettyCtex(ec)(ctex).map(" " ++ _)
}
s"${fn.fullName}: $rsonStr"
Seq(s"${fn.fullName}:") ++ rsonStr.map(" " ++ _)
}

if (pruned.nonEmpty) {
context.reporter.whenDebug(DebugSectionEquivChk) { debug =>
debug("The following functions were pruned:")
val strs = pruned.toSeq.sortBy(_._1).map(pretty.tupled)
strs.foreach(s => debug(s" $s"))
val lines = pruned.toSeq.sortBy(_._1).flatMap(pretty.tupled)
lines.foreach(s => debug(s" $s"))
}
}
}
Expand All @@ -281,7 +279,7 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
val msg = classification match {
case ec.Classification.Valid(model) => s"valid whose direct model is ${model.fullName}"
case ec.Classification.Invalid(ctexs) =>
val ctexStr = ctexs.map(prettyCtex(ec)(_)).map(s => s" $s").mkString("\n ")
val ctexStr = ctexs.flatMap(prettyCtex(ec)(_)).map(s => s" $s").mkString("\n ")
s"invalid with the following counter-examples:\n $ctexStr"
case ec.Classification.Unknown => "unknown"
}
Expand Down Expand Up @@ -320,7 +318,7 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
info(s"Path for the function ${cand.fullName}: $pathStr")
}
res.unequivalent.foreach { case (cand, data) =>
val ctexsStr = data.ctexs.map(ctex => ctex.map { case (vd, arg) => s"${vd.id.name} -> $arg" }.mkString(", "))
val ctexsStr = data.ctexs.flatMap(prettyCtex(ec)(_))
info(s"Unequivalence counterexample for the function ${cand.fullName}:")
ctexsStr.foreach(s => info(s" $s"))
}
Expand All @@ -343,16 +341,29 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
}
}

private def prettyCtex(ec: EquivalenceChecker)(ctex: Seq[(ec.trees.ValDef, ec.trees.Expr)]): String =
ctex.map { case (vd, e) => s"${vd.id.name} -> $e" }.mkString(", ")
// This returns a Seq due to being multiline and to allow easier indentation
private def prettyCtex(ec: EquivalenceChecker)(ctex: ec.Ctex): Seq[String] = {
val args = ctex.mapping.map { case (vd, e) => s"${vd.id.name} -> $e" }.mkString(", ")
val eval = ctex.eval.map(ev => Seq(s"Expected ${ev.expected} but got ${ev.got}")).getOrElse(Seq.empty)
Seq(args) ++ eval
}

private def dumpResultsJson(out: File, ec: EquivalenceChecker)(res: ec.Results): Unit = {
def ctexJson(ctex: Seq[(ec.trees.ValDef, ec.trees.Expr)]): Json =
Json.fromFields(ctex.map { case (vd, expr) => vd.id.name -> Json.fromString(expr.toString) })
def ctexJson(ctex: ec.Ctex): Json = {
val args = Json.fromFields(ctex.mapping.map { case (vd, expr) => vd.id.name -> Json.fromString(expr.toString) })
Json.fromFields(
Seq("args" -> args) ++
ctex.eval.map(ev => Seq(
"expected" -> Json.fromString(ev.expected.toString),
"got" -> Json.fromString(ev.got.toString)
)).getOrElse(Seq.empty)
)
}

def unsafeCtexJson(data: ec.UnsafeCtex): Json = Json.fromFields(Seq(
"kind" -> Json.fromString(data.kind.name),
"position" -> Json.fromString(s"${data.pos.line}:${data.pos.col}"),
"ctex" -> data.ctex.map(ctexJson).getOrElse(Json.Null)
"ctex" -> data.ctex.map(mapping => ctexJson(ec.Ctex(mapping, None))).getOrElse(Json.Null)
))

val equivs = res.equiv.map { case (m, l) => m.fullName -> l.map(_.fullName).toSeq.sorted }
Expand Down
15 changes: 15 additions & 0 deletions frontends/benchmarks/equivalence/separate/Candidate1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import defs._

object Candidate1 {
def separate(xs: List[Animal]): (List[Sheep], List[Goat]) = {
xs match {
case Nil => (Nil, Nil)
case (s: Sheep) :: t =>
val (s2, g2) = separate(t)
(s :: s2, g2)
case (g: Goat) :: t =>
val (s2, g2) = separate(t)
(s2, g :: g2)
}
}
}
Loading

0 comments on commit 87f90a0

Please sign in to comment.