diff --git a/src/Init/Data/Bool.lean b/src/Init/Data/Bool.lean index f3de8ba4ad53..3eb9b4c1cf7f 100644 --- a/src/Init/Data/Bool.lean +++ b/src/Init/Data/Bool.lean @@ -549,3 +549,19 @@ export Bool (cond_eq_if) @[simp] theorem true_eq_decide_iff {p : Prop} [h : Decidable p] : true = decide p ↔ p := by cases h with | _ q => simp [q] + +/-! ### coercions -/ + +/-- +This should not be turned on globally as an instance because it degrades performance in Mathlib, +but may be used locally. +-/ +def boolPredToPred : Coe (α → Bool) (α → Prop) where + coe r := fun a => Eq (r a) true + +/-- +This should not be turned on globally as an instance because it degrades performance in Mathlib, +but may be used locally. +-/ +def boolRelToRel : Coe (α → α → Bool) (α → α → Prop) where + coe r := fun a b => Eq (r a b) true diff --git a/src/Init/Data/List.lean b/src/Init/Data/List.lean index e8dbe7776f3a..94e8c22a6e47 100644 --- a/src/Init/Data/List.lean +++ b/src/Init/Data/List.lean @@ -22,3 +22,4 @@ import Init.Data.List.Sublist import Init.Data.List.TakeDrop import Init.Data.List.Zip import Init.Data.List.Perm +import Init.Data.List.Sort diff --git a/src/Init/Data/List/Basic.lean b/src/Init/Data/List/Basic.lean index 9df9667a6dbd..cdefad5646ca 100644 --- a/src/Init/Data/List/Basic.lean +++ b/src/Init/Data/List/Basic.lean @@ -962,6 +962,26 @@ def IsInfix (l₁ : List α) (l₂ : List α) : Prop := Exists fun s => Exists f @[inherit_doc] infixl:50 " <:+: " => IsInfix +/-! ### splitAt -/ + +/-- +Split a list at an index. +``` +splitAt 2 [a, b, c] = ([a, b], [c]) +``` +-/ +def splitAt (n : Nat) (l : List α) : List α × List α := go l n [] where + /-- + Auxiliary for `splitAt`: + `splitAt.go l xs n acc = (acc.reverse ++ take n xs, drop n xs)` if `n < xs.length`, + and `(l, [])` otherwise. + -/ + go : List α → Nat → List α → List α × List α + | [], _, _ => (l, []) -- This branch ensures the pointer equality of the result with the input + -- without any runtime branching cost. + | x :: xs, n+1, acc => go xs n (x :: acc) + | xs, _, acc => (acc.reverse, xs) + /-! ### rotateLeft -/ /-- diff --git a/src/Init/Data/List/Lemmas.lean b/src/Init/Data/List/Lemmas.lean index c04d8646af0f..53b86a0cbb58 100644 --- a/src/Init/Data/List/Lemmas.lean +++ b/src/Init/Data/List/Lemmas.lean @@ -276,6 +276,9 @@ theorem getElem?_cons_zero {l : List α} : (a::l)[0]? = some a := by simp simp only [← get?_eq_getElem?] rfl +theorem getElem?_cons : (a :: l)[i]? = if i = 0 then some a else l[i-1]? := by + cases i <;> simp + theorem getElem?_len_le : ∀ {l : List α} {n}, length l ≤ n → l[n]? = none | [], _, _ => rfl | _ :: l, _+1, h => by @@ -2368,6 +2371,27 @@ theorem dropLast_append {l₁ l₂ : List α} : dropLast (a :: replicate n a) = replicate n a := by rw [← replicate_succ, dropLast_replicate, Nat.add_sub_cancel] +/-! +### splitAt + +We don't provide any API for `splitAt`, beyond the `@[simp]` lemma +`splitAt n l = (l.take n, l.drop n)`, +which is proved in `Init.Data.List.TakeDrop`. +-/ + +theorem splitAt_go (n : Nat) (l acc : List α) : + splitAt.go l xs n acc = + if n < xs.length then (acc.reverse ++ xs.take n, xs.drop n) else (l, []) := by + induction xs generalizing n acc with + | nil => simp [splitAt.go] + | cons x xs ih => + cases n with + | zero => simp [splitAt.go] + | succ n => + rw [splitAt.go, take_succ_cons, drop_succ_cons, ih n (x :: acc), + reverse_cons, append_assoc, singleton_append, length_cons] + simp only [Nat.succ_lt_succ_iff] + /-! ## Manipulating elements -/ /-! ### replace -/ diff --git a/src/Init/Data/List/Nat/Range.lean b/src/Init/Data/List/Nat/Range.lean index ec2145659f1b..3a3dfb029bcf 100644 --- a/src/Init/Data/List/Nat/Range.lean +++ b/src/Init/Data/List/Nat/Range.lean @@ -260,9 +260,24 @@ theorem enumFrom_map_snd : ∀ (n) (l : List α), map Prod.snd (enumFrom n l) = theorem snd_mem_of_mem_enumFrom {x : Nat × α} {n : Nat} {l : List α} (h : x ∈ enumFrom n l) : x.2 ∈ l := enumFrom_map_snd n l ▸ mem_map_of_mem _ h -theorem mem_enumFrom {x : α} {i j : Nat} (xs : List α) (h : (i, x) ∈ xs.enumFrom j) : - j ≤ i ∧ i < j + xs.length ∧ x ∈ xs := - ⟨le_fst_of_mem_enumFrom h, fst_lt_add_of_mem_enumFrom h, snd_mem_of_mem_enumFrom h⟩ +theorem snd_eq_of_mem_enumFrom {x : Nat × α} {n : Nat} {l : List α} (h : x ∈ enumFrom n l) : + x.2 = l[x.1 - n]'(by have := le_fst_of_mem_enumFrom h; have := fst_lt_add_of_mem_enumFrom h; omega) := by + induction l generalizing n with + | nil => cases h + | cons hd tl ih => + cases h with + | head h => simp + | tail h m => + specialize ih m + have : x.1 - n = x.1 - (n + 1) + 1 := by + have := le_fst_of_mem_enumFrom m + omega + simp [this, ih] + +theorem mem_enumFrom {x : α} {i j : Nat} {xs : List α} (h : (i, x) ∈ xs.enumFrom j) : + j ≤ i ∧ i < j + xs.length ∧ + x = xs[i - j]'(by have := le_fst_of_mem_enumFrom h; have := fst_lt_add_of_mem_enumFrom h; omega) := + ⟨le_fst_of_mem_enumFrom h, fst_lt_add_of_mem_enumFrom h, snd_eq_of_mem_enumFrom h⟩ theorem enumFrom_cons' (n : Nat) (x : α) (xs : List α) : enumFrom n (x :: xs) = (n, x) :: (enumFrom n xs).map (Prod.map (· + 1) id) := by @@ -329,6 +344,14 @@ theorem fst_lt_of_mem_enum {x : Nat × α} {l : List α} (h : x ∈ enum l) : x. theorem snd_mem_of_mem_enum {x : Nat × α} {l : List α} (h : x ∈ enum l) : x.2 ∈ l := snd_mem_of_mem_enumFrom h +theorem snd_eq_of_mem_enum {x : Nat × α} {l : List α} (h : x ∈ enum l) : + x.2 = l[x.1]'(fst_lt_of_mem_enum h) := + snd_eq_of_mem_enumFrom h + +theorem mem_enum {x : α} {i : Nat} {xs : List α} (h : (i, x) ∈ xs.enum) : + i < xs.length ∧ x = xs[i]'(fst_lt_of_mem_enum h) := + by simpa using mem_enumFrom h + theorem map_enum (f : α → β) (l : List α) : map (Prod.map id f) (enum l) = enum (map f l) := map_enumFrom f 0 l diff --git a/src/Init/Data/List/Nat/TakeDrop.lean b/src/Init/Data/List/Nat/TakeDrop.lean index 4a0da8468ce3..7c7825afeafc 100644 --- a/src/Init/Data/List/Nat/TakeDrop.lean +++ b/src/Init/Data/List/Nat/TakeDrop.lean @@ -70,20 +70,20 @@ theorem get?_take_eq_none {l : List α} {n m : Nat} (h : n ≤ m) : (l.take n).get? m = none := by simp [getElem?_take_eq_none h] -theorem getElem?_take_eq_if {l : List α} {n m : Nat} : +theorem getElem?_take {l : List α} {n m : Nat} : (l.take n)[m]? = if m < n then l[m]? else none := by split - · next h => exact getElem?_take h + · next h => exact getElem?_take_of_lt h · next h => exact getElem?_take_eq_none (Nat.le_of_not_lt h) -@[deprecated getElem?_take_eq_if (since := "2024-06-12")] +@[deprecated getElem?_take (since := "2024-06-12")] theorem get?_take_eq_if {l : List α} {n m : Nat} : (l.take n).get? m = if m < n then l.get? m else none := by - simp [getElem?_take_eq_if] + simp [getElem?_take] theorem head?_take {l : List α} {n : Nat} : (l.take n).head? = if n = 0 then none else l.head? := by - simp [head?_eq_getElem?, getElem?_take_eq_if] + simp [head?_eq_getElem?, getElem?_take] split · rw [if_neg (by omega)] · rw [if_pos (by omega)] @@ -95,7 +95,7 @@ theorem head_take {l : List α} {n : Nat} (h : l.take n ≠ []) : simp_all theorem getLast?_take {l : List α} : (l.take n).getLast? = if n = 0 then none else l[n - 1]?.or l.getLast? := by - rw [getLast?_eq_getElem?, getElem?_take_eq_if, length_take] + rw [getLast?_eq_getElem?, getElem?_take, length_take] split · rw [if_neg (by omega)] rw [Nat.min_def] @@ -128,7 +128,7 @@ theorem take_take : ∀ (n m) (l : List α), take n (take m l) = take (min n m) theorem take_set_of_lt (a : α) {n m : Nat} (l : List α) (h : m < n) : (l.set n a).take m = l.take m := List.ext_getElem? fun i => by - rw [getElem?_take_eq_if, getElem?_take_eq_if] + rw [getElem?_take, getElem?_take] split · next h' => rw [getElem?_set_ne (by omega)] · rfl @@ -203,7 +203,7 @@ theorem map_eq_append_split {f : α → β} {l : List α} {s₁ s₂ : List β} theorem take_prefix_take_left (l : List α) {m n : Nat} (h : m ≤ n) : take m l <+: take n l := by rw [isPrefix_iff] intro i w - rw [getElem?_take, getElem_take', getElem?_eq_getElem] + rw [getElem?_take_of_lt, getElem_take', getElem?_eq_getElem] simp only [length_take] at w exact Nat.lt_of_lt_of_le (Nat.lt_of_lt_of_le w (Nat.min_le_left _ _)) h @@ -334,7 +334,7 @@ theorem set_eq_take_append_cons_drop {l : List α} {n : Nat} {a : α} : · ext1 m by_cases h' : m < n · rw [getElem?_append_left (by simp [length_take]; omega), getElem?_set_ne (by omega), - getElem?_take h'] + getElem?_take_of_lt h'] · by_cases h'' : m = n · subst h'' rw [getElem?_set_eq ‹_›, getElem?_append_right, length_take, @@ -373,40 +373,67 @@ theorem drop_take : ∀ (m n : Nat) (l : List α), drop n (take m l) = take (m - congr 1 omega -theorem take_reverse {α} {xs : List α} {n : Nat} (h : n ≤ xs.length) : +theorem take_reverse {α} {xs : List α} {n : Nat} : xs.reverse.take n = (xs.drop (xs.length - n)).reverse := by - induction xs generalizing n <;> - simp only [reverse_cons, drop, reverse_nil, Nat.zero_sub, length, take_nil] - next xs_hd xs_tl xs_ih => - cases Nat.lt_or_eq_of_le h with - | inl h' => - have h' := Nat.le_of_succ_le_succ h' - rw [take_append_of_le_length, xs_ih h'] - rw [show xs_tl.length + 1 - n = succ (xs_tl.length - n) from _, drop] - · rwa [succ_eq_add_one, Nat.sub_add_comm] - · rwa [length_reverse] - | inr h' => - subst h' - rw [length, Nat.sub_self, drop] - suffices xs_tl.length + 1 = (xs_tl.reverse ++ [xs_hd]).length by - rw [this, take_length, reverse_cons] - rw [length_append, length_reverse] - rfl - -@[deprecated (since := "2024-06-15")] abbrev reverse_take := @take_reverse - -theorem drop_reverse {α} {xs : List α} {n : Nat} (h : n ≤ xs.length) : + by_cases h : n ≤ xs.length + · induction xs generalizing n <;> + simp only [reverse_cons, drop, reverse_nil, Nat.zero_sub, length, take_nil] + next xs_hd xs_tl xs_ih => + cases Nat.lt_or_eq_of_le h with + | inl h' => + have h' := Nat.le_of_succ_le_succ h' + rw [take_append_of_le_length, xs_ih h'] + rw [show xs_tl.length + 1 - n = succ (xs_tl.length - n) from _, drop] + · rwa [succ_eq_add_one, Nat.sub_add_comm] + · rwa [length_reverse] + | inr h' => + subst h' + rw [length, Nat.sub_self, drop] + suffices xs_tl.length + 1 = (xs_tl.reverse ++ [xs_hd]).length by + rw [this, take_length, reverse_cons] + rw [length_append, length_reverse] + rfl + · have w : xs.length - n = 0 := by omega + rw [take_of_length_le, w, drop_zero] + simp + omega + +theorem drop_reverse {α} {xs : List α} {n : Nat} : xs.reverse.drop n = (xs.take (xs.length - n)).reverse := by - conv => - rhs - rw [← reverse_reverse xs] - rw [← reverse_reverse xs] at h - generalize xs.reverse = xs' at h ⊢ - rw [take_reverse] - · simp only [length_reverse, reverse_reverse] at * + by_cases h : n ≤ xs.length + · conv => + rhs + rw [← reverse_reverse xs] + rw [← reverse_reverse xs] at h + generalize xs.reverse = xs' at h ⊢ + rw [take_reverse] + · simp only [length_reverse, reverse_reverse] at * + congr + omega + · have w : xs.length - n = 0 := by omega + rw [drop_of_length_le, w, take_zero, reverse_nil] + simp + omega + +theorem reverse_take {l : List α} {n : Nat} : + (l.take n).reverse = l.reverse.drop (l.length - n) := by + by_cases h : n ≤ l.length + · rw [drop_reverse] + congr + omega + · have w : l.length - n = 0 := by omega + rw [w, drop_zero, take_of_length_le] + omega + +theorem reverse_drop {l : List α} {n : Nat} : + (l.drop n).reverse = l.reverse.take (l.length - n) := by + by_cases h : n ≤ l.length + · rw [take_reverse] congr omega - · simp only [length_reverse, sub_le] + · have w : l.length - n = 0 := by omega + rw [w, take_zero, drop_of_length_le, reverse_nil] + omega /-! ### rotateLeft -/ diff --git a/src/Init/Data/List/Pairwise.lean b/src/Init/Data/List/Pairwise.lean index 8d21edf4b885..db4686c55fdc 100644 --- a/src/Init/Data/List/Pairwise.lean +++ b/src/Init/Data/List/Pairwise.lean @@ -226,6 +226,18 @@ theorem pairwise_iff_forall_sublist : l.Pairwise R ↔ (∀ {a b}, [a,b] <+ l intro a b hab apply h; exact hab.cons _ +theorem Pairwise.rel_of_mem_take_of_mem_drop + {l : List α} (h : l.Pairwise R) (hx : x ∈ l.take n) (hy : y ∈ l.drop n) : R x y := by + apply pairwise_iff_forall_sublist.mp h + rw [← take_append_drop n l, sublist_append_iff] + refine ⟨[x], [y], rfl, by simpa, by simpa⟩ + +theorem Pairwise.rel_of_mem_append + {l₁ l₂ : List α} (h : (l₁ ++ l₂).Pairwise R) (hx : x ∈ l₁) (hy : y ∈ l₂) : R x y := by + apply pairwise_iff_forall_sublist.mp h + rw [sublist_append_iff] + exact ⟨[x], [y], rfl, by simpa, by simpa⟩ + theorem pairwise_of_forall_mem_list {l : List α} {r : α → α → Prop} (h : ∀ a ∈ l, ∀ b ∈ l, r a b) : l.Pairwise r := by rw [pairwise_iff_forall_sublist] diff --git a/src/Init/Data/List/Perm.lean b/src/Init/Data/List/Perm.lean index 788bafa57c11..3d2019673cea 100644 --- a/src/Init/Data/List/Perm.lean +++ b/src/Init/Data/List/Perm.lean @@ -400,6 +400,40 @@ theorem Pairwise.perm {R : α → α → Prop} {l l' : List α} (hR : l.Pairwise theorem Perm.pairwise {R : α → α → Prop} {l l' : List α} (hl : l ~ l') (hR : l.Pairwise R) (hsymm : ∀ {x y}, R x y → R y x) : l'.Pairwise R := hR.perm hl hsymm +/-- +If two lists are sorted by an antisymmetric relation, and permutations of each other, +they must be equal. +-/ +theorem Perm.eq_of_sorted : ∀ {l₁ l₂ : List α} + (_ : ∀ a b, a ∈ l₁ → b ∈ l₂ → le a b → le b a → a = b) + (_ : l₁.Pairwise le) (_ : l₂.Pairwise le) (_ : l₁ ~ l₂), l₁ = l₂ + | [], [], _, _, _, _ => rfl + | [], b :: l₂, _, _, _, h => by simp_all + | a :: l₁, [], _, _, _, h => by simp_all + | a :: l₁, b :: l₂, w, h₁, h₂, h => by + have am : a ∈ b :: l₂ := h.subset (mem_cons_self _ _) + have bm : b ∈ a :: l₁ := h.symm.subset (mem_cons_self _ _) + have ab : a = b := by + simp only [mem_cons] at am + rcases am with rfl | am + · rfl + · simp only [mem_cons] at bm + rcases bm with rfl | bm + · rfl + · exact w _ _ (mem_cons_self _ _) (mem_cons_self _ _) + (rel_of_pairwise_cons h₁ bm) (rel_of_pairwise_cons h₂ am) + subst ab + simp only [perm_cons] at h + have := Perm.eq_of_sorted + (fun x y hx hy => w x y (mem_cons_of_mem a hx) (mem_cons_of_mem a hy)) + h₁.tail h₂.tail h + simp_all + +theorem Nodup.perm {l l' : List α} (hR : l.Nodup) (hl : l ~ l') : l'.Nodup := + Pairwise.perm hR hl (by intro x y h h'; simp_all) + +theorem Perm.nodup {l l' : List α} (hl : l ~ l') (hR : l.Nodup) : l'.Nodup := hR.perm hl + theorem Perm.nodup_iff {l₁ l₂ : List α} : l₁ ~ l₂ → (Nodup l₁ ↔ Nodup l₂) := Perm.pairwise_iff <| @Ne.symm α diff --git a/src/Init/Data/List/Sort.lean b/src/Init/Data/List/Sort.lean new file mode 100644 index 000000000000..d26785c28d49 --- /dev/null +++ b/src/Init/Data/List/Sort.lean @@ -0,0 +1,9 @@ +/- +Copyright (c) 2024 Lean FRO. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +prelude +import Init.Data.List.Sort.Basic +import Init.Data.List.Sort.Impl +import Init.Data.List.Sort.Lemmas diff --git a/src/Init/Data/List/Sort/Basic.lean b/src/Init/Data/List/Sort/Basic.lean new file mode 100644 index 000000000000..11bbadb2aa5c --- /dev/null +++ b/src/Init/Data/List/Sort/Basic.lean @@ -0,0 +1,81 @@ +/- +Copyright (c) 2024 Lean FRO. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +prelude +import Init.Data.List.Impl + +/-! +# Definition of `merge` and `mergeSort`. + +These definitions are intended for verification purposes, +and are replaced at runtime by efficient versions in `Init.Data.List.Sort.Impl`. +-/ + +namespace List + +/-- +`O(min |l| |r|)`. Merge two lists using `le` as a switch. + +This version is not tail-recursive, +but it is replaced at runtime by `mergeTR` using a `@[csimp]` lemma. +-/ +def merge (le : α → α → Bool) : List α → List α → List α + | [], ys => ys + | xs, [] => xs + | x :: xs, y :: ys => + if le x y then + x :: merge le xs (y :: ys) + else + y :: merge le (x :: xs) ys + +@[simp] theorem merge_nil_left (ys : List α) : merge le [] ys = ys := by simp [merge] +@[simp] theorem merge_nil_right (xs : List α) : merge le xs [] = xs := by + induction xs with + | nil => simp [merge] + | cons x xs ih => simp [merge, ih] + +/-- +Split a list in two equal parts. If the length is odd, the first part will be one element longer. +-/ +def splitInTwo (l : { l : List α // l.length = n }) : + { l : List α // l.length = (n+1)/2 } × { l : List α // l.length = n/2 } := + let r := splitAt ((n+1)/2) l.1 + (⟨r.1, by simp [r, splitAt_eq, l.2]; omega⟩, ⟨r.2, by simp [r, splitAt_eq, l.2]; omega⟩) + +/-- +Simplified implementation of stable merge sort. + +This function is designed for reasoning about the algorithm, and is not efficient. +(It particular it uses the non tail-recursive `merge` function, +and so can not be run on large lists, but also makes unnecessary traversals of lists.) +It is replaced at runtime in the compiler by `mergeSortTR₂` using a `@[csimp]` lemma. + +Because we want the sort to be stable, +it is essential that we split the list in two contiguous sublists. +-/ +def mergeSort (le : α → α → Bool) : List α → List α + | [] => [] + | [a] => [a] + | a :: b :: xs => + let lr := splitInTwo ⟨a :: b :: xs, rfl⟩ + have := by simpa using lr.2.2 + have := by simpa using lr.1.2 + merge le (mergeSort le lr.1) (mergeSort le lr.2) +termination_by l => l.length + + +/-- +Given an ordering relation `le : α → α → Bool`, +construct the reverse lexicographic ordering on `Nat × α`. +which first compares the second components using `le`, +but if these are equivalent (in the sense `le a.2 b.2 && le b.2 a.2`) +then compares the first components using `≤`. + +This function is only used in stating the stability properties of `mergeSort`. +-/ +def enumLE (le : α → α → Bool) (a b : Nat × α) : Bool := + if le a.2 b.2 then if le b.2 a.2 then a.1 ≤ b.1 else true else false + +end List diff --git a/src/Init/Data/List/Sort/Impl.lean b/src/Init/Data/List/Sort/Impl.lean new file mode 100644 index 000000000000..8c8aea8b1bac --- /dev/null +++ b/src/Init/Data/List/Sort/Impl.lean @@ -0,0 +1,237 @@ +/- +Copyright (c) 2024 Lean FRO. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +prelude +import Init.Data.List.Sort.Lemmas + +/-! +# Replacing `merge` and `mergeSort` at runtime with tail-recursive and faster versions. + +We replace `merge` with `mergeTR` using a `@[csimp]` lemma. + +We replace `mergeSort` in two steps: +* first with `mergeSortTR`, which while not tail-recursive itself (it can't be), + uses `mergeTR` internally. +* second with `mergeSortTR₂`, which achieves an ~20% speed-up over `mergeSortTR` + by avoiding some unnecessary list reversals. + +There is no public API in this file; it solely exists to implement the `@[csimp]` lemmas +affecting runtime behaviour. + +## Future work +The current runtime implementation could be further improved in a number of ways, e.g.: +* only walking the list once during splitting, +* using insertion sort for small chunks rather than splitting all the way down to singletons, +* identifying already sorted or reverse sorted chunks and skipping them. + +Because the theory developed for `mergeSort` is independent of the runtime implementation, +as long as such improvements are carefully validated by benchmarking, +they can be done without changing the theory, as long as a `@[csimp]` lemma is provided. +-/ + +open List + +namespace List.MergeSort.Internal + +/-- +`O(min |l| |r|)`. Merge two lists using `le` as a switch. +-/ +def mergeTR (le : α → α → Bool) (l₁ l₂ : List α) : List α := + go l₁ l₂ [] +where go : List α → List α → List α → List α + | [], l₂, acc => reverseAux acc l₂ + | l₁, [], acc => reverseAux acc l₁ + | x :: xs, y :: ys, acc => + if le x y then + go xs (y :: ys) (x :: acc) + else + go (x :: xs) ys (y :: acc) + +theorem mergeTR_go_eq : mergeTR.go le l₁ l₂ acc = acc.reverse ++ merge le l₁ l₂ := by + induction l₁ generalizing l₂ acc with + | nil => simp [mergeTR.go, merge, reverseAux_eq] + | cons x l₁ ih₁ => + induction l₂ generalizing acc with + | nil => simp [mergeTR.go, merge, reverseAux_eq] + | cons y l₂ ih₂ => + simp [mergeTR.go, merge] + split <;> simp [ih₁, ih₂] + +@[csimp] theorem merge_eq_mergeTR : @merge = @mergeTR := by + funext + simp [mergeTR, mergeTR_go_eq] + +/-- +Variant of `splitAt`, that does not reverse the first list, i.e +`splitRevAt n l = ((l.take n).reverse, l.drop n)`. + +This exists solely as an optimization for `mergeSortTR` and `mergeSortTR₂`, +and should not be used elsewhere. +-/ +def splitRevAt (n : Nat) (l : List α) : List α × List α := go l n [] where + /-- Auxiliary for `splitAtRev`: `splitAtRev.go xs n acc = ((take n xs).reverse ++ acc, drop n xs)`. -/ + go : List α → Nat → List α → List α × List α + | x :: xs, n+1, acc => go xs n (x :: acc) + | xs, _, acc => (acc, xs) + +theorem splitRevAt_go (xs : List α) (n : Nat) (acc : List α) : + splitRevAt.go xs n acc = ((take n xs).reverse ++ acc, drop n xs) := by + induction xs generalizing n acc with + | nil => simp [splitRevAt.go] + | cons x xs ih => + cases n with + | zero => simp [splitRevAt.go] + | succ n => + rw [splitRevAt.go, ih n (x :: acc), take_succ_cons, reverse_cons, drop_succ_cons, + append_assoc, singleton_append] + +theorem splitRevAt_eq (n : Nat) (l : List α) : splitRevAt n l = ((l.take n).reverse, l.drop n) := by + rw [splitRevAt, splitRevAt_go, append_nil] + +/-- +An intermediate speed-up for `mergeSort`. +This version uses the tail-recurive `mergeTR` function as a subroutine. + +This is not the final version we use at runtime, as `mergeSortTR₂` is faster. +This definition is useful as an intermediate step in proving the `@[csimp]` lemma for `mergeSortTR₂`. +-/ +def mergeSortTR (le : α → α → Bool) (l : List α) : List α := + run ⟨l, rfl⟩ +where run : {n : Nat} → { l : List α // l.length = n } → List α + | 0, ⟨[], _⟩ => [] + | 1, ⟨[a], _⟩ => [a] + | n+2, xs => + let (l, r) := splitInTwo xs + mergeTR le (run l) (run r) + +/-- +Split a list in two equal parts, reversing the first part. +If the length is odd, the first part will be one element longer. +-/ +def splitRevInTwo (l : { l : List α // l.length = n }) : + { l : List α // l.length = (n+1)/2 } × { l : List α // l.length = n/2 } := + let r := splitRevAt ((n+1)/2) l.1 + (⟨r.1, by simp [r, splitRevAt_eq, l.2]; omega⟩, ⟨r.2, by simp [r, splitRevAt_eq, l.2]; omega⟩) + +/-- +Split a list in two equal parts, reversing the first part. +If the length is odd, the second part will be one element longer. +-/ +def splitRevInTwo' (l : { l : List α // l.length = n }) : + { l : List α // l.length = n/2 } × { l : List α // l.length = (n+1)/2 } := + let r := splitRevAt (n/2) l.1 + (⟨r.1, by simp [r, splitRevAt_eq, l.2]; omega⟩, ⟨r.2, by simp [r, splitRevAt_eq, l.2]; omega⟩) + +/-- +Faster version of `mergeSortTR`, which avoids unnecessary list reversals. +-/ +-- Per the benchmark in `tests/bench/mergeSort/` +-- (which averages over 4 use cases: already sorted lists, reverse sorted lists, almost sorted lists, and random lists), +-- for lists of length 10^6, `mergeSortTR₂` is about 20% faster than `mergeSortTR`. +def mergeSortTR₂ (le : α → α → Bool) (l : List α) : List α := + run ⟨l, rfl⟩ +where + run : {n : Nat} → { l : List α // l.length = n } → List α + | 0, ⟨[], _⟩ => [] + | 1, ⟨[a], _⟩ => [a] + | n+2, xs => + let (l, r) := splitRevInTwo xs + mergeTR le (run' l) (run r) + run' : {n : Nat} → { l : List α // l.length = n } → List α + | 0, ⟨[], _⟩ => [] + | 1, ⟨[a], _⟩ => [a] + | n+2, xs => + let (l, r) := splitRevInTwo' xs + mergeTR le (run' r) (run l) + +theorem splitRevInTwo'_fst (l : { l : List α // l.length = n }) : + (splitRevInTwo' l).1 = ⟨(splitInTwo ⟨l.1.reverse, by simpa using l.2⟩).2.1, by have := l.2; simp; omega⟩ := by + simp only [splitRevInTwo', splitRevAt_eq, reverse_take, splitInTwo_snd] + congr + have := l.2 + omega +theorem splitRevInTwo'_snd (l : { l : List α // l.length = n }) : + (splitRevInTwo' l).2 = ⟨(splitInTwo ⟨l.1.reverse, by simpa using l.2⟩).1.1.reverse, by have := l.2; simp; omega⟩ := by + simp only [splitRevInTwo', splitRevAt_eq, reverse_take, splitInTwo_fst, reverse_reverse] + congr 2 + have := l.2 + simp + omega +theorem splitRevInTwo_fst (l : { l : List α // l.length = n }) : + (splitRevInTwo l).1 = ⟨(splitInTwo l).1.1.reverse, by have := l.2; simp; omega⟩ := by + simp only [splitRevInTwo, splitRevAt_eq, reverse_take, splitInTwo_fst] +theorem splitRevInTwo_snd (l : { l : List α // l.length = n }) : + (splitRevInTwo l).2 = ⟨(splitInTwo l).2.1, by have := l.2; simp; omega⟩ := by + simp only [splitRevInTwo, splitRevAt_eq, reverse_take, splitInTwo_snd] + +theorem mergeSortTR_run_eq_mergeSort : {n : Nat} → (l : { l : List α // l.length = n }) → mergeSortTR.run le l = mergeSort le l.1 + | 0, ⟨[], _⟩ + | 1, ⟨[a], _⟩ => by simp [mergeSortTR.run, mergeSort] + | n+2, ⟨a :: b :: l, h⟩ => by + cases h + simp only [mergeSortTR.run, mergeSortTR.run, mergeSort] + rw [merge_eq_mergeTR] + rw [mergeSortTR_run_eq_mergeSort, mergeSortTR_run_eq_mergeSort] + +-- We don't make this a `@[csimp]` lemma because `mergeSort_eq_mergeSortTR₂` is faster. +theorem mergeSort_eq_mergeSortTR : @mergeSort = @mergeSortTR := by + funext + rw [mergeSortTR, mergeSortTR_run_eq_mergeSort] + +-- This mutual block is unfortunately quite slow to elaborate. +set_option maxHeartbeats 400000 in +mutual +theorem mergeSortTR₂_run_eq_mergeSort : {n : Nat} → (l : { l : List α // l.length = n }) → mergeSortTR₂.run le l = mergeSort le l.1 + | 0, ⟨[], _⟩ + | 1, ⟨[a], _⟩ => by simp [mergeSortTR₂.run, mergeSort] + | n+2, ⟨a :: b :: l, h⟩ => by + cases h + simp only [mergeSortTR₂.run, mergeSort] + rw [splitRevInTwo_fst, splitRevInTwo_snd] + rw [mergeSortTR₂_run_eq_mergeSort, mergeSortTR₂_run'_eq_mergeSort] + rw [merge_eq_mergeTR] + rw [reverse_reverse] +termination_by n => n + +theorem mergeSortTR₂_run'_eq_mergeSort : {n : Nat} → (l : { l : List α // l.length = n }) → (w : l' = l.1.reverse) → mergeSortTR₂.run' le l = mergeSort le l' + | 0, ⟨[], _⟩, w + | 1, ⟨[a], _⟩, w => by simp_all [mergeSortTR₂.run', mergeSort] + | n+2, ⟨a :: b :: l, h⟩, w => by + cases h + simp only [mergeSortTR₂.run', mergeSort] + rw [splitRevInTwo'_fst, splitRevInTwo'_snd] + rw [mergeSortTR₂_run_eq_mergeSort, mergeSortTR₂_run'_eq_mergeSort _ rfl] + rw [← merge_eq_mergeTR] + have w' := congrArg length w + simp at w' + cases l' with + | nil => simp at w' + | cons x l' => + cases l' with + | nil => simp at w'; omega + | cons y l' => + rw [mergeSort] + congr 2 + · dsimp at w + simp only [w] + simp only [splitInTwo_fst, splitInTwo_snd, reverse_take, take_reverse] + congr 1 + rw [w, length_reverse] + simp + · dsimp at w + simp only [w] + simp only [reverse_cons, append_assoc, singleton_append, splitInTwo_snd, length_cons] + congr 1 + simp at w' + omega +termination_by n => n + +end + +@[csimp] theorem mergeSort_eq_mergeSortTR₂ : @mergeSort = @mergeSortTR₂ := by + funext + rw [mergeSortTR₂, mergeSortTR₂_run_eq_mergeSort] + +end List.MergeSort.Internal diff --git a/src/Init/Data/List/Sort/Lemmas.lean b/src/Init/Data/List/Sort/Lemmas.lean new file mode 100644 index 000000000000..b3077d6a0fbc --- /dev/null +++ b/src/Init/Data/List/Sort/Lemmas.lean @@ -0,0 +1,399 @@ +/- +Copyright (c) 2024 Lean FRO. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +prelude +import Init.Data.List.Perm +import Init.Data.List.Sort.Basic +import Init.Data.Bool + +/-! +# Basic properties of `mergeSort`. + +* `mergeSort_sorted`: `mergeSort` produces a sorted list. +* `mergeSort_perm`: `mergeSort` is a permutation of the input list. +* `mergeSort_of_sorted`: `mergeSort` does not change a sorted list. +* `mergeSort_cons`: proves `mergeSort le (x :: xs) = l₁ ++ x :: l₂` for some `l₁, l₂` + so that `mergeSort le xs = l₁ ++ l₂`, and no `a ∈ l₁` satisfies `le a x`. +* `mergeSort_stable`: if `c` is a sorted sublist of `l`, then `c` is still a sublist of `mergeSort le l`. + +-/ + +namespace List + +-- We enable this instance locally so we can write `Sorted le` instead of `Sorted (le · ·)` everywhere. +attribute [local instance] boolRelToRel + +variable {le : α → α → Bool} + +/-! ### splitInTwo -/ + +@[simp] theorem splitInTwo_fst (l : { l : List α // l.length = n }) : (splitInTwo l).1 = ⟨l.1.take ((n+1)/2), by simp [splitInTwo, splitAt_eq, l.2]; omega⟩ := by + simp [splitInTwo, splitAt_eq] + +@[simp] theorem splitInTwo_snd (l : { l : List α // l.length = n }) : (splitInTwo l).2 = ⟨l.1.drop ((n+1)/2), by simp [splitInTwo, splitAt_eq, l.2]; omega⟩ := by + simp [splitInTwo, splitAt_eq] + +theorem splitInTwo_fst_append_splitInTwo_snd (l : { l : List α // l.length = n }) : (splitInTwo l).1.1 ++ (splitInTwo l).2.1 = l.1 := by + simp + +theorem splitInTwo_cons_cons_enumFrom_fst (i : Nat) (l : List α) : + (splitInTwo ⟨(i, a) :: (i+1, b) :: l.enumFrom (i+2), rfl⟩).1.1 = + (splitInTwo ⟨a :: b :: l, rfl⟩).1.1.enumFrom i := by + simp only [length_cons, splitInTwo_fst, enumFrom_length] + ext1 j + rw [getElem?_take, getElem?_enumFrom, getElem?_take] + split + · rw [getElem?_cons, getElem?_cons, getElem?_cons, getElem?_cons] + split + · simp; omega + · split + · simp; omega + · simp only [getElem?_enumFrom] + congr + ext <;> simp; omega + · simp + +theorem splitInTwo_cons_cons_enumFrom_snd (i : Nat) (l : List α) : + (splitInTwo ⟨(i, a) :: (i+1, b) :: l.enumFrom (i+2), rfl⟩).2.1 = + (splitInTwo ⟨a :: b :: l, rfl⟩).2.1.enumFrom (i+(l.length+3)/2) := by + simp only [length_cons, splitInTwo_snd, enumFrom_length] + ext1 j + rw [getElem?_drop, getElem?_enumFrom, getElem?_drop] + rw [getElem?_cons, getElem?_cons, getElem?_cons, getElem?_cons] + split + · simp; omega + · split + · simp; omega + · simp only [getElem?_enumFrom] + congr + ext <;> simp; omega + +theorem splitInTwo_fst_sorted (l : { l : List α // l.length = n }) (h : Pairwise le l.1) : Pairwise le (splitInTwo l).1.1 := by + rw [splitInTwo_fst] + exact h.take + +theorem splitInTwo_snd_sorted (l : { l : List α // l.length = n }) (h : Pairwise le l.1) : Pairwise le (splitInTwo l).2.1 := by + rw [splitInTwo_snd] + exact h.drop + +theorem splitInTwo_fst_le_splitInTwo_snd {l : { l : List α // l.length = n }} (h : Pairwise le l.1) : + ∀ a b, a ∈ (splitInTwo l).1.1 → b ∈ (splitInTwo l).2.1 → le a b := by + rw [splitInTwo_fst, splitInTwo_snd] + intro a b ma mb + exact h.rel_of_mem_take_of_mem_drop ma mb + +/-! ### enumLE -/ + +theorem enumLE_trans (trans : ∀ a b c, le a b → le b c → le a c) + (a b c : Nat × α) : enumLE le a b → enumLE le b c → enumLE le a c := by + simp only [enumLE] + split <;> split <;> split <;> rename_i ab₂ ba₂ bc₂ + · simp_all + intro ab₁ + intro h + refine ⟨trans _ _ _ ab₂ bc₂, ?_⟩ + rcases h with (cd₂ | bc₁) + · exact Or.inl (Decidable.byContradiction + (fun ca₂ => by simp_all [trans _ _ _ (by simpa using ca₂) ab₂])) + · exact Or.inr (Nat.le_trans ab₁ bc₁) + · simp_all + · simp_all + intro h + refine ⟨trans _ _ _ ab₂ bc₂, ?_⟩ + left + rcases h with (cb₂ | _) + · exact (Decidable.byContradiction + (fun ca₂ => by simp_all [trans _ _ _ (by simpa using ca₂) ab₂])) + · exact (Decidable.byContradiction + (fun ca₂ => by simp_all [trans _ _ _ bc₂ (by simpa using ca₂)])) + · simp_all + · simp_all + · simp_all + · simp_all + · simp_all + +theorem enumLE_total (total : ∀ a b, !le a b → le b a) + (a b : Nat × α) : !enumLE le a b → enumLE le b a := by + simp only [enumLE] + split <;> split + · simpa using Nat.le_of_lt + · simp + · simp + · simp_all [total a.2 b.2] + +/-! ### merge -/ + +theorem merge_stable : ∀ (xs ys) (_ : ∀ x y, x ∈ xs → y ∈ ys → x.1 ≤ y.1), + (merge (enumLE le) xs ys).map (·.2) = merge le (xs.map (·.2)) (ys.map (·.2)) + | [], ys, _ => by simp [merge] + | xs, [], _ => by simp [merge] + | (i, x) :: xs, (j, y) :: ys, h => by + simp only [merge, enumLE, map_cons] + split <;> rename_i w + · rw [if_pos (by simp [h _ _ (mem_cons_self ..) (mem_cons_self ..)])] + simp only [map_cons, cons.injEq, true_and] + rw [merge_stable, map_cons] + exact fun x' y' mx my => h x' y' (mem_cons_of_mem (i, x) mx) my + · simp only [↓reduceIte, map_cons, cons.injEq, true_and] + rw [merge_stable, map_cons] + exact fun x' y' mx my => h x' y' mx (mem_cons_of_mem (j, y) my) + +/-- +The elements of `merge le xs ys` are exactly the elements of `xs` and `ys`. +-/ +-- We subsequently prove that `mergeSort_perm : merge le xs ys ~ xs ++ ys`. +theorem mem_merge {a : α} {xs ys : List α} : a ∈ merge le xs ys ↔ a ∈ xs ∨ a ∈ ys := by + induction xs generalizing ys with + | nil => simp [merge] + | cons x xs ih => + induction ys with + | nil => simp [merge] + | cons y ys ih => + simp only [merge] + split <;> rename_i h + · simp_all [or_assoc] + · simp only [mem_cons, or_assoc, Bool.not_eq_true, ih, ← or_assoc] + apply or_congr_left + simp only [or_comm (a := a = y), or_assoc] + +/-- +If the ordering relation `le` is transitive and total (i.e. `le a b ∨ le b a` for all `a, b`) +then the `merge` of two sorted lists is sorted. +-/ +theorem merge_sorted + (trans : ∀ (a b c : α), le a b → le b c → le a c) + (total : ∀ (a b : α), !le a b → le b a) + (l₁ l₂ : List α) (h₁ : l₁.Pairwise le) (h₂ : l₂.Pairwise le) : (merge le l₁ l₂).Pairwise le := by + induction l₁ generalizing l₂ with + | nil => simpa only [merge] + | cons x l₁ ih₁ => + induction l₂ with + | nil => simpa only [merge] + | cons y l₂ ih₂ => + simp only [merge] + split <;> rename_i h + · apply Pairwise.cons + · intro z m + rw [mem_merge, mem_cons] at m + rcases m with (m|rfl|m) + · exact rel_of_pairwise_cons h₁ m + · exact h + · exact trans _ _ _ h (rel_of_pairwise_cons h₂ m) + · exact ih₁ _ h₁.tail h₂ + · apply Pairwise.cons + · intro z m + rw [mem_merge, mem_cons] at m + rcases m with (⟨rfl|m⟩|m) + · exact total _ _ (by simpa using h) + · exact trans _ _ _ (total _ _ (by simpa using h)) (rel_of_pairwise_cons h₁ m) + · exact rel_of_pairwise_cons h₂ m + · exact ih₂ h₂.tail + +theorem merge_of_le : ∀ {xs ys : List α} (_ : ∀ a b, a ∈ xs → b ∈ ys → le a b), + merge le xs ys = xs ++ ys + | [], ys, _ + | xs, [], _ => by simp [merge] + | x :: xs, y :: ys, h => by + simp only [merge, cons_append] + rw [if_pos, merge_of_le] + · intro a b ma mb + exact h a b (mem_cons_of_mem _ ma) mb + · exact h x y (mem_cons_self _ _) (mem_cons_self _ _) + +variable (le) in +theorem merge_perm_append : ∀ {xs ys : List α}, merge le xs ys ~ xs ++ ys + | [], ys => by simp [merge] + | xs, [] => by simp [merge] + | x :: xs, y :: ys => by + simp only [merge] + split + · exact merge_perm_append.cons x + · exact (merge_perm_append.cons y).trans + ((Perm.swap x y _).trans (perm_middle.symm.cons x)) + +/-! ### mergeSort -/ + +variable (le) in +theorem mergeSort_perm : ∀ (l : List α), mergeSort le l ~ l + | [] => by simp [mergeSort] + | [a] => by simp [mergeSort] + | a :: b :: xs => by + simp only [mergeSort] + have : (splitInTwo ⟨a :: b :: xs, rfl⟩).1.1.length < xs.length + 1 + 1 := by simp [splitInTwo_fst]; omega + have : (splitInTwo ⟨a :: b :: xs, rfl⟩).2.1.length < xs.length + 1 + 1 := by simp [splitInTwo_snd]; omega + exact (merge_perm_append le).trans + (((mergeSort_perm _).append (mergeSort_perm _)).trans + (Perm.of_eq (splitInTwo_fst_append_splitInTwo_snd _))) +termination_by l => l.length + +@[simp] theorem mem_mergeSort {a : α} {l : List α} : a ∈ mergeSort le l ↔ a ∈ l := + (mergeSort_perm le l).mem_iff + +/-- +The result of `mergeSort` is sorted, +as long as the comparison function is transitive (`le a b → le b c → le a c`) +and total in the sense that `le a b ∨ le b a`. + +The comparison function need not be irreflexive, i.e. `le a b` and `le b a` is allowed even when `a ≠ b`. +-/ +theorem mergeSort_sorted + (trans : ∀ (a b c : α), le a b → le b c → le a c) + (total : ∀ (a b : α), !le a b → le b a) : + (l : List α) → (mergeSort le l).Pairwise le + | [] => by simp [mergeSort] + | [a] => by simp [mergeSort] + | a :: b :: xs => by + have : (splitInTwo ⟨a :: b :: xs, rfl⟩).1.1.length < xs.length + 1 + 1 := by simp [splitInTwo_fst]; omega + have : (splitInTwo ⟨a :: b :: xs, rfl⟩).2.1.length < xs.length + 1 + 1 := by simp [splitInTwo_snd]; omega + rw [mergeSort] + apply merge_sorted @trans @total + apply mergeSort_sorted trans total + apply mergeSort_sorted trans total +termination_by l => l.length + +/-- +If the input list is already sorted, then `mergeSort` does not change the list. +-/ +theorem mergeSort_of_sorted : ∀ {l : List α} (_ : Pairwise le l), mergeSort le l = l + | [], _ => by simp [mergeSort] + | [a], _ => by simp [mergeSort] + | a :: b :: xs, h => by + have : (splitInTwo ⟨a :: b :: xs, rfl⟩).1.1.length < xs.length + 1 + 1 := by simp [splitInTwo_fst]; omega + have : (splitInTwo ⟨a :: b :: xs, rfl⟩).2.1.length < xs.length + 1 + 1 := by simp [splitInTwo_snd]; omega + rw [mergeSort] + rw [mergeSort_of_sorted (splitInTwo_fst_sorted ⟨a :: b :: xs, rfl⟩ h)] + rw [mergeSort_of_sorted (splitInTwo_snd_sorted ⟨a :: b :: xs, rfl⟩ h)] + rw [merge_of_le (splitInTwo_fst_le_splitInTwo_snd h)] + rw [splitInTwo_fst_append_splitInTwo_snd] +termination_by l => l.length + +/-- +This merge sort algorithm is stable, +in the sense that breaking ties in the ordering function using the position in the list +has no effect on the output. + +That is, elements which are equal with respect to the ordering function will remain +in the same order in the output list as they were in the input list. + +See also: +* `mergeSort_stable`: if `c <+ l` and `c.Pairwise le`, then `c <+ mergeSort le l`. +* `mergeSort_stable_pair`: if `[a, b] <+ l` and `le a b`, then `[a, b] <+ mergeSort le l`) +-/ +theorem mergeSort_enum {l : List α} : + (mergeSort (enumLE le) (l.enum)).map (·.2) = mergeSort le l := + go 0 l +where go : ∀ (i : Nat) (l : List α), + (mergeSort (enumLE le) (l.enumFrom i)).map (·.2) = mergeSort le l + | _, [] + | _, [a] => by simp [mergeSort] + | _, a :: b :: xs => by + have : (splitInTwo ⟨a :: b :: xs, rfl⟩).1.1.length < xs.length + 1 + 1 := by simp [splitInTwo_fst]; omega + have : (splitInTwo ⟨a :: b :: xs, rfl⟩).2.1.length < xs.length + 1 + 1 := by simp [splitInTwo_snd]; omega + simp only [mergeSort, enumFrom] + rw [splitInTwo_cons_cons_enumFrom_fst] + rw [splitInTwo_cons_cons_enumFrom_snd] + rw [merge_stable] + · rw [go, go] + · simp only [mem_mergeSort, Prod.forall] + intros j x k y mx my + have := mem_enumFrom mx + have := mem_enumFrom my + simp_all + omega +termination_by _ l => l.length + +theorem mergeSort_cons {le : α → α → Bool} + (trans : ∀ (a b c : α), le a b → le b c → le a c) + (total : ∀ (a b : α), !le a b → le b a) + (a : α) (l : List α) : + ∃ l₁ l₂, mergeSort le (a :: l) = l₁ ++ a :: l₂ ∧ mergeSort le l = l₁ ++ l₂ ∧ + ∀ b, b ∈ l₁ → !le a b := by + rw [← mergeSort_enum] + rw [enum_cons] + have nd : Nodup ((a :: l).enum.map (·.1)) := by rw [enum_map_fst]; exact nodup_range _ + have m₁ : (0, a) ∈ mergeSort (enumLE le) ((a :: l).enum) := + mem_mergeSort.mpr (mem_cons_self _ _) + obtain ⟨l₁, l₂, h⟩ := append_of_mem m₁ + have s := mergeSort_sorted (enumLE_trans trans) (enumLE_total total) ((a :: l).enum) + rw [h] at s + have p := mergeSort_perm (enumLE le) ((a :: l).enum) + rw [h] at p + refine ⟨l₁.map (·.2), l₂.map (·.2), ?_, ?_, ?_⟩ + · simpa using congrArg (·.map (·.2)) h + · rw [← mergeSort_enum.go 1, ← map_append] + congr 1 + have q : mergeSort (enumLE le) (enumFrom 1 l) ~ l₁ ++ l₂ := + (mergeSort_perm (enumLE le) (enumFrom 1 l)).trans + (p.symm.trans perm_middle).cons_inv + apply Perm.eq_of_sorted (le := enumLE le) + · rintro ⟨i, a⟩ ⟨j, b⟩ ha hb + simp only [mem_mergeSort] at ha + simp only [← q.mem_iff, mem_mergeSort] at hb + simp only [enumLE] + simp only [Bool.if_false_right, Bool.and_eq_true, Prod.mk.injEq, and_imp] + intro ab h ba h' + simp only [Bool.decide_eq_true] at ba + replace h : i ≤ j := by simpa [ab, ba] using h + replace h' : j ≤ i := by simpa [ab, ba] using h' + cases Nat.le_antisymm h h' + constructor + · rfl + · have := mem_enumFrom ha + have := mem_enumFrom hb + simp_all + · exact mergeSort_sorted (enumLE_trans trans) (enumLE_total total) .. + · exact s.sublist ((sublist_cons_self (0, a) l₂).append_left l₁) + · exact q + · intro b m + simp only [mem_map, Prod.exists, exists_eq_right] at m + obtain ⟨j, m⟩ := m + replace p := p.map (·.1) + have nd' := nd.perm p.symm + rw [map_append] at nd' + have j0 := nd'.rel_of_mem_append + (mem_map_of_mem (·.1) m) (mem_map_of_mem _ (mem_cons_self _ _)) + simp only [ne_eq] at j0 + have r := s.rel_of_mem_append m (mem_cons_self _ _) + simp_all [enumLE] + +/-- +Another statement of stability of merge sort. +If `c` is a sorted sublist of `l`, +then `c` is still a sublist of `mergeSort le l`. +-/ +theorem mergeSort_stable + (trans : ∀ (a b c : α), le a b → le b c → le a c) + (total : ∀ (a b : α), !le a b → le b a) : + ∀ {c : List α} (_ : c.Pairwise le) (_ : c <+ l), + c <+ mergeSort le l + | _, _, .slnil => nil_sublist _ + | c, hc, @Sublist.cons _ _ l a h => by + obtain ⟨l₁, l₂, h₁, h₂, -⟩ := mergeSort_cons trans total a l + rw [h₁] + have h' := mergeSort_stable trans total hc h + rw [h₂] at h' + exact h'.middle a + | _, _, @Sublist.cons₂ _ l₁ l₂ a h => by + rename_i hc + obtain ⟨l₃, l₄, h₁, h₂, h₃⟩ := mergeSort_cons trans total a l₂ + rw [h₁] + have h' := mergeSort_stable trans total hc.tail h + rw [h₂] at h' + simp only [Bool.not_eq_true', tail_cons] at h₃ h' + exact + sublist_append_of_sublist_right (Sublist.cons₂ a + ((fun w => Sublist.of_sublist_append_right w h') fun b m₁ m₃ => + (Bool.eq_not_self true).mp ((rel_of_pairwise_cons hc m₁).symm.trans (h₃ b m₃)))) + +/-- +Another statement of stability of merge sort. +If a pair `[a, b]` is a sublist of `l` and `le a b`, +then `[a, b]` is still a sublist of `mergeSort le l`. +-/ +theorem mergeSort_stable_pair + (trans : ∀ (a b c : α), le a b → le b c → le a c) + (total : ∀ (a b : α), !le a b → le b a) + (hab : le a b) (h : [a, b] <+ l) : [a, b] <+ mergeSort le l := + mergeSort_stable trans total (pairwise_pair.mpr hab) h diff --git a/src/Init/Data/List/Sublist.lean b/src/Init/Data/List/Sublist.lean index b734e2c4d540..3c9033431162 100644 --- a/src/Init/Data/List/Sublist.lean +++ b/src/Init/Data/List/Sublist.lean @@ -398,6 +398,27 @@ theorem append_sublist_iff {l₁ l₂ : List α} : · rintro ⟨r₁, r₂, rfl, h₁, h₂⟩ exact Sublist.append h₁ h₂ +theorem Sublist.of_sublist_append_left (w : ∀ a, a ∈ l → a ∉ l₂) (h : l <+ l₁ ++ l₂) : l <+ l₁ := by + rw [sublist_append_iff] at h + obtain ⟨l₁', l₂', rfl, h₁, h₂⟩ := h + have : l₂' = [] := by + rw [eq_nil_iff_forall_not_mem] + exact fun x m => w x (mem_append_of_mem_right l₁' m) (h₂.mem m) + simp_all + +theorem Sublist.of_sublist_append_right (w : ∀ a, a ∈ l → a ∉ l₁) (h : l <+ l₁ ++ l₂) : l <+ l₂ := by + rw [sublist_append_iff] at h + obtain ⟨l₁', l₂', rfl, h₁, h₂⟩ := h + have : l₁' = [] := by + rw [eq_nil_iff_forall_not_mem] + exact fun x m => w x (mem_append_of_mem_left l₂' m) (h₁.mem m) + simp_all + +theorem Sublist.middle {l : List α} (h : l <+ l₁ ++ l₂) (a : α) : l <+ l₁ ++ a :: l₂ := by + rw [sublist_append_iff] at h + obtain ⟨l₁', l₂', rfl, h₁, h₂⟩ := h + exact Sublist.append h₁ (h₂.cons a) + theorem Sublist.reverse : l₁ <+ l₂ → l₁.reverse <+ l₂.reverse | .slnil => Sublist.refl _ | .cons _ h => by rw [reverse_cons]; exact sublist_append_of_sublist_left h.reverse diff --git a/src/Init/Data/List/TakeDrop.lean b/src/Init/Data/List/TakeDrop.lean index 0596f1bffe4a..8499ecbb245a 100644 --- a/src/Init/Data/List/TakeDrop.lean +++ b/src/Init/Data/List/TakeDrop.lean @@ -20,6 +20,11 @@ Further results on `List.take` and `List.drop`, which rely on stronger automatio are given in `Init.Data.List.TakeDrop`. -/ +theorem take_cons {l : List α} (h : 0 < n) : take n (a :: l) = a :: take (n - 1) l := by + cases n with + | zero => exact absurd h (Nat.lt_irrefl _) + | succ n => rfl + @[simp] theorem drop_one : ∀ l : List α, drop 1 l = tail l | [] | _ :: _ => rfl @@ -74,7 +79,7 @@ theorem drop_eq_get_cons {n} {l : List α} (h) : drop n l = get l ⟨n, h⟩ :: simp [drop_eq_getElem_cons] @[simp] -theorem getElem?_take {l : List α} {n m : Nat} (h : m < n) : (l.take n)[m]? = l[m]? := by +theorem getElem?_take_of_lt {l : List α} {n m : Nat} (h : m < n) : (l.take n)[m]? = l[m]? := by induction n generalizing l m with | zero => exact absurd h (Nat.not_lt_of_le m.zero_le) @@ -86,13 +91,13 @@ theorem getElem?_take {l : List α} {n m : Nat} (h : m < n) : (l.take n)[m]? = l · simp · simpa using hn (Nat.lt_of_succ_lt_succ h) -@[deprecated getElem?_take (since := "2024-06-12")] +@[deprecated getElem?_take_of_lt (since := "2024-06-12")] theorem get?_take {l : List α} {n m : Nat} (h : m < n) : (l.take n).get? m = l.get? m := by - simp [getElem?_take, h] + simp [getElem?_take_of_lt, h] @[simp] theorem getElem?_take_of_succ {l : List α} {n : Nat} : (l.take (n + 1))[n]? = l[n]? := - getElem?_take (Nat.lt_succ_self n) + getElem?_take_of_lt (Nat.lt_succ_self n) @[simp] theorem drop_drop (n : Nat) : ∀ (m) (l : List α), drop n (drop m l) = drop (n + m) l | m, [] => by simp @@ -433,6 +438,12 @@ theorem take_takeWhile {l : List α} (p : α → Bool) n : | nil => rfl | cons h t ih => by_cases p h <;> simp_all +/-! ### splitAt -/ + +@[simp] theorem splitAt_eq (n : Nat) (l : List α) : splitAt n l = (l.take n, l.drop n) := by + rw [splitAt, splitAt_go, reverse_nil, nil_append] + split <;> simp_all [take_of_length_le, drop_of_length_le] + /-! ### rotateLeft -/ @[simp] theorem rotateLeft_zero (l : List α) : rotateLeft l 0 = l := by diff --git a/tests/bench/mergeSort/.gitignore b/tests/bench/mergeSort/.gitignore new file mode 100644 index 000000000000..bfb30ec8c762 --- /dev/null +++ b/tests/bench/mergeSort/.gitignore @@ -0,0 +1 @@ +/.lake diff --git a/tests/bench/mergeSort/Bench.lean b/tests/bench/mergeSort/Bench.lean new file mode 100644 index 000000000000..0b9ac2ae803c --- /dev/null +++ b/tests/bench/mergeSort/Bench.lean @@ -0,0 +1,18 @@ +open List.MergeSort.Internal + +def main (args : List String) : IO Unit := do + let k := 5 + let some arg := args[0]? | throw <| IO.userError s!"specify length of test data in multiples of 10^{k}" + let some i := arg.toNat? | throw <| IO.userError s!"specify length of test data in multiples of 10^{k}" + let n := i * (10^k) + let i₁ := List.iota n + let i₂ := List.range n + let i₃ ← (List.range n).mapM (fun _ => IO.rand 0 1000) + let i₄ := (List.range (i * (10^(k-3)))).bind (fun k => (k * 1000 + 1) :: (k * 1000) :: List.range' (k * 1000 + 2) 998) + let start ← IO.monoMsNow + let o₁ := (mergeSortTR₂ (· ≤ ·) i₁).length == n + let o₂ := (mergeSortTR₂ (· ≤ ·) i₂).length == n + let o₃ := (mergeSortTR₂ (· ≤ ·) i₃).length == n + let o₄ := (mergeSortTR₂ (· ≤ ·) i₄).length == n + IO.println (((← IO.monoMsNow) - start)/4) + IO.Process.exit (if o₁ && o₂ && o₃ && o₄ then 0 else 1) diff --git a/tests/bench/mergeSort/README.md b/tests/bench/mergeSort/README.md new file mode 100644 index 000000000000..11b0881cbd69 --- /dev/null +++ b/tests/bench/mergeSort/README.md @@ -0,0 +1,15 @@ +# mergeSortBenchmark + +Benchmarking `List.mergeSort`. + +Run `lake exe mergeSort k` to run a benchmark on lists of size `k * 10^5`. +This reports the average time (in milliseconds) to sort: +* an already sorted list +* a reverse sorted list +* an almost sorted list +* and a random list with duplicates + +Run `python3 bench.py` to run this for `k = 1, .., 10`, and calculate a best fit +of the model `A * k + B * k * log k` to the observed runtimes. +(This isn't really what one should do: +fitting a log to data across a single order of magnitude is not helpful.) diff --git a/tests/bench/mergeSort/bench.py b/tests/bench/mergeSort/bench.py new file mode 100644 index 000000000000..37472026e29e --- /dev/null +++ b/tests/bench/mergeSort/bench.py @@ -0,0 +1,38 @@ +import subprocess +import numpy as np +from scipy.optimize import curve_fit +import matplotlib.pyplot as plt + +# Function to run the command and capture the elapsed time from stdout +def benchmark(i): + result = subprocess.run([f'./.lake/build/bin/mergeSort', str(i)], capture_output=True, text=True) + elapsed_time_ms = int(result.stdout.strip()) # Assuming the time is printed as a single integer in ms + return elapsed_time_ms / 1e3 # Convert milliseconds to seconds + +# Benchmark for i = 0.1, 0.2, ..., 1.0 with 5 runs each +i_values = [] +times = [] + +for i in range(1, 11): + run_times = sorted([benchmark(i) for _ in range(5)]) + middle_three_avg = np.mean(run_times[1:4]) # Take the average of the middle 3 times + times.append(middle_three_avg) + i_values.append(i / 1e1) + +# Fit the data to A*i + B*i*log(i) +def model(i, A, B): + return A * i + B * i * np.log(i) + +popt, _ = curve_fit(model, i_values, times) +A, B = popt + +# Print the fit parameters +print(f"Best fit parameters: A = {A}, B = {B}") + +# Plot the results +plt.plot(i_values, times, 'o', label='Benchmark Data (Avg of Middle 3)') +plt.plot(i_values, model(np.array(i_values), *popt), '-', label=f'Fit: A*i + B*i*log(i)\nA={A:.3f}, B={B:.3f}') +plt.xlabel('i') +plt.ylabel('Time (s)') +plt.legend() +plt.show() diff --git a/tests/bench/mergeSort/lakefile.lean b/tests/bench/mergeSort/lakefile.lean new file mode 100644 index 000000000000..c3f7c013208c --- /dev/null +++ b/tests/bench/mergeSort/lakefile.lean @@ -0,0 +1,9 @@ +import Lake +open Lake DSL + +package "mergeSortBenchmark" where + -- add package configuration options here + +@[default_target] +lean_exe "mergeSort" where + root := `Bench diff --git a/tests/bench/mergeSort/lean-toolchain b/tests/bench/mergeSort/lean-toolchain new file mode 100644 index 000000000000..dcca6df980de --- /dev/null +++ b/tests/bench/mergeSort/lean-toolchain @@ -0,0 +1 @@ +lean4 diff --git a/tests/lean/run/mergeSort.lean b/tests/lean/run/mergeSort.lean new file mode 100644 index 000000000000..f43d47f118e9 --- /dev/null +++ b/tests/lean/run/mergeSort.lean @@ -0,0 +1,25 @@ +open List MergeSort Internal + +unseal mergeSort merge in +example : mergeSort (· ≤ ·) [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := + rfl + +unseal mergeSort merge in +example : mergeSort (fun x y => x/10 ≤ y/10) [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] := + rfl + +unseal mergeSortTR.run mergeTR.go in +example : mergeSortTR (· ≤ ·) [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := + rfl + +unseal mergeSortTR.run mergeTR.go in +example : mergeSortTR (fun x y => x/10 ≤ y/10) [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] := + rfl + +unseal mergeSortTR₂.run mergeTR.go in +example : mergeSortTR₂ (· ≤ ·) [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] = [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9] := + rfl + +unseal mergeSortTR₂.run mergeTR.go in +example : mergeSortTR₂ (fun x y => x/10 ≤ y/10) [3, 100 + 1, 4, 100 + 1, 5, 100 + 9, 2, 10 + 6, 5, 10 + 3, 5] = [3, 4, 5, 2, 5, 5, 16, 13, 101, 101, 109] := + rfl diff --git a/tests/lean/run/mergeSortCPDT.lean b/tests/lean/run/mergeSortCPDT.lean index 67ba33912789..eb93b33aa53e 100644 --- a/tests/lean/run/mergeSortCPDT.lean +++ b/tests/lean/run/mergeSortCPDT.lean @@ -3,7 +3,7 @@ def List.insert' (p : α → α → Bool) (a : α) (bs : List α) : List α := | [] => [a] | b :: bs' => if p a b then a :: bs else b :: bs'.insert' p a -def List.merge (p : α → α → Bool) (as bs : List α) : List α := +def List.merge' (p : α → α → Bool) (as bs : List α) : List α := match as with | [] => bs | a :: as' => insert' p a (merge p as' bs) @@ -36,7 +36,7 @@ theorem List.length_split_of_atLeast2 {as : List α} (h : as.atLeast2) : as.spli have ⟨ih₁, ih₂⟩ := ih exact ⟨Nat.le_trans ih₁ (by simp_arith), Nat.le_trans ih₂ (by simp_arith)⟩ -def List.mergeSort (p : α → α → Bool) (as : List α) : List α := +def List.mergeSort' (p : α → α → Bool) (as : List α) : List α := if h : as.atLeast2 then match he:as.split with | (as', bs') => @@ -44,7 +44,7 @@ def List.mergeSort (p : α → α → Bool) (as : List α) : List α := have ⟨h₁, h₂⟩ := length_split_of_atLeast2 h have : as'.length < as.length := by simp [he] at h₁; assumption have : bs'.length < as.length := by simp [he] at h₂; assumption - merge p (mergeSort p as') (mergeSort p bs') + merge' p (mergeSort' p as') (mergeSort' p bs') else as termination_by as.length