Skip to content

Commit

Permalink
preprocessing/proof: Refactor to not use NodeManager::currentNM() (cv…
Browse files Browse the repository at this point in the history
…c5#11459)

This PR introduces some calls to `NodeManager::currentNM()`, which will
be removed in subsequent PRs.
  • Loading branch information
daniel-larraz authored Dec 18, 2024
1 parent d9249b3 commit da255fd
Show file tree
Hide file tree
Showing 32 changed files with 69 additions and 59 deletions.
7 changes: 3 additions & 4 deletions src/preprocessing/passes/synth_rew_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
}

std::vector<TypeNode> SynthRewRulesPass::getGrammarsFrom(
const std::vector<Node>& assertions, uint64_t nvars)
NodeManager* nm, const std::vector<Node>& assertions, uint64_t nvars)
{
std::vector<TypeNode> ret;
std::map<TypeNode, TypeNode> tlGrammarTypes =
constructTopLevelGrammar(assertions, nvars);
constructTopLevelGrammar(nm, assertions, nvars);
for (std::pair<const TypeNode, TypeNode> ttp : tlGrammarTypes)
{
ret.push_back(ttp.second);
Expand All @@ -64,14 +64,13 @@ std::vector<TypeNode> SynthRewRulesPass::getGrammarsFrom(
}

std::map<TypeNode, TypeNode> SynthRewRulesPass::constructTopLevelGrammar(
const std::vector<Node>& assertions, uint64_t nvars)
NodeManager* nm, const std::vector<Node>& assertions, uint64_t nvars)
{
std::map<TypeNode, TypeNode> tlGrammarTypes;
if (assertions.empty())
{
return tlGrammarTypes;
}
NodeManager* nm = NodeManager::currentNM();
// initialize the candidate rewrite
std::unordered_map<TNode, bool> visited;
std::unordered_map<TNode, bool>::iterator it;
Expand Down
4 changes: 2 additions & 2 deletions src/preprocessing/passes/synth_rew_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ class SynthRewRulesPass : public PreprocessingPass
SynthRewRulesPass(PreprocessingPassContext* preprocContext);

static std::vector<TypeNode> getGrammarsFrom(
const std::vector<Node>& assertions, uint64_t nvars);
NodeManager* nm, const std::vector<Node>& assertions, uint64_t nvars);

protected:
static std::map<TypeNode, TypeNode> constructTopLevelGrammar(
const std::vector<Node>& assertions, uint64_t nvars);
NodeManager* nm, const std::vector<Node>& assertions, uint64_t nvars);
PreprocessingPassResult applyInternal(
AssertionPipeline* assertionsToPreprocess) override;
};
Expand Down
2 changes: 1 addition & 1 deletion src/proof/lazy_proof.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ void LazyCDProof::addLazyStep(Node expected,
}
Trace("lazy-cdproof") << "LazyCDProof::addLazyStep: " << expected
<< " set (trusted) step " << idNull << "\n";
Node tid = mkTrustId(idNull);
Node tid = mkTrustId(nodeManager(), idNull);
addStep(expected, ProofRule::TRUST, {}, {tid, expected});
return;
}
Expand Down
2 changes: 1 addition & 1 deletion src/proof/lazy_tree_proof_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void LazyTreeProofGenerator::setCurrentTrust(size_t objectId,
Node proven)
{
std::vector<Node> newArgs;
newArgs.push_back(mkTrustId(tid));
newArgs.push_back(mkTrustId(nodeManager(), tid));
newArgs.push_back(proven);
newArgs.insert(newArgs.end(), args.begin(), args.end());
setCurrent(objectId, ProofRule::TRUST, premise, newArgs, proven);
Expand Down
14 changes: 7 additions & 7 deletions src/proof/method_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ std::ostream& operator<<(std::ostream& out, MethodId id)
return out;
}

Node mkMethodId(MethodId id)
Node mkMethodId(NodeManager* nm, MethodId id)
{
return NodeManager::currentNM()->mkConstInt(
Rational(static_cast<uint32_t>(id)));
return nm->mkConstInt(Rational(static_cast<uint32_t>(id)));
}

