Skip to content

Commit

Permalink
Fix bug with Iff elimination. (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
abdoo8080 authored Jul 12, 2024
1 parent 91ac7f4 commit b332ae4
Show file tree
Hide file tree
Showing 15 changed files with 305 additions and 47 deletions.
8 changes: 8 additions & 0 deletions Smt/Preprocess.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/-
Copyright (c) 2021-2024 by the authors listed in the file AUTHORS and their
institutional affiliations. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Abdalrhman Mohamed, Tomaz Gomes Mascarenhas
-/

import Smt.Preprocess.Iff
62 changes: 62 additions & 0 deletions Smt/Preprocess/Iff.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/-
Copyright (c) 2021-2024 by the authors listed in the file AUTHORS and their
institutional affiliations. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Abdalrhman Mohamed, Tomaz Gomes Mascarenhas
-/

import Lean
import Qq

namespace Smt.Preprocess

open Lean Qq

theorem iff_eq_eq : (p ↔ q) = (p = q) := propext ⟨propext, (· ▸ ⟨(·), (·)⟩)⟩

theorem eq_resolve {p q : Prop} (hp : p) (hpq : p = q) : q := hpq ▸ hp

def replaceIff (e : Expr) : MetaM Expr :=
let f e :=
if let some ((l : Q(Prop)), (r : Q(Prop))) := e.app2? ``Iff then
q($l = $r)
else
none
Meta.mkAppM ``Eq #[e, e.replace f]

def elimIff (mv : MVarId) (hs : List Expr) : MetaM (List Expr × MVarId) := mv.withContext do
let simpTheorems ← #[``eq_self, ``iff_eq_eq].foldlM (·.addConst ·) ({} : Meta.SimpTheorems)
let simpTheorems := #[simpTheorems]
let congrTheorems ← Meta.getSimpCongrTheorems
let ctx := { simpTheorems, congrTheorems }
let (hs, mv) ← elimIffLocalDecls mv hs ctx
let mv ← elimIffTarget mv ctx
return (hs, mv)
where
elimIffLocalDecls mv hs ctx := mv.withContext do
let mut newHs := []
let mut toAssert := #[]
for h in hs do
let type ← Meta.inferType h
let eq ← replaceIff (← instantiateMVars type)
let (_, l, r) := eq.eq?.get!
if l == r then
newHs := h :: newHs
else
let userName ← if h.isFVar then h.fvarId!.getUserName else Lean.mkFreshId
let type := r
let (r, _) ← Meta.simp eq ctx
let value ← Meta.mkAppM ``eq_resolve #[h, ← Meta.mkOfEqTrue (← r.getProof)]
toAssert := toAssert.push { userName, type, value }
let (fvs, mv) ← mv.assertHypotheses toAssert
newHs := newHs.reverse ++ (fvs.map (.fvar ·)).toList
return (newHs, mv)
elimIffTarget mv ctx := mv.withContext do
let eq ← replaceIff (← instantiateMVars (← mv.getType))
let (r, _) ← Meta.simp eq ctx
if r.expr.isTrue then
mv.replaceTargetEq eq.appArg! (← Meta.mkOfEqTrue (← r.getProof))
else
return mv

end Smt.Preprocess
15 changes: 9 additions & 6 deletions Smt/Tactic/Smt.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Smt.Dsl.Sexp
import Smt.Reconstruct
import Smt.Reconstruct.Prop.Lemmas
import Smt.Translate.Query
import Smt.Preprocess
import Smt.Util

namespace Smt
Expand Down Expand Up @@ -39,10 +40,11 @@ where
go mv hs (.fvar fv :: fvs) k

def smt (mv : MVarId) (hs : List Expr) (timeout : Option Nat := none) : MetaM (List MVarId) := mv.withContext do
let mv ← Util.rewriteIffMeta mv
let goalType : Q(Prop) ← mv.getType
-- 1. Process the hints passed to the tactic.
withProcessedHints mv hs fun mv hs => mv.withContext do
let (hs, mv) ← Preprocess.elimIff mv hs
mv.withContext do
let goalType : Q(Prop) ← mv.getType
-- 2. Generate the SMT query.
let cmds ← prepareSmtQuery hs (← mv.getType)
let cmds := .setLogic "ALL" :: cmds
Expand Down Expand Up @@ -122,11 +124,12 @@ def parseTimeout : TSyntax `smtTimeout → TacticM (Option Nat)

@[tactic smtShow] def evalSmtShow : Tactic := fun stx => withMainContext do
let g ← Meta.mkFreshExprMVar (← getMainTarget)
let mv ← Util.rewriteIffMeta g.mvarId!
let goalType ← mv.getType
let mut hs ← parseHints ⟨stx[1]⟩
hs := hs.eraseDups
let mv := g.mvarId!
let hs ← parseHints ⟨stx[1]⟩
withProcessedHints mv hs fun mv hs => mv.withContext do
let (hs, mv) ← Preprocess.elimIff mv hs
mv.withContext do
let goalType ← mv.getType
let cmds ← prepareSmtQuery hs (← mv.getType)
let cmds := cmds ++ [.checkSat]
logInfo m!"goal: {goalType}\n\nquery:\n{Command.cmdsAsQuery cmds}"
Expand Down
31 changes: 1 addition & 30 deletions Smt/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -108,33 +108,4 @@ where
| some e => .visit e
| none => .continue

theorem iff_eq_eq : (p ↔ q) = (p = q) := propext ⟨propext, (· ▸ ⟨(·), (·)⟩)⟩

def rewriteIffGoal (mvar : MVarId) : MetaM MVarId :=
mvar.withContext do
let t ← mvar.getType
let r ← mvar.rewrite t (mkConst ``iff_eq_eq)
let mvar' ← mvar.replaceTargetEq r.eNew r.eqProof
pure mvar'

def rewriteIffDecl (decl : LocalDecl) (mvar : MVarId) : MetaM MVarId :=
mvar.withContext do
let rwRes ← mvar.rewrite decl.type (mkConst ``iff_eq_eq)
let repRes ← mvar.replaceLocalDecl decl.fvarId rwRes.eNew rwRes.eqProof
pure repRes.mvarId

partial def fixRewriteIff (mvar : MVarId) (f : MVarId → MetaM MVarId) : MetaM MVarId :=
mvar.withContext do
try
let mvar' ← f mvar
fixRewriteIff mvar' f
catch _ => return mvar

def rewriteIffMeta (mvar : MVarId) : MetaM MVarId :=
mvar.withContext do
let mvar' ← fixRewriteIff mvar rewriteIffGoal
let lctx ← getLCtx
lctx.foldrM
(fun decl mvar'' => fixRewriteIff mvar'' (rewriteIffDecl decl)) mvar'

namespace Smt.Util
end Smt.Util
2 changes: 1 addition & 1 deletion Test/BitVec/XorComm.expected
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ Test/BitVec/XorComm.lean:3:8: warning: declaration uses 'sorry'
goal: x ^^^ y = y ^^^ x

query:
(declare-const x (_ BitVec 8))
(declare-const y (_ BitVec 8))
(declare-const x (_ BitVec 8))
(assert (distinct (bvxor x y) (bvxor y x)))
(check-sat)
Test/BitVec/XorComm.lean:7:8: warning: declaration uses 'sorry'
4 changes: 2 additions & 2 deletions Test/Int/Binders.expected
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ goal: partCurryAdd a b = partCurryAdd b a

query:
(define-fun partCurryAdd ((a Int) ([email protected]._hyg.36 Int)) Int (+ a [email protected]._hyg.36))
(declare-const b Int)
(declare-const a Int)
(declare-const b Int)
(assert (distinct (partCurryAdd a b) (partCurryAdd b a)))
(check-sat)
Test/Int/Binders.lean:11:0: warning: declaration uses 'sorry'
Expand All @@ -29,8 +29,8 @@ goal: mismatchNamesAdd a b = mismatchNamesAdd b a

query:
(define-fun mismatchNamesAdd ((a Int) (b Int)) Int (+ a b))
(declare-const b Int)
(declare-const a Int)
(declare-const b Int)
(assert (distinct (mismatchNamesAdd a b) (mismatchNamesAdd b a)))
(check-sat)
Test/Int/Binders.lean:25:0: warning: declaration uses 'sorry'
9 changes: 9 additions & 0 deletions Test/Int/linarith.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Test/Int/linarith.lean:51:9: warning: unused variable `e`
note: this linter can be disabled with `set_option linter.unusedVariables false`
Test/Int/linarith.lean:76:0: warning: declaration uses 'sorry'
Test/Int/linarith.lean:93:62: warning: unused variable `h3`
note: this linter can be disabled with `set_option linter.unusedVariables false`
Test/Int/linarith.lean:98:36: warning: unused variable `z`
note: this linter can be disabled with `set_option linter.unusedVariables false`
Test/Int/linarith.lean:99:5: warning: unused variable `h5`
note: this linter can be disabled with `set_option linter.unusedVariables false`
199 changes: 199 additions & 0 deletions Test/Int/linarith.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import Smt

-- example : ∃ (x : Int), x * x = 2 := by
-- smt

example : p ∧ q → p := by
smt

example (a b : Int) : b < 0 → a > 0 → b * (- 2) * a * b* b * (- 3) * a * a < 0 := by
smt
all_goals sorry

example (a b : Int) (hb : b < 0) (ha : a < 0) : b * (- 2) * a * b * (- 3) * a * a < 0 := by
smt [hb, ha]
all_goals sorry

-- example (m n : Int) (h : m > 0) : n % m < m := by
-- smt [h]
-- all_goals sorry

example {x y : Int} {f : Int → Int} : ¬(x ≤ y ∧ y ≤ x ∧ ¬f x = f y) := by
smt
all_goals sorry

example {p q r : Prop} (hp : ¬p) (hq : ¬q) (hr : r) : ¬(p ∨ q ∨ ¬r) := by
smt [hp, hq, hr]

example {p q r : Prop} : ((p ∧ q) ∧ r) = (r ∧ True ∧ q ∧ p ∧ p) := by
smt

example {p q r : Prop} : ((p ∧ q) ∧ r) = (r ∧ True ∧ q ∧ p ∧ p) := by
ac_rfl

example {a b : Int} (h : a < b) (w : b < a) : False := by
smt [h, w]
all_goals sorry

example
{a b c : Int}
(ha : a < 0)
(hb : ¬b = 0)
(hc' : c = 0)
(h₁ : (1 - a) * (b * b) ≤ 0)
(hc : (0 : Int) ≤ 0)
(w : -(a * -b * -b + b * -b + 0) = (1 - a) * (b * b))
(h₂ : (1 - a) * (b * b) ≤ 0) :
0 < 1 - a := by
smt [ha, hb, hc', h₁, hc, w, h₂]
all_goals sorry

example (e b c a v0 v1 : Int) (h1 : v0 = 5*a) (h2 : v1 = 3*b)
(h3 : v0 + v1 + c = 10) : v0 + 5 + (v1 - 3) + (c - 2) = 10 := by
smt [h1, h2, h3]
all_goals sorry

example (h : (1 : Int) < 0) (g : ¬ (37 : Int) < 42) (_k : True) (l : (-7 : Int) < 5): (3 : Int) < 7 := by
smt [h, g, _k, l]
all_goals sorry

example (u v r s t : Int) (h : 0 < u*(t*v + t*r + s)) : 0 < (t*(r + v) + s)*3*u := by
smt [h]
all_goals sorry

example (A B : Int) (h : 0 < 3 * A * B) : 0 < 8*A*B := by
smt [h]
all_goals sorry

example (A B : Int) (h : 0 < 8 * A * B) : 0 < A*B := by
smt [h]
all_goals sorry

example (A B : Int) (h : 0 < A * B) : 0 < A*8*B := by
smt [h]
all_goals sorry

example (x : Int) : 0 ≤ x := by
have h : 0 ≤ x := sorry
smt [h]

example (u v r s t : Int) (h : 0 < u*(t*v + t*r + s)) :
0 < (t*(r + v) + s)*3*u := by
smt [h]
all_goals sorry

example (A B : Int) (h : 0 < A * B) : 0 < 8*A*B := by
smt [h]
all_goals sorry

example (x y z : Int) (h1 : 2*x < 3*y) (h2 : -4*x + 2*z < 0) (h3 : 12*y - 4* z < 0) : False := by
smt [h1, h2, h3]
all_goals sorry

example (x y z : Int) (h1 : 2*x < 3*y) (h2 : -4*x + 2*z < 0) (h3 : x*y < 5) (h3 : 12*y - 4* z < 0) :
False := by
smt [h1, h2, h3]
all_goals sorry

example (prime : Int → Prop) (w x y z : Int) (h1 : 4*x + (-3)*y + 6*w ≤ 0) (h2 : (-1)*x < 0) (h3 : y < 0) (h4 : w ≥ 0)
(h5 : prime x) : False := by
smt [h1, h2, h3, h4]
all_goals sorry

-- set_option maxRecDepth 2000000

example (u v x y A B : Int)
(a : 0 < A)
(a_1 : 0 <= 1 - A)
(a_2 : 0 <= B - 1)
(a_3 : 0 <= B - x)
(a_4 : 0 <= B - y)
(a_5 : 0 <= u)
(a_6 : 0 <= v)
(a_7 : 0 < A - u)
(a_8 : 0 < A - v) :
(0 < A * A)
-> (0 <= A * (1 - A))
-> (0 <= A * (B - 1))
-> (0 <= A * (B - x))
-> (0 <= A * (B - y))
-> (0 <= A * u)
-> (0 <= A * v)
-> (0 < A * (A - u))
-> (0 < A * (A - v))
-> (0 <= (1 - A) * A)
-> (0 <= (1 - A) * (1 - A))
-> (0 <= (1 - A) * (B - 1))
-> (0 <= (1 - A) * (B - x))
-> (0 <= (1 - A) * (B - y))
-> (0 <= (1 - A) * u)
-> (0 <= (1 - A) * v)
-> (0 <= (1 - A) * (A - u))
-> (0 <= (1 - A) * (A - v))
-> (0 <= (B - 1) * A)
-> (0 <= (B - 1) * (1 - A))
-> (0 <= (B - 1) * (B - 1))
-> (0 <= (B - 1) * (B - x))
-> (0 <= (B - 1) * (B - y))
-> (0 <= (B - 1) * u)
-> (0 <= (B - 1) * v)
-> (0 <= (B - 1) * (A - u))
-> (0 <= (B - 1) * (A - v))
-> (0 <= (B - x) * A)
-> (0 <= (B - x) * (1 - A))
-> (0 <= (B - x) * (B - 1))
-> (0 <= (B - x) * (B - x))
-> (0 <= (B - x) * (B - y))
-> (0 <= (B - x) * u)
-> (0 <= (B - x) * v)
-> (0 <= (B - x) * (A - u))
-> (0 <= (B - x) * (A - v))
-> (0 <= (B - y) * A)
-> (0 <= (B - y) * (1 - A))
-> (0 <= (B - y) * (B - 1))
-> (0 <= (B - y) * (B - x))
-> (0 <= (B - y) * (B - y))
-> (0 <= (B - y) * u)
-> (0 <= (B - y) * v)
-> (0 <= (B - y) * (A - u))
-> (0 <= (B - y) * (A - v))
-> (0 <= u * A)
-> (0 <= u * (1 - A))
-> (0 <= u * (B - 1))
-> (0 <= u * (B - x))
-> (0 <= u * (B - y))
-> (0 <= u * u)
-> (0 <= u * v)
-> (0 <= u * (A - u))
-> (0 <= u * (A - v))
-> (0 <= v * A)
-> (0 <= v * (1 - A))
-> (0 <= v * (B - 1))
-> (0 <= v * (B - x))
-> (0 <= v * (B - y))
-> (0 <= v * u)
-> (0 <= v * v)
-> (0 <= v * (A - u))
-> (0 <= v * (A - v))
-> (0 < (A - u) * A)
-> (0 <= (A - u) * (1 - A))
-> (0 <= (A - u) * (B - 1))
-> (0 <= (A - u) * (B - x))
-> (0 <= (A - u) * (B - y))
-> (0 <= (A - u) * u)
-> (0 <= (A - u) * v)
-> (0 < (A - u) * (A - u))
-> (0 < (A - u) * (A - v))
-> (0 < (A - v) * A)
-> (0 <= (A - v) * (1 - A))
-> (0 <= (A - v) * (B - 1))
-> (0 <= (A - v) * (B - x))
-> (0 <= (A - v) * (B - y))
-> (0 <= (A - v) * u)
-> (0 <= (A - v) * v)
-> (0 < (A - v) * (A - u))
-> (0 < (A - v) * (A - v))
->
u * y + v * x + u * v < 3 * A * B := by
smt [a, a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8]
all_goals sorry
2 changes: 1 addition & 1 deletion Test/Nat/Cong.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ goal: x = y → f x = f y
query:
(define-sort Nat () Int)
(declare-fun f (Nat) Nat)
(assert (forall ((_uniq.1520 Nat)) (=> (>= _uniq.1520 0) (>= (f _uniq.1520) 0))))
(assert (forall ((_uniq.1556 Nat)) (=> (>= _uniq.1556 0) (>= (f _uniq.1556) 0))))
(declare-const x Nat)
(assert (>= x 0))
(declare-const y Nat)
Expand Down
Loading

0 comments on commit b332ae4

Please sign in to comment.