diff --git a/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.cpp b/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.cpp index d77c1fba21e..be0c8d9d941 100644 --- a/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.cpp +++ b/src/theory/arith/nl/ext/arith_nl_compare_proof_gen.cpp @@ -69,7 +69,8 @@ std::shared_ptr ArithNlCompareProofGenerator::getProofFor(Node fact) if (e != ec) { Node eeq = e.eqNode(ec); - cdp.addTrustedStep(eeq, TrustId::ARITH_NL_COMPARE_LIT_TRANSFORM, {}, {}); + Node eeqSym = ec.eqNode(e); + cdp.addTrustedStep(eeqSym, TrustId::ARITH_NL_COMPARE_LIT_TRANSFORM, {}, {}); cdp.addStep(ec, ProofRule::EQ_RESOLVE, {e, eeq}, {}); } // add to product @@ -166,17 +167,33 @@ Kind ArithNlCompareProofGenerator::decomposeCompareLit(const Node& lit, { return Kind::UNDEFINED_KIND; } - a.emplace_back(lit[0][0]); - b.emplace_back(lit[1][0]); + addProduct(lit[0][0], a); + addProduct(lit[1][0], b); } else { - a.emplace_back(lit[0]); - b.emplace_back(lit[1]); + addProduct(lit[0], a); + addProduct(lit[1], b); } return k; } +void ArithNlCompareProofGenerator::addProduct(const Node& n, std::vector& vec) +{ + if (n.getKind()==Kind::NONLINEAR_MULT) + { + vec.insert(vec.end(), n.begin(), n.end()); + } + else if (n.isConst() && n.getConst().isOne()) + { + // do nothing + } + else + { + vec.emplace_back(n); + } +} + Kind ArithNlCompareProofGenerator::combineRelation(Kind k1, Kind k2) { if (k2 == Kind::EQUAL) @@ -204,6 +221,24 @@ Kind ArithNlCompareProofGenerator::combineRelation(Kind k1, Kind k2) return Kind::UNDEFINED_KIND; } +bool ArithNlCompareProofGenerator::diffProduct(const std::vector& a, const std::vector& b, std::map& diff) +{ + size_t indexb = 0; + for (size_t i=0, nmona=a.size(); i& b); /** */ static Kind combineRelation(Kind k1, Kind k2); + /** */ + static void addProduct(const Node& n, std::vector& vec); + /** */ + static bool diffProduct(const std::vector& a, const std::vector& b, std::map& diff); }; } // namespace nl diff --git a/src/theory/arith/nl/ext/monomial_check.cpp b/src/theory/arith/nl/ext/monomial_check.cpp index 67f9ef0952a..c58439b02a3 100644 --- a/src/theory/arith/nl/ext/monomial_check.cpp +++ b/src/theory/arith/nl/ext/monomial_check.cpp @@ -423,7 +423,7 @@ bool MonomialCheck::compareMonomial( << "infer : " << oa << " <" << status << "> " << ob << std::endl; if (status == 2) { - // must state that all variables are non-zero + // must state that all variables are non-zero and not absolute for (const Node& v : vla) { exp.push_back(v.eqNode(mkZero(v.getType())).negate()); @@ -437,7 +437,7 @@ bool MonomialCheck::compareMonomial( Node conc = mkAndNotifyLit(oa, ob, status, true); Node clem = nm->mkNode(Kind::IMPLIES, nm->mkAnd(exp), conc); Trace("nl-ext-comp-lemma") << "comparison lemma : " << clem << std::endl; - // use special proof generator + // use dedicated proof generator d_ancPfGen lem.emplace_back(InferenceId::ARITH_NL_COMPARISON, clem, LemmaProperty::NONE, diff --git a/src/theory/arith/nl/ext/proof_checker.cpp b/src/theory/arith/nl/ext/proof_checker.cpp index e92d8bbef2c..3cf227e806e 100644 --- a/src/theory/arith/nl/ext/proof_checker.cpp +++ b/src/theory/arith/nl/ext/proof_checker.cpp @@ -161,13 +161,14 @@ Node ExtProofRuleChecker::checkInternal(ProofRule id, std::vector eproda; std::vector eprodb; Kind k = Kind::EQUAL; - std::vector deq; + std::unordered_set deqZero; for (const Node& c : children) { Kind ck = c.getKind(); - if (ck == Kind::NOT && c[0].getKind() == Kind::EQUAL) + // it may be a disequality with zero + if (ck == Kind::NOT && c[0].getKind() == Kind::EQUAL && c[0][1].isConst() && c[0][1].getConst().isZero()) { - deq.emplace_back(c); + deqZero.insert(c[0][0]); continue; } ck = ArithNlCompareProofGenerator::decomposeCompareLit( @@ -191,7 +192,33 @@ Node ExtProofRuleChecker::checkInternal(ProofRule id, { return Node::null(); } - + // now ensure that the products align + std::sort(eproda.begin(), eproda.end()); + std::sort(eprodb.begin(), eprodb.end()); + std::sort(cproda.begin(), cproda.end()); + std::sort(cprodb.begin(), cprodb.end()); + std::map diffa; + if (!ArithNlCompareProofGenerator::diffProduct(cproda, eproda, diffa)) + { + // explained monomomials are not a subset of the conclusion for LHS + return Node::null(); + } + std::map diffb; + if (!ArithNlCompareProofGenerator::diffProduct(cprodb, eprodb, diffb)) + { + // explained monomomials are not a subset of the conclusion for RHS + return Node::null(); + } + if (diffa!=diffb) + { + // the conclusion is not a product of what was proven + return Node::null(); + } + // variables must be non-zero if strict + if (k==Kind::GT || k==Kind::LT) + { + + } return args[0]; } return Node::null();