From 64d7b55a334890b5a72c0973737fe197f53aa9bc Mon Sep 17 00:00:00 2001 From: ajreynol Date: Fri, 18 Oct 2024 09:48:49 -0500 Subject: [PATCH] More --- src/proof/conv_proof_generator.cpp | 14 ++++ src/proof/conv_proof_generator.h | 8 ++ src/proof/trust_id.cpp | 1 + src/proof/trust_id.h | 2 + .../nl/ext/arith_nl_compare_proof_gen.cpp | 65 ++++++++++++++-- .../arith/nl/ext/arith_nl_compare_proof_gen.h | 10 +-- src/theory/arith/nl/ext/monomial_check.cpp | 76 ++++++++++++------- src/theory/arith/nl/ext/monomial_check.h | 4 +- 8 files changed, 140 insertions(+), 40 deletions(-) diff --git a/src/proof/conv_proof_generator.cpp b/src/proof/conv_proof_generator.cpp index 2db8a60d6b4..780100133eb 100644 --- a/src/proof/conv_proof_generator.cpp +++ b/src/proof/conv_proof_generator.cpp @@ -112,6 +112,20 @@ void TConvProofGenerator::addRewriteStep(Node t, } } +void TConvProofGenerator::addTrustedRewriteStep(Node t, + Node s, + TrustId id, + const std::vector& children, + const std::vector& args, + bool isPre, + uint32_t tctx) +{ + std::vector targs; + targs.push_back(mkTrustId(id)); + targs.emplace_back(t.eqNode(s)); + addRewriteStep(t, s, ProofRule::TRUST, children, targs, isPre, tctx); +} + bool TConvProofGenerator::hasRewriteStep(Node t, uint32_t tctx, bool isPre) const diff --git a/src/proof/conv_proof_generator.h b/src/proof/conv_proof_generator.h index 1b7b428aa8d..58d439e5ca6 100644 --- a/src/proof/conv_proof_generator.h +++ b/src/proof/conv_proof_generator.h @@ -172,6 +172,14 @@ class TConvProofGenerator : protected EnvObj, public ProofGenerator const std::vector& args, bool isPre = false, uint32_t tctx = 0); + /** Same as above, but with a trusted step */ + void addTrustedRewriteStep(Node t, + Node s, + TrustId id, + const std::vector& children, + const std::vector& args, + bool isPre = false, + uint32_t tctx = 0); /** Has rewrite step for term t */ bool hasRewriteStep(Node t, uint32_t tctx = 0, bool isPre = false) const; /** diff --git a/src/proof/trust_id.cpp b/src/proof/trust_id.cpp index d4b5e53079e..7c5f72c042d 100644 --- a/src/proof/trust_id.cpp +++ b/src/proof/trust_id.cpp @@ -38,6 +38,7 @@ const char* toString(TrustId id) case TrustId::ARITH_NL_COVERING_DIRECT: return "ARITH_NL_COVERING_DIRECT"; case TrustId::ARITH_NL_COVERING_RECURSIVE: return "ARITH_NL_COVERING_RECURSIVE"; + case TrustId::ARITH_NL_COMPARE_LIT_TRANSFORM: return "ARITH_NL_COMPARE_LIT_TRANSFORM"; case TrustId::EXT_THEORY_REWRITE: return "EXT_THEORY_REWRITE"; case TrustId::REWRITE_NO_ELABORATE: return "REWRITE_NO_ELABORATE"; case TrustId::FLATTENING_REWRITE: return "FLATTENING_REWRITE"; diff --git a/src/proof/trust_id.h b/src/proof/trust_id.h index 50a63acbebb..9e27574b97f 100644 --- a/src/proof/trust_id.h +++ b/src/proof/trust_id.h @@ -95,6 +95,8 @@ enum class TrustId : uint32_t * no :math:`x_i` exists that extends the cell and satisfies all assumptions. */ ARITH_NL_COVERING_RECURSIVE, + /** */ + ARITH_NL_COMPARE_LIT_TRANSFORM, /** An extended theory rewrite */ EXT_THEORY_REWRITE, /** A rewrite whose proof could not be elaborated */ diff --git a/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.cpp b/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.cpp index b51129a1963..d737d5362c9 100644 --- a/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.cpp +++ b/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.cpp @@ -17,13 +17,15 @@ #include "theory/arith/arith_utilities.h" #include "theory/arith/nl/ext/monomial_check.h" +#include "expr/attribute.h" +#include "proof/proof.h" namespace cvc5::internal { namespace theory { namespace arith { namespace nl { -ArithNlCompareProofGenerator::ArithNlCompareProofGenerator(Env& env) : EnvObj(env), d_absConv(env) {} +ArithNlCompareProofGenerator::ArithNlCompareProofGenerator(Env& env) : EnvObj(env) {} ArithNlCompareProofGenerator::~ArithNlCompareProofGenerator() {} std::shared_ptr ArithNlCompareProofGenerator::getProofFor(Node fact) @@ -39,26 +41,75 @@ std::shared_ptr ArithNlCompareProofGenerator::getProofFor(Node fact) exp.emplace_back(fact[0]); } Node conc = fact[1]; + Trace("arith-nl-compare") << "Comparsion prove: " << exp << " => " << conc << std::endl; + // get the expected form of the literals CDProof cdp(d_env); - cdp.addStep(conc, ProofRule::MACRO_ARITH_NL_COMPARISON, exp, {conc}); + std::vector expc; + for (const Node& e : exp) + { + Node ec = getCompareLit(e); + if (ec.isNull()) + { + expc.emplace_back(e); + continue; + } + expc.emplace_back(ec); + if (e!=ec) + { + Node eeq = e.eqNode(ec); + cdp.addTrustedStep(eeq, TrustId::ARITH_NL_COMPARE_LIT_TRANSFORM, {}, {}); + cdp.addStep(ec, ProofRule::EQ_RESOLVE, {e, eeq}, {}); + } + } + Node concc = getCompareLit(conc); + Assert (!concc.isNull()); + Trace("arith-nl-compare") << "...processed prove: " << expc << " => " << concc << std::endl; + cdp.addStep(concc, ProofRule::MACRO_ARITH_NL_COMPARISON, expc, {concc}); + if (conc!=concc) + { + Node ceq = concc.eqNode(conc); + cdp.addTrustedStep(ceq, TrustId::ARITH_NL_COMPARE_LIT_TRANSFORM, {}, {}); + cdp.addStep(conc, ProofRule::EQ_RESOLVE, {concc, ceq}, {}); + } cdp.addStep(fact, ProofRule::SCOPE, {conc}, exp); + AlwaysAssert(false); return cdp.getProofFor(fact); } std::string ArithNlCompareProofGenerator::identify() const { return "ArithNlCompareProofGenerator"; } -Node ArithNlCompareProofGenerator::mkLit(NodeManager* nm, Kind k, Node a, Node b, bool isAbsolute) +Node ArithNlCompareProofGenerator::mkLit(NodeManager* nm, Kind k, const Node& a, const Node& b, bool isAbsolute) { + Node au = a; + Node bu = b; if (isAbsolute) { - a = nm->mkNode(Kind::ABS, a); - b = nm->mkNode(Kind::ABS, b); + au = nm->mkNode(Kind::ABS, au); + bu = nm->mkNode(Kind::ABS, bu); } if (k==Kind::EQUAL) { - return mkEquality(a, b); + return mkEquality(au, bu); } - return nm->mkNode(k, a, b); + return nm->mkNode(k, au, bu); +} + +struct ArithNlCompareLitAttributeId +{ +}; +using ArithNlCompareLitAttribute = expr::Attribute; + +void ArithNlCompareProofGenerator::setCompareLit(NodeManager* nm, Node olit, Kind k, const Node& a, const Node& b, bool isAbsolute) +{ + Node lit = mkLit(nm, k, a, b, isAbsolute); + ArithNlCompareLitAttribute ancla; + olit.setAttribute(ancla, lit); +} + +Node ArithNlCompareProofGenerator::getCompareLit(const Node& olit) +{ + ArithNlCompareLitAttribute ancla; + return olit.getAttribute(ancla); } } // namespace nl diff --git a/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.h b/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.h index 703cc16481f..4ee6033e47c 100644 --- a/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.h +++ b/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.h @@ -18,7 +18,6 @@ #include "smt/env_obj.h" #include "proof/proof_generator.h" -#include "proof/conv_proof_generator.h" namespace cvc5::internal { namespace theory { @@ -36,10 +35,11 @@ class ArithNlCompareProofGenerator : protected EnvObj, public ProofGenerator /** identify */ std::string identify() const override; /** Make literal */ - static Node mkLit(NodeManager* nm, Kind k, Node a, Node b, bool isAbsolute); -private: - /** Converter for absolute value literals */ - TConvProofGenerator d_absConv; + static Node mkLit(NodeManager* nm, Kind k, const Node& a, const Node& b, bool isAbsolute); + /** */ + static void setCompareLit(NodeManager* nm, Node olit, Kind k, const Node& a, const Node& b, bool isAbsolute); + /** */ + static Node getCompareLit(const Node& olit); }; } // namespace nl diff --git a/src/theory/arith/nl/ext/monomial_check.cpp b/src/theory/arith/nl/ext/monomial_check.cpp index 22d2363eec8..440bb767126 100644 --- a/src/theory/arith/nl/ext/monomial_check.cpp +++ b/src/theory/arith/nl/ext/monomial_check.cpp @@ -304,7 +304,7 @@ int MonomialCheck::compareSign( if (mvaoa.getConst().sgn() != status) { Node zero = mkZero(oa.getType()); - Node lemma = nm->mkAnd(exp).impNode(mkLit(nm, oa, zero, status * 2)); + Node lemma = nm->mkAnd(exp).impNode(mkLit(oa, zero, status * 2)); CDProof* proof = nullptr; if (d_data->isProofEnabled()) { @@ -430,7 +430,7 @@ bool MonomialCheck::compareMonomial( { cob = mkOne(oa.getType()); } - Node conc = mkLit(nm, oa, ob, status, true); + Node conc = mkLit(oa, ob, status, true); Node clem = nm->mkNode( Kind::IMPLIES, nm->mkAnd(exp), conc); Trace("nl-ext-comp-lemma") << "comparison lemma : " << clem << std::endl; @@ -511,7 +511,7 @@ bool MonomialCheck::compareMonomial( Trace("nl-ext-comp-debug") << "...take leading " << bv << std::endl; // can multiply b by <=1 Node one = mkOne(bv.getType()); - exp.push_back(mkLit(nm, one, bv, bvo == ovo ? 0 : 2, true)); + exp.push_back(mkLit(one, bv, bvo == ovo ? 0 : 2, true)); return compareMonomial(oa, a, a_index, @@ -536,7 +536,7 @@ bool MonomialCheck::compareMonomial( Trace("nl-ext-comp-debug") << "...take leading " << av << std::endl; // can multiply a by >=1 Node one = mkOne(av.getType()); - exp.push_back(mkLit(nm, av, one, avo == ovo ? 0 : 2, true)); + exp.push_back(mkLit(av, one, avo == ovo ? 0 : 2, true)); return compareMonomial(oa, a, a_index + 1, @@ -562,7 +562,7 @@ bool MonomialCheck::compareMonomial( Trace("nl-ext-comp-debug") << "...take leading " << av << std::endl; // do avo>=1 instead Node one = mkOne(av.getType()); - exp.push_back(mkLit(nm, av, one, avo == ovo ? 0 : 2, true)); + exp.push_back(mkLit(av, one, avo == ovo ? 0 : 2, true)); return compareMonomial(oa, a, a_index + 1, @@ -581,7 +581,7 @@ bool MonomialCheck::compareMonomial( b_exp_proc[bv] += min_exp; Trace("nl-ext-comp-debug") << "...take leading " << min_exp << " from " << av << " and " << bv << std::endl; - exp.push_back(mkLit(nm, av, bv, avo == bvo ? 0 : 2, true)); + exp.push_back(mkLit(av, bv, avo == bvo ? 0 : 2, true)); bool ret = compareMonomial(oa, a, a_index, @@ -603,7 +603,7 @@ bool MonomialCheck::compareMonomial( Trace("nl-ext-comp-debug") << "...take leading " << bv << std::endl; // try multiply b <= 1 Node one = mkOne(bv.getType()); - exp.push_back(mkLit(nm, one, bv, bvo == ovo ? 0 : 2, true)); + exp.push_back(mkLit(one, bv, bvo == ovo ? 0 : 2, true)); return compareMonomial(oa, a, a_index, @@ -726,39 +726,61 @@ void MonomialCheck::assignOrderIds(std::vector& vars, } } -Node MonomialCheck::mkLit(NodeManager* nm, Node a, Node b, int status, bool isAbsolute) +Node MonomialCheck::mkLit(Node a, Node b, int status, bool isAbsolute) const { + NodeManager * nm = nodeManager(); + if (status<0) + { + status = -status; + Node tmp = a; + a = b; + b = tmp; + } Assert(a.getType().isRealOrInt() && b.getType().isRealOrInt()); + Node ret; + Kind k; if (status == 0) { + k = Kind::EQUAL; Node a_eq_b = mkEquality(a, b); if (!isAbsolute) { - return a_eq_b; + ret = a_eq_b; + } + else + { + Node negate_b = nm->mkNode(Kind::NEG, b); + ret = a_eq_b.orNode(mkEquality(a, negate_b)); } - Node negate_b = nm->mkNode(Kind::NEG, b); - return a_eq_b.orNode(mkEquality(a, negate_b)); } - else if (status < 0) + else { - return mkLit(nm, b, a, -status); + Assert(status == 1 || status == 2); + k = status == 1 ? Kind::GEQ : Kind::GT; + if (!isAbsolute) + { + ret = nm->mkNode(k, a, b); + } + else + { + Node zero = mkZero(a.getType()); + Node a_is_nonnegative = nm->mkNode(Kind::GEQ, a, zero); + Node b_is_nonnegative = nm->mkNode(Kind::GEQ, b, zero); + Node negate_a = nm->mkNode(Kind::NEG, a); + Node negate_b = nm->mkNode(Kind::NEG, b); + ret = a_is_nonnegative.iteNode( + b_is_nonnegative.iteNode(nm->mkNode(k, a, b), + nm->mkNode(k, a, negate_b)), + b_is_nonnegative.iteNode(nm->mkNode(k, negate_a, b), + nm->mkNode(k, negate_a, negate_b))); + } } - Assert(status == 1 || status == 2); - Kind greater_op = status == 1 ? Kind::GEQ : Kind::GT; - if (!isAbsolute) + // if proofs are enabled, we ensure we remember what the literal represents + if (d_ancPfGen!=nullptr) { - return nm->mkNode(greater_op, a, b); + ArithNlCompareProofGenerator::setCompareLit(nm, ret, k, a, b, isAbsolute); } - Node zero = mkZero(a.getType()); - Node a_is_nonnegative = nm->mkNode(Kind::GEQ, a, zero); - Node b_is_nonnegative = nm->mkNode(Kind::GEQ, b, zero); - Node negate_a = nm->mkNode(Kind::NEG, a); - Node negate_b = nm->mkNode(Kind::NEG, b); - return a_is_nonnegative.iteNode( - b_is_nonnegative.iteNode(nm->mkNode(greater_op, a, b), - nm->mkNode(greater_op, a, negate_b)), - b_is_nonnegative.iteNode(nm->mkNode(greater_op, negate_a, b), - nm->mkNode(greater_op, negate_a, negate_b))); + return ret; } void MonomialCheck::setMonomialFactor(Node a, diff --git a/src/theory/arith/nl/ext/monomial_check.h b/src/theory/arith/nl/ext/monomial_check.h index 6e9c550bf8f..943f2dda1b1 100644 --- a/src/theory/arith/nl/ext/monomial_check.h +++ b/src/theory/arith/nl/ext/monomial_check.h @@ -76,7 +76,7 @@ class MonomialCheck : protected EnvObj void checkMagnitude(unsigned c); /** Make literal */ - static Node mkLit(NodeManager* nm, Node a, Node b, int status, bool isAbsolute = false); + static Node mkLit(NodeManager* nm, Kind k, Node a, Node b, bool isAbsolute); private: /** In the following functions, status states a relationship * between two arithmetic terms, where: @@ -178,6 +178,8 @@ class MonomialCheck : protected EnvObj NodeMultiset& d_order, bool isConcrete, bool isAbsolute); + /** Make literal */ + Node mkLit(Node a, Node b, int status, bool isAbsolute = false) const; /** register monomial */ void setMonomialFactor(Node a, Node b, const NodeMultiset& common);