Skip to content

Commit

Permalink
feat: switching List lookup normal forms to L[n] and L[n]? (leanp…
Browse files Browse the repository at this point in the history
…rover#4400)

This is presumably going to have significant breakage downstream.
  • Loading branch information
kim-em authored Jun 15, 2024
1 parent fe0cb97 commit e10a37d
Show file tree
Hide file tree
Showing 6 changed files with 426 additions and 149 deletions.
55 changes: 30 additions & 25 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import Init.TacticsExtra
/-!
## Bootstrapping theorems about arrays
This file contains some theorems about `Array` and `List` needed for `Std.List.Basic`.
This file contains some theorems about `Array` and `List` needed for `Init.Data.List.Impl`.
-/

namespace Array
Expand All @@ -34,9 +34,13 @@ attribute [simp] data_toArray uset

@[simp] theorem size_mk (as : List α) : (Array.mk as).size = as.length := by simp [size]

theorem getElem_eq_data_get (a : Array α) (h : i < a.size) : a[i] = a.data.get ⟨i, h⟩ := by
theorem getElem_eq_data_getElem (a : Array α) (h : i < a.size) : a[i] = a.data[i] := by
by_cases i < a.size <;> (try simp [*]) <;> rfl

@[deprecated getElem_eq_data_getElem (since := "2024-06-12")]
theorem getElem_eq_data_get (a : Array α) (h : i < a.size) : a[i] = a.data.get ⟨i, h⟩ := by
simp [getElem_eq_data_getElem]

theorem foldlM_eq_foldlM_data.aux [Monad m]
(f : β → α → m β) (arr : Array α) (i j) (H : arr.size ≤ i + j) (b) :
foldlM.loop f arr arr.size (Nat.le_refl _) i j b = (arr.data.drop j).foldlM f b := by
Expand Down Expand Up @@ -114,11 +118,11 @@ theorem foldr_push (f : α → β → β) (init : β) (arr : Array α) (a : α)
theorem get_push_lt (a : Array α) (x : α) (i : Nat) (h : i < a.size) :
have : i < (a.push x).size := by simp [*, Nat.lt_succ_of_le, Nat.le_of_lt]
(a.push x)[i] = a[i] := by
simp only [push, getElem_eq_data_get, List.concat_eq_append, List.get_append_left, h]
simp only [push, getElem_eq_data_getElem, List.concat_eq_append, List.getElem_append_left, h]

@[simp] theorem get_push_eq (a : Array α) (x : α) : (a.push x)[a.size] = x := by
simp only [push, getElem_eq_data_get, List.concat_eq_append]
rw [List.get_append_right] <;> simp [getElem_eq_data_get, Nat.zero_lt_one]
simp only [push, getElem_eq_data_getElem, List.concat_eq_append]
rw [List.getElem_append_right] <;> simp [getElem_eq_data_getElem, Nat.zero_lt_one]

theorem get_push (a : Array α) (x : α) (i : Nat) (h : i < (a.push x).size) :
(a.push x)[i] = if h : i < a.size then a[i] else x := by
Expand Down Expand Up @@ -233,11 +237,11 @@ theorem get!_eq_getD [Inhabited α] (a : Array α) : a.get! n = a.getD n default
@[simp] theorem getElem_set_eq (a : Array α) (i : Fin a.size) (v : α) {j : Nat}
(eq : i.val = j) (p : j < (a.set i v).size) :
(a.set i v)[j]'p = v := by
simp [set, getElem_eq_data_get, ←eq]
simp [set, getElem_eq_data_getElem, ←eq]

@[simp] theorem getElem_set_ne (a : Array α) (i : Fin a.size) (v : α) {j : Nat} (pj : j < (a.set i v).size)
(h : i.val ≠ j) : (a.set i v)[j]'pj = a[j]'(size_set a i v ▸ pj) := by
simp only [set, getElem_eq_data_get, List.get_set_ne _ h]
simp only [set, getElem_eq_data_getElem, List.getElem_set_ne _ h]

theorem getElem_set (a : Array α) (i : Fin a.size) (v : α) (j : Nat)
(h : j < (a.set i v).size) :
Expand Down Expand Up @@ -321,7 +325,7 @@ termination_by n - i
@[simp] theorem mkArray_data (n : Nat) (v : α) : (mkArray n v).data = List.replicate n v := rfl

@[simp] theorem getElem_mkArray (n : Nat) (v : α) (h : i < (mkArray n v).size) :
(mkArray n v)[i] = v := by simp [Array.getElem_eq_data_get]
(mkArray n v)[i] = v := by simp [Array.getElem_eq_data_getElem]

/-- # mem -/

Expand All @@ -332,7 +336,7 @@ theorem not_mem_nil (a : α) : ¬ a ∈ #[] := nofun
/-- # get lemmas -/

theorem getElem?_mem {l : Array α} {i : Fin l.size} : l[i] ∈ l := by
erw [Array.mem_def, getElem_eq_data_get]
erw [Array.mem_def, getElem_eq_data_getElem]
apply List.get_mem

theorem getElem_fin_eq_data_get (a : Array α) (i : Fin _) : a[i] = a.data.get i := rfl
Expand All @@ -347,7 +351,7 @@ theorem get?_len_le (a : Array α) (i : Nat) (h : a.size ≤ i) : a[i]? = none :
simp [getElem?_neg, h]

theorem getElem_mem_data (a : Array α) (h : i < a.size) : a[i] ∈ a.data := by
simp only [getElem_eq_data_get, List.get_mem]
simp only [getElem_eq_data_getElem, List.getElem_mem]

theorem getElem?_eq_data_get? (a : Array α) (i : Nat) : a[i]? = a.data.get? i := by
by_cases i < a.size <;> simp_all [getElem?_pos, getElem?_neg, List.get?_eq_get, eq_comm]; rfl
Expand Down Expand Up @@ -395,7 +399,7 @@ theorem get?_push {a : Array α} : (a.push x)[i]? = if i = a.size then some x el

theorem get_set_eq (a : Array α) (i : Fin a.size) (v : α) :
(a.set i v)[i.1] = v := by
simp only [set, getElem_eq_data_get, List.get_set_eq]
simp only [set, getElem_eq_data_getElem, List.getElem_set_eq]

theorem get?_set_eq (a : Array α) (i : Fin a.size) (v : α) :
(a.set i v)[i.1]? = v := by simp [getElem?_pos, i.2]
Expand All @@ -414,7 +418,7 @@ theorem get_set (a : Array α) (i : Fin a.size) (j : Nat) (hj : j < a.size) (v :

@[simp] theorem get_set_ne (a : Array α) (i : Fin a.size) {j : Nat} (v : α) (hj : j < a.size)
(h : i.1 ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
simp only [set, getElem_eq_data_get, List.get_set_ne _ h]
simp only [set, getElem_eq_data_getElem, List.getElem_set_ne _ h]

theorem getElem_setD (a : Array α) (i : Nat) (v : α) (h : i < (setD a i v).size) :
(setD a i v)[i] = v := by
Expand Down Expand Up @@ -452,7 +456,7 @@ theorem swapAt!_def (a : Array α) (i : Nat) (v : α) (h : i < a.size) :

@[simp] theorem getElem_pop (a : Array α) (i : Nat) (hi : i < a.pop.size) :
a.pop[i] = a[i]'(Nat.lt_of_lt_of_le (a.size_pop ▸ hi) (Nat.sub_le _ _)) :=
List.get_dropLast ..
List.getElem_dropLast ..

theorem eq_empty_of_size_eq_zero {as : Array α} (h : as.size = 0) : as = #[] := by
apply ext
Expand Down Expand Up @@ -500,6 +504,7 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
simp only [mkEmpty_eq, size_push] at *
omega

set_option linter.deprecated false in
@[simp] theorem reverse_data (a : Array α) : a.reverse.data = a.data.reverse := by
let rec go (as : Array α) (i j hj)
(h : i + j + 1 = a.size) (h₂ : as.size = a.size)
Expand All @@ -517,10 +522,10 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
simp only [H, getElem_eq_data_get, ← List.get?_eq_get, Nat.le_of_lt h₁, getElem?_eq_data_get?]
split <;> rename_i h₂
· simp only [← h₂, Nat.not_le.2 (Nat.lt_succ_self _), Nat.le_refl, and_false]
exact (List.get?_reverse' _ _ (Eq.trans (by simp_arith) h)).symm
exact (List.get?_reverse' (j+1) i (Eq.trans (by simp_arith) h)).symm
split <;> rename_i h₃
· simp only [← h₃, Nat.not_le.2 (Nat.lt_succ_self _), Nat.le_refl, false_and]
exact (List.get?_reverse' _ _ (Eq.trans (by simp_arith) h)).symm
exact (List.get?_reverse' i (j+1) (Eq.trans (by simp_arith) h)).symm
simp only [Nat.succ_le, Nat.lt_iff_le_and_ne.trans (and_iff_left h₃),
Nat.lt_succ.symm.trans (Nat.lt_iff_le_and_ne.trans (and_iff_left (Ne.symm h₂)))]
· rw [H]; split <;> rename_i h₂
Expand All @@ -533,7 +538,7 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
split
· match a with | ⟨[]⟩ | ⟨[_]⟩ => rfl
· have := Nat.sub_add_cancel (Nat.le_of_not_le ‹_›)
refine List.ext <| go _ _ _ _ (by simp [this]) rfl fun k => ?_
refine List.ext_get? <| go _ _ _ _ (by simp [this]) rfl fun k => ?_
split
· rfl
· rename_i h
Expand Down Expand Up @@ -769,17 +774,17 @@ theorem size_append (as bs : Array α) : (as ++ bs).size = as.size + bs.size :=

theorem get_append_left {as bs : Array α} {h : i < (as ++ bs).size} (hlt : i < as.size) :
(as ++ bs)[i] = as[i] := by
simp only [getElem_eq_data_get]
simp only [getElem_eq_data_getElem]
have h' : i < (as.data ++ bs.data).length := by rwa [← data_length, append_data] at h
conv => rhs; rw [← List.get_append_left (bs:=bs.data) (h':=h')]
conv => rhs; rw [← List.getElem_append_left (bs := bs.data) (h' := h')]
apply List.get_of_eq; rw [append_data]

theorem get_append_right {as bs : Array α} {h : i < (as ++ bs).size} (hle : as.size ≤ i)
(hlt : i - as.size < bs.size := Nat.sub_lt_left_of_lt_add hle (size_append .. ▸ h)) :
(as ++ bs)[i] = bs[i - as.size] := by
simp only [getElem_eq_data_get]
simp only [getElem_eq_data_getElem]
have h' : i < (as.data ++ bs.data).length := by rwa [← data_length, append_data] at h
conv => rhs; rw [← List.get_append_right (h':=h') (h:=Nat.not_lt_of_ge hle)]
conv => rhs; rw [← List.getElem_append_right (h' := h') (h := Nat.not_lt_of_ge hle)]
apply List.get_of_eq; rw [append_data]

@[simp] theorem append_nil (as : Array α) : as ++ #[] = as := by
Expand Down Expand Up @@ -987,13 +992,13 @@ theorem all_eq_true (p : α → Bool) (as : Array α) : all as p ↔ ∀ i : Fin
simp [all_iff_forall, Fin.isLt]

theorem all_def {p : α → Bool} (as : Array α) : as.all p = as.data.all p := by
rw [Bool.eq_iff_iff, all_eq_true, List.all_eq_true]; simp only [List.mem_iff_get]
rw [Bool.eq_iff_iff, all_eq_true, List.all_eq_true]; simp only [List.mem_iff_getElem]
constructor
· rintro w x ⟨r, rfl⟩
rw [← getElem_eq_data_get]
apply w
· rintro w x ⟨r, h, rfl⟩
rw [← getElem_eq_data_getElem]
exact w ⟨r, h⟩
· intro w i
exact w as[i] ⟨i, (getElem_eq_data_get as i.2).symm⟩
exact w as[i] ⟨i, i.2, (getElem_eq_data_getElem as i.2).symm⟩

theorem all_eq_true_iff_forall_mem {l : Array α} : l.all p ↔ ∀ x, x ∈ l → p x := by
simp only [all_def, List.all_eq_true, mem_def]
Expand Down
10 changes: 6 additions & 4 deletions src/Init/Data/List/BasicAux.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ See also `get?` and `get!`.
def getD (as : List α) (i : Nat) (fallback : α) : α :=
(as.get? i).getD fallback

@[ext] theorem ext : ∀ {l₁ l₂ : List α}, (∀ n, l₁.get? n = l₂.get? n) → l₁ = l₂
theorem ext_get? : ∀ {l₁ l₂ : List α}, (∀ n, l₁.get? n = l₂.get? n) → l₁ = l₂
| [], [], _ => rfl
| a :: l₁, [], h => nomatch h 0
| [], a' :: l₂, h => nomatch h 0
| a :: l₁, a' :: l₂, h => by
have h0 : some a = some a' := h 0
injection h0 with aa; simp only [aa, ext fun n => h (n+1)]
injection h0 with aa; simp only [aa, ext_get? fun n => h (n+1)]

@[deprecated (since := "2024-06-07")] abbrev ext := @ext_get?

/--
Returns the first element in the list.
Expand Down Expand Up @@ -191,15 +193,15 @@ def rotateRight (xs : List α) (n : Nat := 1) : List α :=
let e := xs.drop n
e ++ b

theorem get_append_left (as bs : List α) (h : i < as.length) {h'} : (as ++ bs).get ⟨i, h'⟩ = as.get ⟨i, h⟩ := by
theorem getElem_append_left (as bs : List α) (h : i < as.length) {h'} : (as ++ bs)[i] = as[i] := by
induction as generalizing i with
| nil => trivial
| cons a as ih =>
cases i with
| zero => rfl
| succ i => apply ih

theorem get_append_right (as bs : List α) (h : ¬ i < as.length) {h' h''} : (as ++ bs).get ⟨i, h'⟩ = bs.get ⟨i - as.length, h'' := by
theorem getElem_append_right (as bs : List α) (h : ¬ i < as.length) {h' h''} : (as ++ bs)[i]'h' = bs[i - as.length]'h'' := by
induction as generalizing i with
| nil => trivial
| cons a as ih =>
Expand Down
Loading

0 comments on commit e10a37d

Please sign in to comment.