bool getMethodId(TNode n, MethodId& i)
Expand Down Expand Up @@ -98,7 +97,8 @@ bool getMethodIds(const std::vector<Node>& args,
return true;
}

void addMethodIds(std::vector<Node>& args,
void addMethodIds(NodeManager* nm,
std::vector<Node>& args,
MethodId ids,
MethodId ida,
MethodId idr)
Expand All @@ -107,15 +107,15 @@ void addMethodIds(std::vector<Node>& args,
bool ndefApply = (ida != MethodId::SBA_SEQUENTIAL);
if (ids != MethodId::SB_DEFAULT || ndefRewriter || ndefApply)
{
args.push_back(mkMethodId(ids));
args.push_back(mkMethodId(nm, ids));
}
if (ndefApply || ndefRewriter)
{
args.push_back(mkMethodId(ida));
args.push_back(mkMethodId(nm, ida));
}
if (ndefRewriter)
{
args.push_back(mkMethodId(idr));
args.push_back(mkMethodId(nm, idr));
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/proof/method_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ const char* toString(MethodId id);
/** Write a rewriter id to out */
std::ostream& operator<<(std::ostream& out, MethodId id);
/** Make a method id node */
Node mkMethodId(MethodId id);
Node mkMethodId(NodeManager* nm, MethodId id);

/** get a method identifier from a node, return false if we fail */
bool getMethodId(TNode n, MethodId& i);
Expand All @@ -102,7 +102,8 @@ bool getMethodIds(const std::vector<Node>& args,
* Add method identifiers ids, ida and idr as nodes to args. This does not add
* ids, ida or idr if their values are the default ones.
*/
void addMethodIds(std::vector<Node>& args,
void addMethodIds(NodeManager* nm,
std::vector<Node>& args,
MethodId ids,
MethodId ida,
MethodId idr);
Expand Down
2 changes: 1 addition & 1 deletion src/proof/proof.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ bool CDProof::addTrustedStep(Node expected,
CDPOverwrite opolicy)
{
std::vector<Node> sargs;
sargs.push_back(mkTrustId(id));
sargs.push_back(mkTrustId(nodeManager(), id));
sargs.push_back(expected);
sargs.insert(sargs.end(), args.begin(), args.end());
return addStep(
Expand Down
3 changes: 2 additions & 1 deletion src/proof/proof_node_algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ ProofRule getCongRule(const Node& n, std::vector<Node>& args)
break;
}
// Add the arguments
args.push_back(ProofRuleChecker::mkKindNode(k));
NodeManager* nm = NodeManager::currentNM();
args.push_back(ProofRuleChecker::mkKindNode(nm, k));
if (kind::metaKindOf(k) == kind::metakind::PARAMETERIZED)
{
args.push_back(n.getOperator());
Expand Down
2 changes: 1 addition & 1 deletion src/proof/proof_node_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ std::shared_ptr<ProofNode> ProofNodeManager::mkTrustedNode(
const Node& conc)
{
std::vector<Node> sargs;
sargs.push_back(mkTrustId(id));
sargs.push_back(mkTrustId(NodeManager::currentNM(), id));
sargs.push_back(conc);
sargs.insert(sargs.end(), args.begin(), args.end());
return mkNode(ProofRule::TRUST, children, sargs);
Expand Down
5 changes: 2 additions & 3 deletions src/proof/proof_rule_checker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,14 @@ bool ProofRuleChecker::getKind(TNode n, Kind& k)
return true;
}

Node ProofRuleChecker::mkKindNode(Kind k)
Node ProofRuleChecker::mkKindNode(NodeManager* nm, Kind k)
{
if (k == Kind::UNDEFINED_KIND)
{
// UNDEFINED_KIND is negative, hence return null to avoid cast
return Node::null();
}
return NodeManager::currentNM()->mkConstInt(
Rational(static_cast<uint32_t>(k)));
return nm->mkConstInt(Rational(static_cast<uint32_t>(k)));
}

NodeManager* ProofRuleChecker::nodeManager() const { return d_nm; }
Expand Down
2 changes: 1 addition & 1 deletion src/proof/proof_rule_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ProofRuleChecker
/** get a Kind from a node, return false if we fail */
static bool getKind(TNode n, Kind& k);
/** Make a Kind into a node */
static Node mkKindNode(Kind k);
static Node mkKindNode(NodeManager* nm, Kind k);

/** Register all rules owned by this rule checker into pc. */
virtual void registerTo(ProofChecker* pc) {}
Expand Down
2 changes: 1 addition & 1 deletion src/proof/proof_step_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ bool ProofStepBuffer::addTrustedStep(TrustId id,
Node conc)
{
std::vector<Node> sargs;
sargs.push_back(mkTrustId(id));
sargs.push_back(mkTrustId(NodeManager::currentNM(), id));
sargs.push_back(conc);
sargs.insert(sargs.end(), args.begin(), args.end());
return addStep(ProofRule::TRUST, children, sargs, conc);
Expand Down
4 changes: 2 additions & 2 deletions src/proof/resolution_proofs_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ std::ostream& operator<<(std::ostream& out, CrowdingLitInfo info)
return out;
}

Node eliminateCrowdingLits(bool reorderPremises,
Node eliminateCrowdingLits(NodeManager* nm,
bool reorderPremises,
const std::vector<Node>& clauseLits,
const std::vector<Node>& targetClauseLits,
const std::vector<Node>& children,
Expand All @@ -76,7 +77,6 @@ Node eliminateCrowdingLits(bool reorderPremises,
Trace("crowding-lits") << "Clause lits: " << clauseLits << "\n";
Trace("crowding-lits") << "Target lits: " << targetClauseLits << "\n\n";
std::vector<Node> newChildren{children}, newArgs{args};
NodeManager* nm = NodeManager::currentNM();
Node trueNode = nm->mkConst(true);
// get crowding lits and the position of the last clause that includes
// them. The factoring step must be added after the last inclusion and before
Expand Down
3 changes: 2 additions & 1 deletion src/proof/resolution_proofs_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ namespace proof {
* @return The resulting node of transforming MACRO_RESOLUTION into
* CHAIN_RESOLUTION according to the above idea.
*/
Node eliminateCrowdingLits(bool reorderPremises,
Node eliminateCrowdingLits(NodeManager* nm,
bool reorderPremises,
const std::vector<Node>& clauseLits,
const std::vector<Node>& targetClauseLits,
const std::vector<Node>& children,
Expand Down
6 changes: 5 additions & 1 deletion src/proof/rewrite_proof_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ RewriteProofGenerator::RewriteProofGenerator(Env& env, MethodId id)
: EnvObj(env), ProofGenerator(), d_id(id)
{
// initialize the proof args
addMethodIds(d_pargs, MethodId::SB_DEFAULT, MethodId::SBA_SEQUENTIAL, d_id);
addMethodIds(nodeManager(),
d_pargs,
MethodId::SB_DEFAULT,
MethodId::SBA_SEQUENTIAL,
d_id);
}
RewriteProofGenerator::~RewriteProofGenerator() {}

Expand Down
4 changes: 2 additions & 2 deletions src/proof/subtype_elim_proof_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ bool SubtypeElimConverterCallback::prove(const Node& src,
Node csrc = nm->mkNode(src.getKind(), conv[0], conv[1]);
if (tgt.getKind() == Kind::EQUAL)
{
Node nk = ProofRuleChecker::mkKindNode(Kind::TO_REAL);
Node nk = ProofRuleChecker::mkKindNode(nm, Kind::TO_REAL);
cdp->addStep(csrc, ProofRule::CONG, {src}, {nk});
Trace("pf-subtype-elim") << "...via " << csrc << std::endl;
if (csrc != tgt)
Expand Down Expand Up @@ -295,7 +295,7 @@ bool SubtypeElimConverterCallback::prove(const Node& src,
if (csrc != tgt)
{
Node congEq = csrc.eqNode(tgt);
Node nk = ProofRuleChecker::mkKindNode(csrc.getKind());
Node nk = ProofRuleChecker::mkKindNode(nm, csrc.getKind());
cdp->addStep(congEq, ProofRule::CONG, {convEq[0], convEq[1]}, {nk});
cdp->addStep(fullEq, ProofRule::TRANS, {rewriteEq, congEq}, {});
}
Expand Down
10 changes: 5 additions & 5 deletions src/proof/theory_proof_step_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bool TheoryProofStepBuffer::applyEqIntro(Node src,
{
std::vector<Node> args;
args.push_back(src);
addMethodIds(args, ids, ida, idr);
addMethodIds(NodeManager::currentNM(), args, ids, ida, idr);
bool added;
Node expected = src.eqNode(tgt);
Node res = tryStep(added,
Expand Down Expand Up @@ -84,7 +84,7 @@ bool TheoryProofStepBuffer::applyPredTransform(Node src,
// try to prove that tgt rewrites to src
children.insert(children.end(), exp.begin(), exp.end());
args.push_back(tgt);
addMethodIds(args, ids, ida, idr);
addMethodIds(NodeManager::currentNM(), args, ids, ida, idr);
Node res = tryStep(ProofRule::MACRO_SR_PRED_TRANSFORM,
children,
args,
Expand All @@ -108,7 +108,7 @@ bool TheoryProofStepBuffer::applyPredIntro(Node tgt,
{
std::vector<Node> args;
args.push_back(tgt);
addMethodIds(args, ids, ida, idr);
addMethodIds(NodeManager::currentNM(), args, ids, ida, idr);
Node res = tryStep(ProofRule::MACRO_SR_PRED_INTRO,
exp,
args,
Expand All @@ -131,7 +131,7 @@ Node TheoryProofStepBuffer::applyPredElim(Node src,
children.push_back(src);
children.insert(children.end(), exp.begin(), exp.end());
std::vector<Node> args;
addMethodIds(args, ids, ida, idr);
addMethodIds(NodeManager::currentNM(), args, ids, ida, idr);
bool added;
Node srcRew = tryStep(added, ProofRule::MACRO_SR_PRED_ELIM, children, args);
if (d_autoSym && added && CDProof::isSame(src, srcRew))
Expand Down Expand Up @@ -198,7 +198,7 @@ Node TheoryProofStepBuffer::factorReorderElimDoubleNeg(Node n)
Node congEq = oldn.eqNode(n);
addStep(ProofRule::NARY_CONG,
childrenEqs,
{ProofRuleChecker::mkKindNode(Kind::OR)},
{ProofRuleChecker::mkKindNode(nm, Kind::OR)},
congEq);
// add an equality resolution step to derive normalize clause
addStep(ProofRule::EQ_RESOLVE, {oldn, congEq}, {}, n);
Expand Down
5 changes: 2 additions & 3 deletions src/proof/trust_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,9 @@ std::ostream& operator<<(std::ostream& out, TrustId id)
return out;
}

Node mkTrustId(TrustId id)
Node mkTrustId(NodeManager* nm, TrustId id)
{
return NodeManager::currentNM()->mkConstInt(
Rational(static_cast<uint32_t>(id)));
return nm->mkConstInt(Rational(static_cast<uint32_t>(id)));
}

bool getTrustId(TNode n, TrustId& i)
Expand Down
2 changes: 1 addition & 1 deletion src/proof/trust_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ const char* toString(TrustId id);
/** Write a trust id to out */
std::ostream& operator<<(std::ostream& out, TrustId id);
/** Make a trust id node */
Node mkTrustId(TrustId id);
Node mkTrustId(NodeManager* nm, TrustId id);
/** get a trust identifier from a node, return false if we fail */
bool getTrustId(TNode n, TrustId& i);

Expand Down
2 changes: 1 addition & 1 deletion src/rewriter/basic_rewrite_rcons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1221,7 +1221,7 @@ bool BasicRewriteRCons::ensureProofArithPolyNormRel(CDProof* cdp,
Trace("brc-macro") << "...fail premise" << std::endl;
return false;
}
Node kn = ProofRuleChecker::mkKindNode(eq[0].getKind());
Node kn = ProofRuleChecker::mkKindNode(nodeManager(), eq[0].getKind());
if (!cdp->addStep(eq, ProofRule::ARITH_POLY_NORM_REL, {premise}, {kn}))
{
Trace("brc-macro") << "...fail application" << std::endl;
Expand Down
2 changes: 1 addition & 1 deletion src/rewriter/rewrite_db_proof_cons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ bool RewriteDbProofCons::ensureProofInternal(
cdp->addStep(cur,
ProofRule::ARITH_POLY_NORM_REL,
{pcur.d_vars[0]},
{ProofRuleChecker::mkKindNode(cur[0].getKind())});
{ProofRuleChecker::mkKindNode(nm, cur[0].getKind())});
}
}
else if (pcur.d_id == RewriteProofStatus::DSL
Expand Down
5 changes: 3 additions & 2 deletions src/smt/proof_post_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,8 @@ Node ProofPostprocessCallback::expandMacros(ProofRule id,
Trace("crowding-lits") << "..premises: " << children << "\n";
Trace("crowding-lits") << "..args: " << args << "\n";
chainConclusion =
proof::eliminateCrowdingLits(d_env.getOptions().proof.optResReconSize,
proof::eliminateCrowdingLits(nm,
d_env.getOptions().proof.optResReconSize,
chainConclusionLits,
conclusionLits,
children,
Expand Down Expand Up @@ -927,7 +928,7 @@ Node ProofPostprocessCallback::expandMacros(ProofRule id,
{
// will expand this as a default rewrite if needed
Node eqd = retCurr.eqNode(retDef);
Node mid = mkMethodId(midi);
Node mid = mkMethodId(nodeManager(), midi);
cdp->addStep(eqd, ProofRule::MACRO_REWRITE, {}, {retCurr, mid});
transEq.push_back(eqd);
}
Expand Down
5 changes: 3 additions & 2 deletions src/smt/solver_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,8 +1022,9 @@ Node SolverEngine::findSynth(modes::FindSynthTarget fst, const TypeNode& gtn)
}
uint64_t nvars = options().quantifiers.sygusRewSynthInputNVars;
std::vector<Node> asserts = getAssertionsInternal();
gtnu = preprocessing::passes::SynthRewRulesPass::getGrammarsFrom(asserts,
nvars);
NodeManager* nm = d_env->getNodeManager();
gtnu = preprocessing::passes::SynthRewRulesPass::getGrammarsFrom(
nm, asserts, nvars);
if (gtnu.empty())
{
Warning() << "Could not find grammar in find-synth :rewrite_input"
Expand Down
2 changes: 1 addition & 1 deletion src/theory/arrays/inference_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void InferenceManager::convert(ProofRule& id,
Assert(false) << "Unknown rule " << id << "\n";
}
children.push_back(exp);
args.push_back(mkTrustId(TrustId::THEORY_INFERENCE));
args.push_back(mkTrustId(nodeManager(), TrustId::THEORY_INFERENCE));
args.push_back(conc);
args.push_back(
builtin::BuiltinProofRuleChecker::mkTheoryIdNode(THEORY_ARRAYS));
Expand Down
2 changes: 1 addition & 1 deletion src/theory/datatypes/infer_proof_cons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ void InferProofCons::convert(InferenceId infer, TNode conc, TNode exp, CDProof*
// s(exp[0]) = s(exp[1]) s(exp[1]) = r
// --------------------------------------------------- TRANS
// s(exp[0]) = r
Node asn = ProofRuleChecker::mkKindNode(Kind::APPLY_SELECTOR);
Node asn = ProofRuleChecker::mkKindNode(nm, Kind::APPLY_SELECTOR);
Node seq = sl.eqNode(sr);
cdp->addStep(seq, ProofRule::CONG, {exp}, {asn, sop});
Node sceq = sr.eqNode(concEq[1]);
Expand Down
3 changes: 2 additions & 1 deletion src/theory/quantifiers/alpha_equivalence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ TrustNode AlphaEquivalence::reduceQuantifier(Node q)
// sret = q
std::vector<Node> pfArgs2;
pfArgs2.push_back(eq2);
addMethodIds(pfArgs2,
addMethodIds(nodeManager(),
pfArgs2,
MethodId::SB_DEFAULT,
MethodId::SBA_SEQUENTIAL,
MethodId::RW_EXT_REWRITE);
Expand Down
Loading

0 comments on commit da255fd

Please sign in to comment.