From 504399b7b7264497e33f7b41f50d94f44873afff Mon Sep 17 00:00:00 2001
From: "Lei, HUANG" <mrsatangel@gmail.com>
Date: Mon, 6 May 2024 08:52:16 +0000
Subject: [PATCH] feat: use smallvec

---
 src/mito2/src/compaction/run.rs | 185 +++++---------------------------
 1 file changed, 28 insertions(+), 157 deletions(-)

diff --git a/src/mito2/src/compaction/run.rs b/src/mito2/src/compaction/run.rs
index 4730aa4be951..812ef47db6b2 100644
--- a/src/mito2/src/compaction/run.rs
+++ b/src/mito2/src/compaction/run.rs
@@ -16,11 +16,11 @@
 //! along with the best way to merge these items to satisfy the desired run count.
 
 use std::cmp::Ordering;
-use std::marker::PhantomData;
 
 use common_base::BitVec;
 use common_time::Timestamp;
 use itertools::Itertools;
+use smallvec::{smallvec, SmallVec};
 
 use crate::sst::file::FileHandle;
 
@@ -76,7 +76,7 @@ impl Item for FileHandle {
 
 #[derive(Debug, Clone)]
 struct MergeItems<T: Item> {
-    items: Vec<T>,
+    items: SmallVec<[T; 4]>,
     start: T::BoundType,
     end: T::BoundType,
     size: usize,
@@ -96,7 +96,7 @@ impl<T: Item> MergeItems<T> {
         let (start, end) = val.range();
         let size = val.size();
         Self {
-            items: vec![val],
+            items: smallvec![val],
             start,
             end,
             size,
@@ -114,7 +114,7 @@ impl<T: Item> MergeItems<T> {
         let end = self.end.max(other.end);
         let size = self.size + other.size;
 
-        let mut items = Vec::with_capacity(self.items.len() + other.items.len());
+        let mut items = SmallVec::with_capacity(self.items.len() + other.items.len());
         items.extend(self.items);
         items.extend(other.items);
         Self {
@@ -125,13 +125,6 @@ impl<T: Item> MergeItems<T> {
         }
     }
 
-    /// Returns the size of current item.
-    // If current item is merged from two or more items, then the size will be the sum of all
-    // items merged.
-    fn size(&self) -> usize {
-        self.size
-    }
-
     /// Returns true if current item is merged from two items.
     pub fn merged(&self) -> bool {
         self.items.len() > 1
@@ -171,143 +164,13 @@ where
 {
     fn push_item(&mut self, t: MergeItems<T>) {
         let (file_start, file_end) = t.range();
+        if t.merged() {
+            self.penalty += t.size;
+        }
         self.items.push(t);
         self.start = Some(self.start.map_or(file_start, |v| v.min(file_start)));
         self.end = Some(self.end.map_or(file_end, |v| v.max(file_end)));
     }
-
-    fn merge(self, other: Self) -> Self {
-        let (lhs, rhs) = if self.start < other.start {
-            (self, other)
-        } else {
-            (other, self)
-        };
-
-        #[derive(Default)]
-        struct Selection<T: Ranged> {
-            lhs_selection: BitVec,
-            rhs_selection: BitVec,
-            start: Option<T::BoundType>,
-            end: Option<T::BoundType>,
-            _phantom_data: PhantomData<T>,
-        }
-        impl<T: Ranged> Ranged for Selection<T> {
-            type BoundType = T::BoundType;
-
-            fn range(&self) -> (Self::BoundType, Self::BoundType) {
-                (self.start.unwrap(), self.end.unwrap())
-            }
-        }
-
-        impl<T: Ranged> Selection<T> {
-            fn new(lhs_size: usize, rhs_size: usize) -> Self {
-                Self {
-                    lhs_selection: BitVec::repeat(false, lhs_size),
-                    rhs_selection: BitVec::repeat(false, rhs_size),
-                    start: None,
-                    end: None,
-                    _phantom_data: Default::default(),
-                }
-            }
-
-            fn select_item(&mut self, lhs: bool, idx: usize, item: &T) {
-                let selection = if lhs {
-                    &mut self.lhs_selection
-                } else {
-                    &mut self.rhs_selection
-                };
-
-                selection.set(idx, true);
-                let (start, end) = item.range();
-                self.start = Some(self.start.map_or(start, |e| e.min(start)));
-                self.end = Some(self.end.map_or(end, |e| e.max(end)));
-            }
-        }
-
-        let mut overlapping_item: Vec<Selection<MergeItems<T>>> = vec![];
-        let mut current_overlapping: Option<Selection<MergeItems<T>>> = None;
-
-        let mut lhs_start_offset = None;
-        let mut lhs_remain = BitVec::repeat(true, lhs.items.len());
-
-        for (rhs_idx, rhs_item) in rhs.items.iter().enumerate() {
-            if let Some(current) = &current_overlapping {
-                // it's a new round
-                if !rhs_item.overlap(current) {
-                    overlapping_item.push(std::mem::take(&mut current_overlapping).unwrap())
-                }
-            }
-
-            for lhs_idx in lhs_start_offset.unwrap_or(0)..lhs.items.len() {
-                let lhs_item = &lhs.items[lhs_idx];
-                if !lhs_item.overlap(rhs_item) {
-                    continue;
-                }
-
-                let overlapping = current_overlapping
-                    .get_or_insert_with(|| Selection::new(lhs.items.len(), rhs.items.len()));
-                overlapping.select_item(true, lhs_idx, lhs_item);
-                overlapping.select_item(false, rhs_idx, rhs_item);
-                // lhs_item is selected in current overlapping, then it won't remain
-                lhs_remain.set(lhs_idx, false);
-                lhs_start_offset.get_or_insert(lhs_idx);
-            }
-        }
-
-        if let Some(o) = std::mem::take(&mut current_overlapping) {
-            overlapping_item.push(o);
-        }
-
-        let mut penalty = 0;
-        let mut result = SortedRun::default();
-
-        for overlapping in overlapping_item {
-            let mut item: Option<MergeItems<T>> = None;
-            for (selected, lhs_item) in overlapping
-                .lhs_selection
-                .iter()
-                .by_vals()
-                .zip(lhs.items.iter())
-            {
-                if selected {
-                    // lhs_item in current overlapping.
-                    penalty += lhs_item.size();
-                    item = Some(match item {
-                        None => lhs_item.clone(),
-                        Some(e) => e.merge(lhs_item.clone()),
-                    });
-                }
-            }
-
-            for (selected, rhs_item) in overlapping
-                .rhs_selection
-                .iter()
-                .by_vals()
-                .zip(rhs.items.iter())
-            {
-                if selected {
-                    penalty += rhs_item.size();
-                    item = Some(match item {
-                        None => rhs_item.clone(),
-                        Some(e) => e.merge(rhs_item.clone()),
-                    });
-                }
-            }
-            // safety: for each overlapping there must be at least one item.
-            result.push_item(item.unwrap());
-        }
-
-        for (remain, lhs_item) in lhs_remain.iter().by_vals().zip(lhs.items.into_iter()) {
-            if remain {
-                // lhs item remains unmerged
-                result.push_item(lhs_item);
-            }
-        }
-
-        sort_ranged_items(&mut result.items);
-        result.penalty = penalty;
-        result
-    }
 }
 
 /// Finds sorted runs in given items.
@@ -356,15 +219,26 @@ where
     runs
 }
 
-fn merge_all_runs<T: Item>(mut runs: Vec<SortedRun<T>>) -> SortedRun<T> {
+fn merge_all_runs<T: Item>(runs: Vec<SortedRun<T>>) -> SortedRun<T> {
     assert!(!runs.is_empty());
-    if runs.len() == 1 {
-        return runs.pop().unwrap();
-    }
-    let mut res = runs.pop().unwrap();
-    while let Some(next) = runs.pop() {
-        res = res.merge(next);
+    let mut all_items = runs
+        .into_iter()
+        .flat_map(|r| r.items.into_iter())
+        .collect::<Vec<_>>();
+
+    all_items.sort_unstable_by(|l, r| l.start.cmp(&r.start).then(l.end.cmp(&r.end).reverse()));
+
+    let mut res = SortedRun::default();
+    let mut current_item = all_items[0].clone();
+
+    for item in all_items.into_iter().skip(1) {
+        if current_item.overlap(&item) {
+            current_item = current_item.merge(item);
+        } else {
+            res.push_item(std::mem::replace(&mut current_item, item));
+        }
     }
+    res.push_item(current_item);
     res
 }
 
@@ -386,7 +260,7 @@ pub(crate) fn reduce_runs<T: Item>(runs: Vec<SortedRun<T>>, target: usize) -> Ve
         .items
         .into_iter()
         .filter(|m| m.merged()) // find all files to merge in that solution
-        .map(|m| m.items)
+        .map(|m| m.items.to_vec())
         .collect()
 }
 
@@ -501,11 +375,9 @@ mod tests {
         expected: &[Vec<(i64, i64)>],
     ) {
         let mut items = build_items(items);
-        let mut runs = find_sorted_runs(&mut items);
+        let runs = find_sorted_runs(&mut items);
         assert_eq!(2, runs.len());
-        let lhs = runs.pop().unwrap();
-        let rhs = runs.pop().unwrap();
-        let res = lhs.merge(rhs);
+        let res = merge_all_runs(runs);
         let penalty = res.penalty;
         let ranges = res
             .items
@@ -519,7 +391,6 @@ mod tests {
             })
             .collect::<Vec<_>>();
         assert_eq!(expected, &ranges);
-
         assert_eq!(expected_penalty, penalty);
     }