Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalence checking as a component #1378

Merged
merged 4 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 3 additions & 23 deletions core/src/main/scala/stainless/Component.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
package stainless

import utils.{CheckFilter, DefinitionIdFinder, DependenciesFinder}
import extraction.xlang.trees as xt
import io.circe.*
import extraction.xlang.{trees => xt}
import io.circe._
import stainless.extraction.ExtractionSummary

import java.io.File
import scala.concurrent.Future

trait Component { self =>
Expand All @@ -31,27 +32,6 @@ object optFunctions extends inox.OptionDef[Seq[String]] {
val usageRhs = "f1,f2,..."
}

object optCompareFuns extends inox.OptionDef[Seq[String]] {
val name = "comparefuns"
val default = Seq[String]()
val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser)
val usageRhs = "f1,f2,..."
}

object optModels extends inox.OptionDef[Seq[String]] {
val name = "models"
val default = Seq[String]()
val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser)
val usageRhs = "f1,f2,..."
}

object optNorm extends inox.OptionDef[String] {
val name = "norm"
val default = ""
val parser = inox.OptionParsers.stringParser
val usageRhs = "f"
}

trait ComponentRun { self =>
val component: Component
val trees: ast.Trees
Expand Down
20 changes: 12 additions & 8 deletions core/src/main/scala/stainless/MainHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ trait MainHelpers extends inox.MainHelpers { self =>
case object TestsGeneration extends Category {
override def toString: String = "Tests Generation"
}
case object EquivChk extends Category {
override def toString: String = "Equivalence checking"
}

override protected def getOptions: Map[inox.OptionDef[_], Description] = super.getOptions - inox.solvers.optAssumeChecked ++ Map(
optVersion -> Description(General, "Display the version number"),
optConfigFile -> Description(General, "Path to configuration file, set to false to disable (default: stainless.conf or .stainless.conf)"),
optFunctions -> Description(General, "Only consider functions f1,f2,..."),
optCompareFuns -> Description(General, "Only consider functions f1,f2,... for equivalence checking"),
optModels -> Description(General, "Consider functions f1, f2, ... as model functions for equivalence checking"),
optNorm -> Description(General, "Use function f as normalization function for equivalence checking"),
extraction.utils.optDebugObjects -> Description(General, "Only print debug output for functions/adts named o1,o2,..."),
extraction.utils.optDebugPhases -> Description(General, {
// f interpolator does not process escape sequence, we workaround that with the following trick.
Expand All @@ -44,6 +44,7 @@ trait MainHelpers extends inox.MainHelpers { self =>
evaluators.optCodeGen -> Description(Evaluators, "Use code generating evaluator"),
codegen.optInstrumentFields -> Description(Evaluators, "Instrument ADT field access during code generation"),
codegen.optSmallArrays -> Description(Evaluators, "Assume all arrays fit into memory during code generation"),
verification.optSilent -> Description(Verification, "Do not print any message when a verification condition fails due to invalidity or timeout"),
verification.optFailEarly -> Description(Verification, "Halt verification as soon as a check fails (invalid or unknown)"),
verification.optFailInvalid -> Description(Verification, "Halt verification as soon as a check is invalid"),
verification.optVCCache -> Description(Verification, "Enable caching of verification conditions"),
Expand Down Expand Up @@ -77,6 +78,13 @@ trait MainHelpers extends inox.MainHelpers { self =>
utils.Caches.optCacheDir -> Description(General, "Specify the directory in which cache files should be stored"),
testgen.optOutputFile -> Description(TestsGeneration, "Specify the output file"),
testgen.optGenCIncludes -> Description(TestsGeneration, "(GenC variant only) Specify header includes"),
equivchk.optCompareFuns -> Description(EquivChk, "Only consider functions f1,f2,... for equivalence checking"),
equivchk.optModels -> Description(EquivChk, "Consider functions f1, f2, ... as model functions for equivalence checking"),
equivchk.optNorm -> Description(EquivChk, "Use function f as normalization function for equivalence checking"),
equivchk.optEquivalenceOutput -> Description(EquivChk, "JSON output file for equivalence checking"),
equivchk.optN -> Description(EquivChk, "Consider the top N models"),
equivchk.optInitScore -> Description(EquivChk, "Initial score for models, must be positive"),
equivchk.optMaxPerm -> Description(EquivChk, "Maximum number of permutations to be tested when matching auxiliary functions"),
) ++ MainHelpers.components.map { component =>
val option = inox.FlagOptionDef(component.name, default = false)
option -> Description(Pipelines, component.description)
Expand Down Expand Up @@ -108,6 +116,7 @@ trait MainHelpers extends inox.MainHelpers { self =>
frontend.DebugSectionRecovery,
frontend.DebugSectionExtraDeps,
genc.DebugSectionGenC,
equivchk.DebugSectionEquivChk
)

override protected def displayVersion(reporter: inox.Reporter): Unit = {
Expand Down Expand Up @@ -186,11 +195,6 @@ trait MainHelpers extends inox.MainHelpers { self =>
}

import ctx.{reporter, timers}

if (extraction.trace.Trace.optionsError) {
reporter.fatalError(s"Equivalence checking for --comparefuns and --models only works in batched mode.")
}

if (!useParallelism) {
reporter.warning(s"Parallelism is disabled.")
}
Expand Down
61 changes: 61 additions & 0 deletions core/src/main/scala/stainless/ast/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,65 @@ trait TypeOps extends inox.ast.TypeOps {
}.transform(tpe)
}

protected class Unsolvable extends Exception
protected def unsolvable = throw new Unsolvable

/** Collects the constraints that need to be solved for [[unify]].
* Note: this is an override point. */
protected def unificationConstraints(t1: Type, t2: Type, free: Seq[TypeParameter]): List[(TypeParameter, Type)] = (t1, t2) match {
case (adt: ADTType, _) if adt.lookupSort.isEmpty => unsolvable
case (_, adt: ADTType) if adt.lookupSort.isEmpty => unsolvable

case _ if t1 == t2 => Nil

case (adt1: ADTType, adt2: ADTType) if adt1.id == adt2.id =>
(adt1.tps zip adt2.tps).toList flatMap (p => unificationConstraints(p._1, p._2, free))

case (rt: RefinementType, _) => unificationConstraints(rt.getType, t2, free)
case (_, rt: RefinementType) => unificationConstraints(t1, rt.getType, free)

case (pi: PiType, _) => unificationConstraints(pi.getType, t2, free)
case (_, pi: PiType) => unificationConstraints(t1, pi.getType, free)

case (sigma: SigmaType, _) => unificationConstraints(sigma.getType, t2, free)
case (_, sigma: SigmaType) => unificationConstraints(t1, sigma.getType, free)

case (tp: TypeParameter, _) if !(typeOps.typeParamsOf(t2) contains tp) && (free contains tp) => List(tp -> t2)
case (_, tp: TypeParameter) if !(typeOps.typeParamsOf(t1) contains tp) && (free contains tp) => List(tp -> t1)
case (_: TypeParameter, _) => unsolvable
case (_, _: TypeParameter) => unsolvable

case typeOps.Same(NAryType(ts1, _), NAryType(ts2, _)) if ts1.size == ts2.size =>
(ts1 zip ts2).toList flatMap (p => unificationConstraints(p._1, p._2, free))
case _ => unsolvable
}

/** Solves the constraints collected by [[unificationConstraints]].
* Note: this is an override point. */
protected def unificationSolution(const: List[(Type, Type)]): List[(TypeParameter, Type)] = const match {
case Nil => Nil
case (tp: TypeParameter, t) :: tl =>
val replaced = tl map { case (t1, t2) =>
(typeOps.instantiateType(t1, Map(tp -> t)), typeOps.instantiateType(t2, Map(tp -> t)))
}
(tp -> t) :: unificationSolution(replaced)
case (adt: ADTType, _) :: tl if adt.lookupSort.isEmpty => unsolvable
case (_, adt: ADTType) :: tl if adt.lookupSort.isEmpty => unsolvable
case (ADTType(id1, tps1), ADTType(id2, tps2)) :: tl if id1 == id2 =>
unificationSolution((tps1 zip tps2).toList ++ tl)
case typeOps.Same(NAryType(ts1, _), NAryType(ts2, _)) :: tl if ts1.size == ts2.size =>
unificationSolution((ts1 zip ts2).toList ++ tl)
case _ =>
unsolvable
}

/** Unifies two types, under a set of free variables */
def unify(t1: Type, t2: Type, free: Seq[TypeParameter]): Option[List[(TypeParameter, Type)]] = {
try {
Some(unificationSolution(unificationConstraints(t1, t2, free)))
} catch {
case _: Unsolvable => None
}
}

}
Loading