diff --git a/src/theory/arith/arith_proof_rcons.cpp b/src/theory/arith/arith_proof_rcons.cpp index 9e61922e6dc..226989606d8 100644 --- a/src/theory/arith/arith_proof_rcons.cpp +++ b/src/theory/arith/arith_proof_rcons.cpp @@ -18,10 +18,37 @@ #include "proof/proof.h" #include "theory/arith/arith_msum.h" #include "theory/arith/arith_subs.h" +#include "proof/conv_proof_generator.h" +#include "expr/term_context.h" namespace cvc5::internal { namespace theory { namespace arith { + +/** + * Arithmetic substitution term context. + */ +class ArithSubsTermContext : public TermContext +{ + public: + ArithSubsTermContext() {} + /** The initial value: valid. */ + uint32_t initialValue() const override { return 0; } + /** Compute the value of the index^th child of t whose hash is tval */ + uint32_t computeValue(TNode t, uint32_t tval, size_t index) const override + { + if (tval==0) + { + // if we should not traverse, return 1 + if (!ArithSubs::shouldTraverse(t)) + { + return 1; + } + return 0; + } + return tval; + } +}; ArithProofRCons::ArithProofRCons(Env& env, TrustId id) : EnvObj(env), d_id(id) { @@ -50,7 +77,9 @@ std::shared_ptr<ProofNode> ArithProofRCons::getProofFor(Node fact) } ArithSubs asubs; std::vector<Node> assumpsNoSolve; - std::vector<Node> assumpsSolve; + ArithSubsTermContext astc; + TConvProofGenerator tcnv(d_env, nullptr, TConvPolicy::FIXPOINT, TConvCachePolicy::NEVER, + "ArithRConsTConv", &astc); Node tgtAssump; // prove false for (const Node& a : assumps) @@ -63,17 +92,21 @@ std::shared_ptr<ProofNode> ArithProofRCons::getProofFor(Node fact) continue; } Node as = asubs.applyArith(a); - as = rewrite(as); + Node asr = rewrite(as); Trace("arith-proof-rcons") - << "...under subs+rewrite: " << as << std::endl; - if (as == d_false) + << "...under subs+rewrite: " << asr << std::endl; + if (asr == d_false) { Trace("arith-proof-rcons") << "...success!" << std::endl; - std::vector<Node> pargs; - pargs.push_back(a); - pargs.insert(pargs.end(), assumpsSolve.begin(), assumpsSolve.end()); + if (a!=as) + { + std::shared_ptr<ProofNode> pfn = tcnv.getProofForRewriting(a); + Assert (pfn.getResult()[1]==as); + cdp.addProof(pfn); + cdp.addStep(as, ProofRule::EQ_RESOLVE, {a, a.eqNode(as)}, {}); + } cdp.addStep( - d_false, ProofRule::MACRO_SR_PRED_TRANSFORM, pargs, {d_false}); + d_false, ProofRule::MACRO_SR_PRED_TRANSFORM, {as}, {d_false}); success = true; break; } @@ -96,12 +129,19 @@ std::shared_ptr<ProofNode> ArithProofRCons::getProofFor(Node fact) Trace("arith-proof-rcons") << "...solved " << m.first << " = " << val << std::endl; Node eq = m.first.eqNode(val); - std::vector<Node> pargs; - pargs.push_back(a); - pargs.insert(pargs.end(), assumpsSolve.begin(), assumpsSolve.end()); - cdp.addStep(eq, ProofRule::MACRO_SR_PRED_TRANSFORM, pargs, {eq}); - assumpsSolve.push_back(eq); + if (a!=as) + { + std::shared_ptr<ProofNode> pfn = tcnv.getProofForRewriting(a); + Assert (pfn.getResult()[1]==as); + cdp.addProof(pfn); + cdp.addStep(as, ProofRule::EQ_RESOLVE, {a, a.eqNode(as)}, {}); + } + if (as!=eq) + { + cdp.addStep(eq, ProofRule::MACRO_SR_PRED_TRANSFORM, {as}, {eq}); + } asubs.add(m.first, val); + tcnv.addRewriteStep(m.first, val, &cdp); break; } } @@ -114,7 +154,8 @@ std::shared_ptr<ProofNode> ArithProofRCons::getProofFor(Node fact) } if (!success) { - Trace("arith-proof-rcons") << "Not solved by rewriting single literal" << std::endl; + Trace("arith-proof-rcons") + << "Not solved by rewriting single literal" << std::endl; // check if two unsolved literals rewrite to the negation of one another std::vector<Node> sassumps; std::map<Node, bool> pols; @@ -122,18 +163,25 @@ std::shared_ptr<ProofNode> ArithProofRCons::getProofFor(Node fact) for (const Node& a : assumpsNoSolve) { Node as = asubs.applyArith(a); - as = rewrite(as); - Trace("arith-proof-rcons") << "...have " << as << std::endl; - std::vector<Node> pargs; - pargs.push_back(a); - pargs.insert(pargs.end(), assumpsSolve.begin(), assumpsSolve.end()); - cdp.addStep(as, ProofRule::MACRO_SR_PRED_TRANSFORM, pargs, {as}); - bool pol = as.getKind()!=Kind::NOT; - Node aslit = pol ? as : as[0]; + Node asr = rewrite(as); + Trace("arith-proof-rcons") << "...have " << asr << std::endl; + if (a!=as) + { + std::shared_ptr<ProofNode> pfn = tcnv.getProofForRewriting(a); + Assert (pfn.getResult()[1]==as); + cdp.addProof(pfn); + cdp.addStep(as, ProofRule::EQ_RESOLVE, {a, a.eqNode(as)}, {}); + } + if (as!=asr) + { + cdp.addStep(asr, ProofRule::MACRO_SR_PRED_TRANSFORM, {as}, {asr}); + } + bool pol = asr.getKind() != Kind::NOT; + Node aslit = pol ? asr : asr[0]; itp = pols.find(aslit); - if (itp!=pols.end()) + if (itp != pols.end()) { - if (itp->second!=pol) + if (itp->second != pol) { Node asn = aslit.notNode(); cdp.addStep(d_false, ProofRule::CONTRA, {aslit, asn}, {}); diff --git a/src/theory/arith/arith_subs.cpp b/src/theory/arith/arith_subs.cpp index 2731a309dbf..c401e35316c 100644 --- a/src/theory/arith/arith_subs.cpp +++ b/src/theory/arith/arith_subs.cpp @@ -31,6 +31,10 @@ void ArithSubs::addArith(const Node& v, const Node& s) Node ArithSubs::applyArith(const Node& n, bool traverseNlMult) const { + if (d_vars.empty()) + { + return n; + } NodeManager* nm = NodeManager::currentNM(); std::unordered_map<TNode, Node> visited; std::vector<TNode> visit; @@ -44,7 +48,6 @@ Node ArithSubs::applyArith(const Node& n, bool traverseNlMult) const if (it == visited.end()) { visited[cur] = Node::null(); - Kind ck = cur.getKind(); auto s = find(cur); if (s) { @@ -56,12 +59,7 @@ Node ArithSubs::applyArith(const Node& n, bool traverseNlMult) const } else { - TheoryId ctid = theory::kindToTheoryId(ck); - if ((ctid != THEORY_ARITH && ctid != THEORY_BOOL - && ctid != THEORY_BUILTIN) - || isTranscendentalKind(ck) - || (!traverseNlMult - && (ck == Kind::NONLINEAR_MULT || ck == Kind::IAND))) + if (!shouldTraverse(cur)) { // Do not traverse beneath applications that belong to another theory // besides (core) arithmetic. Notice that transcendental function @@ -107,6 +105,21 @@ Node ArithSubs::applyArith(const Node& n, bool traverseNlMult) const return visited[n]; } +bool ArithSubs::shouldTraverse(const Node& n, bool traverseNlMult) +{ + Kind k = n.getKind(); + TheoryId ctid = theory::kindToTheoryId(k); + if ((ctid != THEORY_ARITH && ctid != THEORY_BOOL + && ctid != THEORY_BUILTIN) + || isTranscendentalKind(k) + || (!traverseNlMult + && (k == Kind::NONLINEAR_MULT || k == Kind::IAND))) + { + return false; + } + return true; +} + } // namespace arith } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/arith/arith_subs.h b/src/theory/arith/arith_subs.h index 08b375242f0..be3160d37d6 100644 --- a/src/theory/arith/arith_subs.h +++ b/src/theory/arith/arith_subs.h @@ -47,6 +47,10 @@ class ArithSubs : public Subs * @param traverseNlMult Whether to traverse applications of NONLINEAR_MULT. */ Node applyArith(const Node& n, bool traverseNlMult = true) const; + /** + * Should traverse, returns true if the above method traverses n. + */ + static bool shouldTraverse(const Node& n, bool traverseNlMult = true); }; } // namespace arith