Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalence checker: allows for let-binding in tests and output 'expected but got' results #1485

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 81 additions & 80 deletions core/src/main/scala/stainless/equivchk/EquivalenceChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class EquivalenceChecker(override val trees: Trees,

enum Classification {
case Valid(directModel: Identifier)
case Invalid(ctex: Seq[Seq[(ValDef, Expr)]])
case Invalid(ctex: Seq[Ctex])
case Unknown
}

Expand All @@ -115,11 +115,12 @@ class EquivalenceChecker(override val trees: Trees,
// Incorrect signature
wrongs: Set[Identifier],
weights: Map[Identifier, Int])
case class Ctex(mapping: Seq[(ValDef, Expr)], expected: Expr, got: Expr)
case class Eval(expected: Expr, got: Expr)
case class Ctex(mapping: Seq[(ValDef, Expr)], eval: Option[Eval])
case class ValidData(path: Seq[Identifier], solvingInfo: SolvingInfo)
// The list of counter-examples can be empty; the candidate is still invalid but a ctex could not be extracted
// If the solvingInfo is None, the candidate has been pruned.
case class UnequivalentData(ctexs: Seq[Seq[(ValDef, Expr)]], solvingInfo: Option[SolvingInfo])
case class UnequivalentData(ctexs: Seq[Ctex], solvingInfo: Option[SolvingInfo])
case class UnsafeData(self: Seq[UnsafeCtex], auxiliaries: Map[Identifier, Seq[UnsafeCtex]])
case class UnsafeCtex(kind: VCKind, pos: Position, ctex: Option[Seq[(ValDef, Expr)]], solvingInfo: SolvingInfo)

Expand Down Expand Up @@ -304,10 +305,10 @@ class EquivalenceChecker(override val trees: Trees,
case EvalCheck.Ok =>
picked = Some(candId)
case EvalCheck.FailsTest(testId, sampleIx, ctex) =>
unequivalent += candId -> UnequivalentData(Seq(ctex.mapping), None)
unequivalent += candId -> UnequivalentData(Seq(ctex), None)
pruned += candId -> PruningReason.ByTest(testId, sampleIx, ctex)
case EvalCheck.FailsCtex(ctex) =>
unequivalent += candId -> UnequivalentData(Seq(ctex.mapping), None)
unequivalent += candId -> UnequivalentData(Seq(ctex), None)
pruned += candId -> PruningReason.ByPreviousCtex(ctex)
}
} else {
Expand All @@ -327,6 +328,8 @@ class EquivalenceChecker(override val trees: Trees,
examinationState = ExaminationState.Examining(candId, RoundState(topN.head, topN.tail, strat, EquivLemmas.ToGenerate, 0L))
NextExamination.NewCandidate(candId, topN.head, strat, pruned.toMap)
} else {
// This candidate has been tested with all models, so put it pack into unknowns
unknowns += candId -> UnknownData(SolvingInfo(0L, None, false, false))
pickNextExamination() match {
case [email protected](_, _) => d.copy(pruned = pruned.toMap ++ d.pruned)
case [email protected](_, _, _, _) => nc.copy(pruned = pruned.toMap ++ nc.pruned)
Expand Down Expand Up @@ -469,7 +472,10 @@ class EquivalenceChecker(override val trees: Trees,
val candFd = symbols.functions(cand)
// Take all ctex for `cand`, `eqLemma` and `proof`
val ctexOrderedArgs = (Seq(cand, eqLemma) ++ proof.toSeq).flatMap(id => allCtexs.getOrElse(id, Seq.empty))
val ctexsMap = ctexOrderedArgs.map(ctex => candFd.params.zip(ctex))
val ctexsMap = ctexOrderedArgs.map { ctex =>
val eval = evalOn(symbols.functions(model), candFd, ctex)
Ctex(candFd.params.zip(ctex), eval)
}
unequivalent += cand -> UnequivalentData(ctexsMap, Some(solvingInfo.withAddedTime(currCumulativeSolvingTime)))
examinationState = ExaminationState.PickNext
RoundConclusion.CandidateClassified(cand, Classification.Invalid(ctexsMap), Set.empty)
Expand Down Expand Up @@ -622,8 +628,8 @@ class EquivalenceChecker(override val trees: Trees,
val (samples, instParams) = tests(id)
findMap(samples.zipWithIndex) { case (arg, sampleIx) =>
passTestSample(arg, instParams).map(_ -> sampleIx)
}.map { case ((evalArgs, expected, got), sampleIx) =>
EvalCheck.FailsTest(id, sampleIx, Ctex(cand.params.zip(evalArgs), expected, got))
}.map { case ((evalArgs, eval), sampleIx) =>
EvalCheck.FailsTest(id, sampleIx, Ctex(cand.params.zip(evalArgs), Some(eval)))
}
}

Expand All @@ -638,9 +644,10 @@ class EquivalenceChecker(override val trees: Trees,
loop(tests.keys.toSeq)
}

def passTestSample(arg: Expr, instTparams: Seq[Type]): Option[(Seq[Expr], Expr, Expr)] = {
def passTestSample(arg: Expr, instTparams: Seq[Type]): Option[(Seq[Expr], Eval)] = {
val evaluator = mkEvaluator()
val evalArg = try {
evaluate(arg) match {
evaluator.eval(arg) match {
case inox.evaluators.EvaluationResults.Successful(evalArg) => evalArg
case _ =>
return None // If we cannot evaluate the argument, then we consider this test to be "successful"
Expand All @@ -661,40 +668,17 @@ class EquivalenceChecker(override val trees: Trees,
val invocationCand = FunctionInvocation(cand.id, instTparams, argsSplit)
val invocationModel = FunctionInvocation(allModels.head, instTparams, argsSplit) // any model will do
try {
(evaluate(invocationCand), evaluate(invocationModel)) match {
(evaluator.eval(invocationCand), evaluator.eval(invocationModel)) match {
case (inox.evaluators.EvaluationResults.Successful(output), inox.evaluators.EvaluationResults.Successful(expected)) =>
if (output == expected) None
else Some((argsSplit, expected, output))
else Some((argsSplit, Eval(expected, output)))
case _ => None
}
} catch {
case NonFatal(_) => None
}
}

def evaluate(expr: Expr) = {
val syms: symbols.type = symbols
type ProgramType = inox.Program {val trees: self.trees.type; val symbols: syms.type}
val prog: ProgramType = inox.Program(self.trees)(syms)
val sem = new inox.Semantics {
val trees: self.trees.type = self.trees
val symbols: syms.type = syms
val program: prog.type = prog

def createEvaluator(ctx: inox.Context) = ???

def createSolver(ctx: inox.Context) = ???
}
class EvalImpl(override val program: prog.type, override val context: inox.Context)
(using override val semantics: sem.type)
extends evaluators.RecursiveEvaluator(program, context)
with inox.evaluators.HasDefaultGlobalContext
with inox.evaluators.HasDefaultRecContext

val evaluator = new EvalImpl(prog, self.context)(using sem)
evaluator.eval(expr)
}

val permutation = ArgPermutation(model.params.indices) // No permutation for top-level model and candidate
passAllTests
.orElse(evalCheckCtexOnly(model, cand, permutation).map(EvalCheck.FailsCtex.apply))
Expand All @@ -706,7 +690,7 @@ class EquivalenceChecker(override val trees: Trees,
assert(areSignaturesCompatibleModuloPerm(model, cand, candPerm))
val subst = TyParamSubst(IntegerType(), i => Some(IntegerLiteral(i)))

def passUnordCtex(ctex: UnordCtex): Option[(Seq[Expr], Expr, Expr)] = {
def passUnordCtex(ctex: UnordCtex): Option[(Seq[Expr], Eval)] = {
// From `ctex`, generate all possible ordered permutations of args according to the types
// If the type multiplicity is 1 for all params, then there is only one ordered ctex possible
val ctexSeq = ctex.args.toSeq
Expand All @@ -724,57 +708,69 @@ class EquivalenceChecker(override val trees: Trees,
tpeIxs(vdTpeInst) = tpeIxs(vdTpeInst) + 1
arg
}
passOrdCtex(ordArgs).map { case (exp, got) => (ordArgs, exp, got) }
passOrdCtex(ordArgs).map(ordArgs -> _)
}
}

def passOrdCtex(args: Seq[Expr]): Option[(Expr, Expr)] = {
val syms: symbols.type = symbols
type ProgramType = inox.Program {val trees: self.trees.type; val symbols: syms.type}
val prog: ProgramType = inox.Program(self.trees)(syms)
val sem = new inox.Semantics {
val trees: prog.trees.type = prog.trees
val symbols: syms.type = prog.symbols
val program: prog.type = prog

def createEvaluator(ctx: inox.Context) = ???

def createSolver(ctx: inox.Context) = ???
}
class EvalImpl(override val program: prog.type, override val context: inox.Context)
(using override val semantics: sem.type)
extends evaluators.RecursiveEvaluator(program, context)
with inox.evaluators.HasDefaultGlobalContext
with inox.evaluators.HasDefaultRecContext {
override lazy val maxSteps: Int = maxStepsEval
}
val evaluator = new EvalImpl(prog, self.context)(using sem)

val tparams = model.tparams.map(_ => IntegerType())
val invocationModel = evaluator.program.trees.FunctionInvocation(model.id, tparams, args)
val invocationCand = evaluator.program.trees.FunctionInvocation(cand.id, tparams, candPerm.reverse.m2c.map(args))
try {
(evaluator.eval(invocationCand), evaluator.eval(invocationModel)) match {
case (inox.evaluators.EvaluationResults.Successful(output), inox.evaluators.EvaluationResults.Successful(expected)) =>
if (output == expected) None
else Some((expected, output))
case _ => None
}
} catch {
case NonFatal(_) => None
}
}
// Returns an Option of (expected, got) if evaluation succeeds and got is different from expected
def passOrdCtex(args: Seq[Expr]): Option[Eval] =
evalOn(model, cand, args, candPerm)
.filter { case Eval(expected, got) => expected != got }

// Substitute tparams with IntegerType()
val argsTpe = model.params.map(vd => substTypeParams(model.tparams, vd.tpe)(using subst))
val unordSig = UnordSig(argsTpe.groupMapReduce(identity)(_ => 1)(_ + _))
val ctexs = ctexsDb.getOrElse(unordSig, mutable.ArrayBuffer.empty)
findMap(ctexs.toSeq)(passUnordCtex)
.map { case (ctex, expected, got) =>
.map { case (ctex, eval) =>
// ctex is ordered according to the model, so we need to reorder cand according to the permutation
val candReorg = candPerm.m2c.map(cand.params)
Ctex(candReorg.zip(ctex), expected, got)
Ctex(candReorg.zip(ctex), Some(eval))
}
}

// Evaluate `model` and `cand` with the given `args` and whose candidate argument permutation is given by `candPerm`.
// Note: this expects `args` to have generic type substituted to integers, as it is done in `ctexOrderedArguments`.
private def evalOn(model: FunDef, cand: FunDef, args: Seq[Expr], candPerm: ArgPermutation): Option[Eval] = {
val evaluator = mkEvaluator()
val tparams = model.tparams.map(_ => IntegerType())
val invocationModel = evaluator.program.trees.FunctionInvocation(model.id, tparams, args)
val invocationCand = evaluator.program.trees.FunctionInvocation(cand.id, tparams, candPerm.reverse.m2c.map(args))
try {
(evaluator.eval(invocationModel), evaluator.eval(invocationCand)) match {
case (inox.evaluators.EvaluationResults.Successful(expected), inox.evaluators.EvaluationResults.Successful(output)) =>
Some(Eval(expected, output))
case _ => None
}
} catch {
case NonFatal(_) => None
}
}

private def evalOn(model: FunDef, cand: FunDef, args: Seq[Expr]): Option[Eval] =
evalOn(model, cand, args, ArgPermutation(args.indices))

private def mkEvaluator() = {
val syms: symbols.type = symbols
type ProgramType = inox.Program {val trees: self.trees.type; val symbols: syms.type}
val prog: ProgramType = inox.Program(self.trees)(syms)
val sem = new inox.Semantics {
val trees: prog.trees.type = prog.trees
val symbols: syms.type = prog.symbols
val program: prog.type = prog

def createEvaluator(ctx: inox.Context) = sys.error("Unsupported")

def createSolver(ctx: inox.Context) = sys.error("Unsupported")
}
class EvalImpl(override val program: prog.type, override val context: inox.Context)
(using override val semantics: sem.type)
extends evaluators.RecursiveEvaluator(program, context)
with inox.evaluators.HasDefaultGlobalContext
with inox.evaluators.HasDefaultRecContext {
override lazy val maxSteps: Int = maxStepsEval
}
new EvalImpl(prog, self.context)(using sem)
}
//endregion

Expand Down Expand Up @@ -916,6 +912,7 @@ class EquivalenceChecker(override val trees: Trees,

//region Miscellaneous

// Note: expects the ctex to have type parameter substituted with integer literals (as it is done in ctexOrderedArguments).
private def addCtex(ctex: Seq[Expr]): Unit = {
val currNbCtex = ctexsDb.map(_._2.size).sum
if (currNbCtex < maxCtex) {
Expand Down Expand Up @@ -988,7 +985,7 @@ object EquivalenceChecker {
val defaultInitScore = 200
val defaultMaxMatchingPermutation = 16
val defaultMaxCtex = 1024
val defaultMaxStepsEval = 512
val defaultMaxStepsEval = 10000

type Path = Seq[String]

Expand Down Expand Up @@ -1306,15 +1303,19 @@ object EquivalenceChecker {
case _ => return ExtractedTest.Failure(TestExtractionFailure.ReturnTypeMismatch)
}

def peel(e: Expr, acc: Seq[Expr]): Either[Expr, Seq[Expr]] = e match {
type Bdgs = Seq[(ValDef, Expr)]
def peel(e: Expr, bdgs: Bdgs, samplesAcc: Seq[Expr]): Either[Expr, Seq[Expr]] = e match {
case Let(vd, e, body) =>
peel(body, bdgs :+ (vd, e), samplesAcc)
case ADT(id: SymbolIdentifier, _, Seq(head, tail)) if id.symbol.path == Seq("stainless", "collection", "Cons") =>
peel(tail, acc :+ head)
val sample = bdgs.foldRight(head) { case ((vd, e), body) => Let(vd, e, body).copiedFrom(e) }
peel(tail, bdgs, samplesAcc :+ sample)
case ADT(id: SymbolIdentifier, _, Seq()) if id.symbol.path == Seq("stainless", "collection", "Nil") =>
Right(acc)
Right(samplesAcc)
case _ => Left(e)
}

val samples = peel(fd.fullBody, Seq.empty) match {
val samples = peel(fd.fullBody, Seq.empty, Seq.empty) match {
case Left(_) => return ExtractedTest.Failure(TestExtractionFailure.UnknownExpr)
case Right(Seq()) => return ExtractedTest.Failure(TestExtractionFailure.NoData)
case Right(samplesTupled) => samplesTupled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,24 +242,22 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
}

private def debugPruned(ec: EquivalenceChecker)(pruned: Map[Identifier, ec.PruningReason]): Unit = {
def pretty(fn: Identifier, reason: ec.PruningReason): String = {
def pretty(fn: Identifier, reason: ec.PruningReason): Seq[String] = {
val rsonStr = reason match {
case ec.PruningReason.SignatureMismatch => "signature mismatch"
case ec.PruningReason.SignatureMismatch => Seq("signature mismatch")
case ec.PruningReason.ByTest(testId, sampleIx, ctex) =>
s"""test falsification by ${testId.fullName} sample n°${sampleIx + 1} with ${prettyCtex(ec)(ctex.mapping)}
| Expected: ${ctex.expected} but got: ${ctex.got}""".stripMargin
Seq(s"test falsification by ${testId.fullName} sample n°${sampleIx + 1}:") ++ prettyCtex(ec)(ctex).map(" " ++ _) // add 2 indentation
case ec.PruningReason.ByPreviousCtex(ctex) =>
s"""counter-example falsification with ${prettyCtex(ec)(ctex.mapping)}
| Expected: ${ctex.expected} but got: ${ctex.got}""".stripMargin
Seq(s"counter-example falsification:") ++ prettyCtex(ec)(ctex).map(" " ++ _)
}
s"${fn.fullName}: $rsonStr"
Seq(s"${fn.fullName}:") ++ rsonStr.map(" " ++ _)
}

if (pruned.nonEmpty) {
context.reporter.whenDebug(DebugSectionEquivChk) { debug =>
debug("The following functions were pruned:")
val strs = pruned.toSeq.sortBy(_._1).map(pretty.tupled)
strs.foreach(s => debug(s" $s"))
val lines = pruned.toSeq.sortBy(_._1).flatMap(pretty.tupled)
lines.foreach(s => debug(s" $s"))
}
}
}
Expand All @@ -281,7 +279,7 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
val msg = classification match {
case ec.Classification.Valid(model) => s"valid whose direct model is ${model.fullName}"
case ec.Classification.Invalid(ctexs) =>
val ctexStr = ctexs.map(prettyCtex(ec)(_)).map(s => s" $s").mkString("\n ")
val ctexStr = ctexs.flatMap(prettyCtex(ec)(_)).map(s => s" $s").mkString("\n ")
s"invalid with the following counter-examples:\n $ctexStr"
case ec.Classification.Unknown => "unknown"
}
Expand Down Expand Up @@ -320,7 +318,7 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
info(s"Path for the function ${cand.fullName}: $pathStr")
}
res.unequivalent.foreach { case (cand, data) =>
val ctexsStr = data.ctexs.map(ctex => ctex.map { case (vd, arg) => s"${vd.id.name} -> $arg" }.mkString(", "))
val ctexsStr = data.ctexs.flatMap(prettyCtex(ec)(_))
info(s"Unequivalence counterexample for the function ${cand.fullName}:")
ctexsStr.foreach(s => info(s" $s"))
}
Expand All @@ -343,16 +341,29 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
}
}

private def prettyCtex(ec: EquivalenceChecker)(ctex: Seq[(ec.trees.ValDef, ec.trees.Expr)]): String =
ctex.map { case (vd, e) => s"${vd.id.name} -> $e" }.mkString(", ")
// This returns a Seq due to being multiline and to allow easier indentation
private def prettyCtex(ec: EquivalenceChecker)(ctex: ec.Ctex): Seq[String] = {
val args = ctex.mapping.map { case (vd, e) => s"${vd.id.name} -> $e" }.mkString(", ")
val eval = ctex.eval.map(ev => Seq(s"Expected ${ev.expected} but got ${ev.got}")).getOrElse(Seq.empty)
Seq(args) ++ eval
}

private def dumpResultsJson(out: File, ec: EquivalenceChecker)(res: ec.Results): Unit = {
def ctexJson(ctex: Seq[(ec.trees.ValDef, ec.trees.Expr)]): Json =
Json.fromFields(ctex.map { case (vd, expr) => vd.id.name -> Json.fromString(expr.toString) })
def ctexJson(ctex: ec.Ctex): Json = {
val args = Json.fromFields(ctex.mapping.map { case (vd, expr) => vd.id.name -> Json.fromString(expr.toString) })
Json.fromFields(
Seq("args" -> args) ++
ctex.eval.map(ev => Seq(
"expected" -> Json.fromString(ev.expected.toString),
"got" -> Json.fromString(ev.got.toString)
)).getOrElse(Seq.empty)
)
}

def unsafeCtexJson(data: ec.UnsafeCtex): Json = Json.fromFields(Seq(
"kind" -> Json.fromString(data.kind.name),
"position" -> Json.fromString(s"${data.pos.line}:${data.pos.col}"),
"ctex" -> data.ctex.map(ctexJson).getOrElse(Json.Null)
"ctex" -> data.ctex.map(mapping => ctexJson(ec.Ctex(mapping, None))).getOrElse(Json.Null)
))

val equivs = res.equiv.map { case (m, l) => m.fullName -> l.map(_.fullName).toSeq.sorted }
Expand Down
15 changes: 15 additions & 0 deletions frontends/benchmarks/equivalence/separate/Candidate1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import defs._

object Candidate1 {
def separate(xs: List[Animal]): (List[Sheep], List[Goat]) = {
xs match {
case Nil => (Nil, Nil)
case (s: Sheep) :: t =>
val (s2, g2) = separate(t)
(s :: s2, g2)
case (g: Goat) :: t =>
val (s2, g2) = separate(t)
(s2, g :: g2)
}
}
}
Loading