Skip to content

Commit

Permalink
Support more arithmetic rewrites. (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
abdoo8080 authored Jun 7, 2024
1 parent 95747e9 commit 64d482b
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 33 deletions.
56 changes: 55 additions & 1 deletion Smt/Reconstruct/Arith.lean
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,28 @@ where

def reconstructRewrite (pf : cvc5.Proof) : ReconstructM (Option Expr) := do
match pf.getRewriteRule with
| .ARITH_DIV_BY_CONST_ELIM =>
let t : Q(Real) ← reconstructTerm pf.getResult[0]![0]!
let r := pf.getResult[0]![1]!.getRationalValue
if r.den == 1 then
let c : Q(Real) := reconstructArith.mkRealLit r.num.natAbs
if r.num ≥ 0 then
if r.num == 1 then
addThm q($t / 1 = $t * 1) q(@Arith.arith_div_by_const_elim_1_pos $t)
else
addThm q($t / $c = $t * (1 / $c)) q(@Arith.arith_div_by_const_elim_num_pos $t $c)
else
if r.num == -1 then
addThm q($t / -1 = $t * -1) q(@Arith.arith_div_by_const_elim_1_neg $t)
else
addThm q($t / -$c = $t * (-1 / $c)) q(@Arith.arith_div_by_const_elim_num_neg $t $c)
else
let c₁ : Q(Real) := reconstructArith.mkRealLit r.num.natAbs
let c₂ : Q(Real) := reconstructArith.mkRealLit r.den
if r.num ≥ 0 then
addThm q($t / ($c₁ / $c₂) = $t * ($c₂ / $c₁)) q(@Arith.arith_div_by_const_elim_real_pos $t $c₁ $c₂)
else
addThm q($t / (-$c₁ / $c₂) = $t * (-$c₂ / $c₁)) q(@Arith.arith_div_by_const_elim_real_neg $t $c₁ $c₂)
| .ARITH_PLUS_ZERO =>
let ⟨α, h⟩ := getTypeAndInst pf.getArguments[1]![0]!.getSort
let args ← reconstructArgs pf.getArguments[1:]
Expand All @@ -152,6 +174,33 @@ def reconstructRewrite (pf : cvc5.Proof) : ReconstructM (Option Expr) := do
let ⟨α, h⟩ := getTypeAndInst pf.getArguments[1]![0]!.getSort
let args ← reconstructArgs pf.getArguments[1:]
addTac (← reconstructTerm pf.getResult) (Tactic.smtRw · q(@mul_assoc $α _) q(@mul_one $α _) q(@Arith.arith_mul_zero $α $h) args)
| .ARITH_DIV_TOTAL =>
let t : Q(Real) ← reconstructTerm pf.getArguments[1]!
let s : Q(Real) ← reconstructTerm pf.getArguments[2]!
let h : Q($s ≠ 0) ← reconstructProof pf.getChildren[0]!
addThm q($t / $s = $t / $s) q(@Arith.arith_div_total $t $s $h)
| .ARITH_INT_DIV_TOTAL =>
let t : Q(Int) ← reconstructTerm pf.getArguments[1]!
let s : Q(Int) ← reconstructTerm pf.getArguments[2]!
let h : Q($s ≠ 0) ← reconstructProof pf.getChildren[0]!
addThm q($t / $s = $t / $s) q(@Arith.arith_int_div_total $t $s $h)
| .ARITH_INT_DIV_TOTAL_ONE =>
let t : Q(Int) ← reconstructTerm pf.getArguments[1]!
addThm q($t / 1 = $t) q(@Arith.arith_int_div_total_one $t)
| .ARITH_INT_DIV_TOTAL_ZERO =>
let t : Q(Int) ← reconstructTerm pf.getArguments[1]!
addThm q($t / 0 = 0) q(@Arith.arith_int_div_total_zero $t)
| .ARITH_INT_MOD_TOTAL =>
let t : Q(Int) ← reconstructTerm pf.getArguments[1]!
let s : Q(Int) ← reconstructTerm pf.getArguments[2]!
let h : Q($s ≠ 0) ← reconstructProof pf.getChildren[0]!
addThm q($t % $s = $t % $s) q(@Arith.arith_int_mod_total $t $s $h)
| .ARITH_INT_MOD_TOTAL_ONE =>
let t : Q(Int) ← reconstructTerm pf.getArguments[1]!
addThm q($t % 1 = 0) q(@Arith.arith_int_mod_total_one $t)
| .ARITH_INT_MOD_TOTAL_ZERO =>
let t : Q(Int) ← reconstructTerm pf.getArguments[1]!
addThm q($t % 0 = $t) q(@Arith.arith_int_mod_total_zero $t)
| .ARITH_NEG_NEG_ONE =>
let ⟨α, h⟩ := getTypeAndInst pf.getArguments[1]!.getSort
let t : Q($α) ← reconstructTerm pf.getArguments[1]!
Expand Down Expand Up @@ -234,6 +283,10 @@ def reconstructRewrite (pf : cvc5.Proof) : ReconstructM (Option Expr) := do
let ⟨α, h⟩ := getTypeAndInst pf.getArguments[2]!.getSort
let args ← reconstructArgs pf.getArguments[1:]
addTac (← reconstructTerm pf.getResult) (Tactic.smtRw · q(@add_assoc $α _) q(@add_zero $α _) q(@Arith.arith_plus_cancel2 $α $h) args)
| .ARITH_ABS_ELIM =>
let ⟨α, h⟩ := getTypeAndInst pf.getArguments[1]!.getSort
let x : Q($α) ← reconstructTerm pf.getArguments[1]!
addThm q(|$x| = if $x < 0 then -$x else $x) q(@Arith.arith_abs_elim $α $h $x)
| _ => return none
where
reconstructArgs (args : Array cvc5.Term) : ReconstructM (Array (Array Expr)) := do
Expand All @@ -254,7 +307,8 @@ where
if !(h.getUsedConstants.any (isNoncomputable (← getEnv))) then
return none
addTac q($t = $t') Arith.normNum
| .DSL_REWRITE => reconstructRewrite pf
| .DSL_REWRITE
| .THEORY_REWRITE => reconstructRewrite pf
| .ARITH_SUM_UB =>
addTac (← reconstructTerm pf.getResult) (Arith.sumBounds · (← pf.getChildren.mapM reconstructProof))
| .INT_TIGHT_UB =>
Expand Down
53 changes: 49 additions & 4 deletions Smt/Reconstruct/Arith/Rewrites.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,56 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Abdalrhman Mohamed
-/

import Mathlib.Algebra.Order.Ring.Defs
import Mathlib.Data.Real.Archimedean

namespace Smt.Reconstruct.Arith

-- https://github.com/cvc5/cvc5/blob/proof-new/src/theory/arith/rewrites

open Function

theorem arith_div_by_const_elim_1_pos {t : Real} : t / 1 = t * 1 :=
div_eq_mul_one_div t 1 ▸ Eq.symm (@one_div_one Real _) ▸ rfl
theorem arith_div_by_const_elim_1_neg {t : Real} : t / -1 = t * -1 :=
div_eq_mul_one_div t (-1) ▸ div_neg (1 : Real) ▸ Eq.symm (@one_div_one Real _) ▸ rfl
theorem arith_div_by_const_elim_num_pos {t c : Real} : t / c = t * (1 / c) :=
div_eq_mul_one_div t c
theorem arith_div_by_const_elim_num_neg {t c : Real} : t / -c = t * (-1 / c) :=
div_eq_mul_one_div t (-c) ▸ div_neg (1 : Real) ▸ neg_div c 1 ▸ rfl
theorem arith_div_by_const_elim_real_pos {t c₁ c₂ : Real} : t / (c₁ / c₂) = t * (c₂ / c₁) :=
div_eq_mul_one_div t (c₁ / c₂) ▸ one_div_div c₁ c₂ ▸ rfl
theorem arith_div_by_const_elim_real_neg {t c₁ c₂ : Real} : t / (-c₁ / c₂) = t * (-c₂ / c₁) :=
div_eq_mul_one_div t (-c₁ / c₂) ▸ one_div_div (-c₁) c₂ ▸ div_neg c₂ ▸ neg_div c₁ c₂ ▸ rfl

-- https://github.com/cvc5/cvc5/blob/proof-new/src/theory/arith/rewrites

variable {α : Type} [h : LinearOrderedRing α]

variable {t ts x xs : α}

theorem arith_plus_zero : ts + 0 + ss = ts + ss :=
(add_zero ts).symm ▸ rfl

theorem arith_mul_one : ts * 1 * ss = ts * ss :=
(mul_one ts).symm ▸ rfl
theorem arith_mul_zero : ts * 0 * ss = 0 :=
(mul_zero ts).symm ▸ (zero_mul ss).symm ▸ rfl

theorem arith_div_total {t s : Real} : s ≠ 0 → t / s = t / s :=
const _ rfl

theorem arith_int_div_total {t s : Int} : s ≠ 0 → t / s = t / s :=
const _ rfl
theorem arith_int_div_total_one {t : Int} : t / 1 = t :=
Int.ediv_one t
theorem arith_int_div_total_zero {t : Int} : t / 0 = 0 :=
Int.ediv_zero t

theorem arith_int_mod_total {t s : Int} : s ≠ 0 → t % s = t % s :=
const _ rfl
theorem arith_int_mod_total_one {t : Int} : t % 1 = 0 :=
Int.emod_one t
theorem arith_int_mod_total_zero {t : Int} : t % 0 = t :=
Int.emod_zero t

theorem arith_neg_neg_one : -1 * (-1 * t) = t :=
neg_mul _ t ▸ (one_mul t).symm ▸ neg_mul_neg _ t ▸ (one_mul t).symm ▸ rfl

Expand All @@ -37,6 +68,10 @@ theorem arith_elim_gt : (t > s) = ¬(t ≤ s) :=
propext not_le.symm
theorem arith_elim_lt : (t < s) = ¬(t ≥ s) :=
propext not_le.symm
theorem arith_elim_int_gt {t s : Int} : (t > s) = (t ≥ s + 1) :=
propext (Int.lt_iff_add_one_le s t)
theorem arith_elim_int_lt {t s : Int} : (t < s) = (s ≥ t + 1) :=
propext (Int.lt_iff_add_one_le t s)
theorem arith_elim_leq : (t ≤ s) = (s ≥ t) :=
propext ge_iff_le

Expand All @@ -62,6 +97,11 @@ theorem arith_refl_geq : (t ≥ t) = True :=
theorem arith_refl_gt : (t > t) = False :=
propext ⟨(lt_irrefl t), False.elim⟩

theorem arith_real_eq_elim {t s : Real} : (t = s) = (t ≥ s ∧ t ≤ s) :=
propext (Iff.symm antisymm_iff)
theorem arith_int_eq_elim {t s : Int} : (t = s) = (t ≥ s ∧ t ≤ s) :=
propext (Iff.symm antisymm_iff)

theorem arith_plus_flatten : xs + (w + ys) + zs = xs + w + ys + zs :=
add_assoc xs w ys ▸ rfl

Expand All @@ -74,9 +114,14 @@ theorem arith_mult_dist : x * (y + z + ws) = x * y + x * (z + ws) :=
theorem arith_plus_cancel1 : ts + x + ss + (-1 * x) + rs = ts + ss + rs :=
neg_eq_neg_one_mul x ▸ add_assoc ts x ss ▸ add_assoc ts (x + ss) (-x) ▸
add_comm x ss ▸ (add_neg_cancel_right ss x).symm ▸ rfl

theorem arith_plus_cancel2 : ts + (-1 * x) + ss + x + rs = ts + ss + rs :=
neg_eq_neg_one_mul x ▸ add_assoc ts (-x) ss ▸ add_assoc ts (-x + ss) x ▸
add_comm (-x) ss ▸ (neg_add_cancel_right ss x).symm ▸ rfl

theorem arith_abs_elim : |x| = if x < 0 then -x else x :=
if h : x < 0 then
if_pos h ▸ abs_of_neg h
else
if_neg h ▸ abs_of_nonneg (le_of_not_lt h)

end Smt.Reconstruct.Arith
25 changes: 9 additions & 16 deletions Test/linarith.expected
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
Test/linarith.lean:33:9: warning: unused variable `e` [linter.unusedVariables]
Test/linarith.lean:52:0: warning: declaration uses 'sorry'
Test/linarith.lean:63:0: warning: declaration uses 'sorry'
Test/linarith.lean:67:0: warning: declaration uses 'sorry'
Test/linarith.lean:71:0: warning: declaration uses 'sorry'
Test/linarith.lean:75:0: warning: declaration uses 'sorry'
Test/linarith.lean:80:0: warning: declaration uses 'sorry'
Test/linarith.lean:84:0: warning: declaration uses 'sorry'
Test/linarith.lean:88:0: warning: declaration uses 'sorry'
Test/linarith.lean:112:9: warning: unused variable `a` [linter.unusedVariables]
Test/linarith.lean:112:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:117:9: warning: unused variable `a` [linter.unusedVariables]
Test/linarith.lean:117:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:129:60: warning: unused variable `h3` [linter.unusedVariables]
Test/linarith.lean:136:9: warning: unused variable `a` [linter.unusedVariables]
Test/linarith.lean:136:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:179:34: warning: unused variable `z` [linter.unusedVariables]
Test/linarith.lean:180:5: warning: unused variable `h5` [linter.unusedVariables]
Test/linarith.lean:100:9: warning: unused variable `a` [linter.unusedVariables]
Test/linarith.lean:100:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:105:9: warning: unused variable `a` [linter.unusedVariables]
Test/linarith.lean:105:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:117:60: warning: unused variable `h3` [linter.unusedVariables]
Test/linarith.lean:124:9: warning: unused variable `a` [linter.unusedVariables]
Test/linarith.lean:124:13: warning: unused variable `c` [linter.unusedVariables]
Test/linarith.lean:167:34: warning: unused variable `z` [linter.unusedVariables]
Test/linarith.lean:168:5: warning: unused variable `h5` [linter.unusedVariables]
12 changes: 0 additions & 12 deletions Test/linarith.lean
Original file line number Diff line number Diff line change
Expand Up @@ -62,52 +62,40 @@ example (A B : Int) (h : 0 < A * B) : 0 < 8*A*B := by

example (A B : Real) (h : 0 < A * B) : 0 < A*B/8 := by
smt [h]
all_goals sorry

example (A B : Real) (h : 0 < A * B) : 0 < A/8*B := by
smt [h]
all_goals sorry

example (ε : Real) (h1 : ε > 0) : ε / 2 + ε / 3 + ε / 7 < ε := by
smt [h1]
all_goals sorry

example (x y z : Real) (h1 : 2*x < 3*y) (h2 : -4*x + z/2 < 0)
(h3 : 12*y - z < 0) : False := by
smt [h1, h2, h3]
all_goals sorry

example (ε : Real) (h1 : ε > 0) : ε / 2 < ε := by
smt [h1]
all_goals sorry

example (ε : Real) (h1 : ε > 0) : ε / 3 + ε / 3 + ε / 3 = ε := by
smt [h1]
all_goals sorry

example (x : Real) (h : 0 < x) : 0 < x/1 := by
smt [h]
all_goals sorry

example (x : Real) (h : x < 0) : 0 < x/(-1) := by
smt [h]
all_goals (ring_nf; simp)

example (x : Real) (h : x < 0) : 0 < x/(-2) := by
smt [h]
all_goals (ring_nf; simp)

example (x : Real) (h : x < 0) : 0 < x/(-4/5) := by
smt [h]
all_goals (ring_nf; simp)

example (x : Real) (h : 0 < x) : 0 < x/2/3 := by
smt [h]
all_goals (ring_nf; simp)

example (x : Real) (h : 0 < x) : 0 < x/(2/3) := by
smt [h]
all_goals (ring_nf; simp)

example (a b c : Real) (h2 : b + 2 > 3 + b) : False := by
smt [h2]
Expand Down

0 comments on commit 64d482b

Please sign in to comment.