From fbbf9a7416842456bf30e6ce983e46cbdeda8058 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 17 Dec 2024 10:34:39 -0600 Subject: [PATCH] Add Eunoia definitions for 3 simple theory rewrites (#11449) Includes minor refactoring of the sets insert elimination rewrite to make the Eunoia definition simpler. --- proofs/eo/cpc/Cpc.eo | 6 ++-- proofs/eo/cpc/programs/Arith.eo | 26 ++++++++++++++++- proofs/eo/cpc/programs/BitVectors.eo | 22 +++++++------- proofs/eo/cpc/rules/Arith.eo | 15 ++++++++++ proofs/eo/cpc/rules/BitVectors.eo | 18 ++++++++++-- proofs/eo/cpc/rules/Sets.eo | 37 ++++++++++++++++++++++++ src/proof/alf/alf_printer.cpp | 3 ++ src/theory/sets/theory_sets_rewriter.cpp | 12 ++++---- 8 files changed, 116 insertions(+), 23 deletions(-) diff --git a/proofs/eo/cpc/Cpc.eo b/proofs/eo/cpc/Cpc.eo index e55143392c3..2fcbd9a804e 100644 --- a/proofs/eo/cpc/Cpc.eo +++ b/proofs/eo/cpc/Cpc.eo @@ -211,11 +211,11 @@ (($run_evaluate (bvsge xb yb)) (eo::define ((ex ($bv_to_signed_int ($run_evaluate xb)))) (eo::define ((ey ($bv_to_signed_int ($run_evaluate yb)))) (eo::or (eo::gt ex ey) (eo::is_eq ex ey))))) - (($run_evaluate (repeat n xb)) ($bv_eval_repeat ($run_evaluate n) ($run_evaluate xb))) + (($run_evaluate (repeat n xb)) ($run_evaluate ($bv_unfold_repeat ($run_evaluate n) ($run_evaluate xb)))) (($run_evaluate (sign_extend n xb)) (eo::define ((ex ($run_evaluate xb))) - (eo::concat ($bv_eval_repeat ($run_evaluate n) ($bv_sign_bit ex)) ex))) + (eo::concat ($run_evaluate ($bv_unfold_repeat ($run_evaluate n) ($bv_sign_bit ex))) ex))) (($run_evaluate (zero_extend n xb)) (eo::define ((ex ($run_evaluate xb))) - (eo::concat ($bv_eval_repeat ($run_evaluate n) #b0) ex))) + (eo::concat ($run_evaluate ($bv_unfold_repeat ($run_evaluate n) #b0)) ex))) (($run_evaluate (@bv n m)) (eo::to_bin ($run_evaluate m) ($run_evaluate n))) (($run_evaluate (@bvsize x)) ($bv_bitwidth (eo::typeof x))) diff --git a/proofs/eo/cpc/programs/Arith.eo b/proofs/eo/cpc/programs/Arith.eo index af636e8e801..79dec9aa8be 100644 --- a/proofs/eo/cpc/programs/Arith.eo +++ b/proofs/eo/cpc/programs/Arith.eo @@ -106,7 +106,7 @@ ; define: $arith_eval_int_pow_2 ; args: -; - x Int: The term to compute whether it is a power of two. +; - x Int: The term to compute take as the exponent of two. ; return: > ; two raised to the power of x. If x is not a numeral value, we return ; the term (int.pow2 x). @@ -140,3 +140,27 @@ (eo::ite (eo::is_z x) (eo::ite (eo::is_neg x) false ($arith_eval_int_is_pow_2_rec x)) (int.ispow2 x))) + + +; define: $arith_unfold_pow_rec +; args: +; - n Int: The number of times to multiply, expected to be a non-negative numeral. +; - a T: The term to muliply. +; return: The result of multiplying a, n times. +(program $arith_unfold_pow_rec ((T Type) (n Int) (a T)) + (Int T) T + ( + (($arith_unfold_pow_rec 0 a) 1) + (($arith_unfold_pow_rec n a) (eo::cons * a ($arith_unfold_pow_rec (eo::add n -1) a))) + ) +) + +; define: $arith_unfold_pow +; args: +; - n Int: The number of times to multiply. +; - a T: The term to muliply. +; return: The result of multiplying a, n times. If n is not a positive numeral, this returns (^ a n). +(define $arith_unfold_pow ((T Type :implicit) (n Int) (a T)) + (eo::ite (eo::and (eo::is_z n) (eo::not (eo::is_neg n))) + ($arith_unfold_pow_rec n a) + (^ a n))) diff --git a/proofs/eo/cpc/programs/BitVectors.eo b/proofs/eo/cpc/programs/BitVectors.eo index 2dd46304b97..5315db08bfb 100644 --- a/proofs/eo/cpc/programs/BitVectors.eo +++ b/proofs/eo/cpc/programs/BitVectors.eo @@ -23,29 +23,29 @@ (eo::add (eo::neg ($arith_eval_int_pow_2 wm1)) z) z)))))) -; define: $bv_eval_repeat_rec +; define: $bv_unfold_repeat_rec ; args: ; - n Int: The number of times to repeat, expected to be a non-negative numeral. -; - b (BitVec m): The bitvector term, expected to be a binary constant. -; return: The result of repeating b n times. -(program $bv_eval_repeat_rec ((m Int) (n Int) (b (BitVec m))) +; - b (BitVec m): The bitvector term. +; return: The result of concatenating b n times. +(program $bv_unfold_repeat_rec ((m Int) (n Int) (b (BitVec m))) (Int (BitVec m)) (BitVec (eo::mul n m)) ( - (($bv_eval_repeat_rec 0 b) (eo::to_bin 0 0)) - (($bv_eval_repeat_rec n b) (eo::concat b ($bv_eval_repeat_rec (eo::add n -1) b))) + (($bv_unfold_repeat_rec 0 b) (eo::to_bin 0 0)) + (($bv_unfold_repeat_rec n b) (eo::cons concat b ($bv_unfold_repeat_rec (eo::add n -1) b))) ) ) -; define: $bv_eval_repeat +; define: $bv_unfold_repeat ; args: ; - n Int: The number of times to repeat, expected to be a non-negative numeral. -; - b (BitVec m): The bitvector term, expected to be a binary constant. +; - b (BitVec m): The bitvector term. ; return: > -; The result of repeating b n times. If n is not a numeral or is negative, +; The result of concatenating b n times. If n is not a numeral or is negative, ; this returns the term (repeat n b). -(define $bv_eval_repeat ((m Int :implicit) (n Int) (b (BitVec m))) +(define $bv_unfold_repeat ((m Int :implicit) (n Int) (b (BitVec m))) (eo::ite (eo::and (eo::is_z n) (eo::not (eo::is_neg n))) - ($bv_eval_repeat_rec n b) + ($bv_unfold_repeat_rec n b) (repeat n b))) ; program: $bv_get_first_const_child diff --git a/proofs/eo/cpc/rules/Arith.eo b/proofs/eo/cpc/rules/Arith.eo index e6e44fdf524..4607d0f6990 100644 --- a/proofs/eo/cpc/rules/Arith.eo +++ b/proofs/eo/cpc/rules/Arith.eo @@ -348,3 +348,18 @@ :args (t) :conclusion ($arith_reduction_pred t) ) + +;;;;; ProofRewriteRule::ARITH_POW_ELIM + +; rule: arith-pow-elim +; implements: ProofRewriteRule::ARITH_POW_ELIM +; args: +; - eq Bool: The equality to prove with this rule. +; requires: n is integral, and b is the multiplication of a, n times. +; conclusion: the equality between a and b. +(declare-rule arith-pow-elim ((T Type) (a T) (n T) (b T)) + :args ((= (^ a n) b)) + :requires (((eo::to_q (eo::to_z n)) (eo::to_q n)) + (($singleton_elim ($arith_unfold_pow (eo::to_z n) a)) b)) + :conclusion (= (^ a n) b) +) diff --git a/proofs/eo/cpc/rules/BitVectors.eo b/proofs/eo/cpc/rules/BitVectors.eo index 1d6638ac180..96a9e64a840 100644 --- a/proofs/eo/cpc/rules/BitVectors.eo +++ b/proofs/eo/cpc/rules/BitVectors.eo @@ -1,6 +1,20 @@ (include "../programs/BitVectors.eo") -; ---------------- ProofRewriteRule::BV_BITWISE_SLICING +;;;;; ProofRewriteRule::BV_REPEAT_ELIM + +; rule: bv-repeat-elim +; implements: ProofRewriteRule::BV_REPEAT_ELIM +; args: +; - eq Bool: The equality to prove with this rule. +; requires: b is the concatenation of a, n times. +; conclusion: the equality between a and b. +(declare-rule bv-repeat-elim ((k1 Int) (k2 Int) (n Int) (a (BitVec k1)) (b (BitVec k2))) + :args ((= (repeat n a) b)) + :requires ((($singleton_elim ($bv_unfold_repeat n a)) b)) + :conclusion (= (repeat n a) b) +) + +;;;;; ProofRewriteRule::BV_BITWISE_SLICING ; program: $bv_mk_bitwise_slicing_rec ; args: @@ -84,7 +98,7 @@ :conclusion (= a b) ) -; ---------------- ProofRewriteRule::BV_BITBLAST_STEP +;;;;; ProofRule::BV_BITBLAST_STEP ; program: $bv_mk_bitblast_step_eq ; args: diff --git a/proofs/eo/cpc/rules/Sets.eo b/proofs/eo/cpc/rules/Sets.eo index 168ead31e8b..29e2468efb4 100644 --- a/proofs/eo/cpc/rules/Sets.eo +++ b/proofs/eo/cpc/rules/Sets.eo @@ -1,5 +1,7 @@ (include "../theories/Sets.eo") +;;;;; ProofRewriteRule::SETS_IS_EMPTY_EVAL + ; define: $set_is_empty_eval ; args: ; - t (Set T): The set to check. @@ -31,6 +33,8 @@ :conclusion (= (set.is_empty t) b) ) +;;;;; ProofRule::SETS_SINGLETON_INJ + ; rule: sets_singleton_inj ; implements: ProofRule::SETS_SINGLETON_INJ ; premises: @@ -41,6 +45,8 @@ :conclusion (= a b) ) +;;;;; ProofRule::SETS_EXT + ; rule: sets_ext ; implements: ProofRule::SETS_EXT ; premises: @@ -53,6 +59,35 @@ :conclusion (not (= (set.member (@sets_deq_diff a b) a) (set.member (@sets_deq_diff a b) b))) ) +;;;;; ProofRewriteRule::SETS_INSERT_ELIM + +; program: $set_eval_insert +; args: +; - es @List: The list of elements +; - t (Set T): The set to insert into. +; return: > +; The union of elements in es with t. +(program $set_eval_insert ((T Type) (x T) (xs @List :list) (t (Set T))) + (@List (Set T)) (Set T) + ( + (($set_eval_insert (@list x xs) t) (set.union (set.singleton x) ($set_eval_insert xs t))) + (($set_eval_insert @list.nil t) t) + ) +) + +; rule: sets-insert-elim +; implements: ProofRewriteRule::SETS_INSERT_ELIM +; args: +; - eq Bool: The equality to prove with this rule. +; requires: the union of the elements in the first argument with the last argument equal the right hand side. +; conclusion: the equality between a and b. +(declare-rule sets-insert-elim ((T Type) (s (Set T)) (es @List) (t (Set T))) + :args ((= (set.insert es s) t)) + :requires ((($set_eval_insert es s) t)) + :conclusion (= (set.insert es s) t) +) + +;;;;; ProofRewriteRule::SETS_FILTER_DOWN ; rule: sets_filter_down ; implements: ProofRewriteRule::SETS_FILTER_DOWN @@ -66,6 +101,8 @@ :conclusion (and (set.member x S) (P x)) ) +;;;;; ProofRewriteRule::SETS_FILTER_UP + ; rule: sets_filter_up ; implements: ProofRewriteRule::SETS_FILTER_UP ; args: diff --git a/src/proof/alf/alf_printer.cpp b/src/proof/alf/alf_printer.cpp index 1aabb0d34e8..151873738c4 100644 --- a/src/proof/alf/alf_printer.cpp +++ b/src/proof/alf/alf_printer.cpp @@ -264,6 +264,7 @@ bool AlfPrinter::isHandledTheoryRewrite(ProofRewriteRule id, const Node& n) case ProofRewriteRule::DISTINCT_ELIM: case ProofRewriteRule::BETA_REDUCE: case ProofRewriteRule::LAMBDA_ELIM: + case ProofRewriteRule::ARITH_POW_ELIM: case ProofRewriteRule::ARITH_STRING_PRED_ENTAIL: case ProofRewriteRule::ARITH_STRING_PRED_SAFE_APPROX: case ProofRewriteRule::EXISTS_ELIM: @@ -280,11 +281,13 @@ bool AlfPrinter::isHandledTheoryRewrite(ProofRewriteRule id, const Node& n) case ProofRewriteRule::QUANT_VAR_ELIM_EQ: case ProofRewriteRule::RE_LOOP_ELIM: case ProofRewriteRule::SETS_IS_EMPTY_EVAL: + case ProofRewriteRule::SETS_INSERT_ELIM: case ProofRewriteRule::STR_IN_RE_CONCAT_STAR_CHAR: case ProofRewriteRule::STR_IN_RE_SIGMA: case ProofRewriteRule::STR_IN_RE_SIGMA_STAR: case ProofRewriteRule::STR_IN_RE_CONSUME: case ProofRewriteRule::RE_INTER_UNION_INCLUSION: + case ProofRewriteRule::BV_REPEAT_ELIM: case ProofRewriteRule::BV_BITWISE_SLICING: return true; case ProofRewriteRule::STR_IN_RE_EVAL: Assert(n[0].getKind() == Kind::STRING_IN_REGEXP && n[0][0].isConst()); diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index 3e1d5ce862c..78c09c55094 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -62,14 +62,14 @@ Node TheorySetsRewriter::rewriteViaRule(ProofRewriteRule id, const Node& n) { NodeManager* nm = nodeManager(); size_t setNodeIndex = n.getNumChildren() - 1; - Node elems = nm->mkNode(Kind::SET_SINGLETON, n[0]); - - for (size_t i = 1; i < setNodeIndex; ++i) + Node elems = n[setNodeIndex]; + for (size_t i = 0; i < setNodeIndex; ++i) { - Node singleton = nm->mkNode(Kind::SET_SINGLETON, n[i]); - elems = nm->mkNode(Kind::SET_UNION, elems, singleton); + size_t ii = (setNodeIndex-i)-1; + Node singleton = nm->mkNode(Kind::SET_SINGLETON, n[ii]); + elems = nm->mkNode(Kind::SET_UNION, singleton, elems); } - return nm->mkNode(Kind::SET_UNION, elems, n[setNodeIndex]); + return elems; } } break;