diff --git a/nekolib-src/ds/rs01_dict/src/lib.rs b/nekolib-src/ds/rs01_dict/src/lib.rs new file mode 100644 index 0000000000..ac0cccdbc9 --- /dev/null +++ b/nekolib-src/ds/rs01_dict/src/lib.rs @@ -0,0 +1,489 @@ +use std::ops::{Range, RangeInclusive}; + +const W: usize = u64::BITS as usize; + +const RANK_LARGE_LEN: usize = 1024; // (1/4) log(n)^2 +const RANK_SMALL_LEN: usize = 16; // (1/2) log(n)/2 +const RANK_BIT_PATTERNS: usize = 1 << RANK_SMALL_LEN; + +const SELECT_SMALL_LEN: usize = 16; // (1/2) log(n)/2 +const SELECT_LARGE_SPARSE_LEN: usize = 12946; +const SELECT_LARGE_POPCNT: usize = 60; +const SELECT_LARGE_NODE_LEN: usize = 4; +const SELECT_LARGE_BRANCH: usize = 4; +const SELECT_WORD_BIT_PATTERNS: usize = 1 << SELECT_SMALL_LEN; +const SELECT_TREE_BIT_PATTERNS: usize = + 1 << (SELECT_LARGE_NODE_LEN * SELECT_LARGE_BRANCH); + +const ASSERTION: () = { + let popcnt = SELECT_LARGE_POPCNT; + let node_len = SELECT_LARGE_NODE_LEN; + let branch = SELECT_LARGE_BRANCH; + + let node_popcnt = !(!0 << node_len); + if node_popcnt * branch >= popcnt { + () // ok + } else { + panic!(); + } +}; + +pub struct Rs01DictGenerics< + const RANK_LARGE_LEN: usize, + const RANK_SMALL_LEN: usize, + const RANK_BIT_PATTERNS: usize, + const SELECT_SMALL_LEN: usize, + const SELECT_LARGE_SPARSE_LEN: usize, + const SELECT_LARGE_POPCNT: usize, + const SELECT_LARGE_NODE_LEN: usize, + const SELECT_LARGE_BRANCH: usize, + const SELECT_WORD_BIT_PATTERNS: usize, + const SELECT_TREE_BIT_PATTERNS: usize, +> { + buf: SimpleBitVec, + rank_index: RankIndex, + select1_index: SelectIndex< + SELECT_SMALL_LEN, + SELECT_LARGE_SPARSE_LEN, + SELECT_LARGE_POPCNT, + SELECT_LARGE_NODE_LEN, + SELECT_LARGE_BRANCH, + SELECT_WORD_BIT_PATTERNS, + SELECT_TREE_BIT_PATTERNS, + >, + select0_index: SelectIndex< + SELECT_SMALL_LEN, + SELECT_LARGE_SPARSE_LEN, + SELECT_LARGE_POPCNT, + SELECT_LARGE_NODE_LEN, + SELECT_LARGE_BRANCH, + SELECT_WORD_BIT_PATTERNS, + SELECT_TREE_BIT_PATTERNS, + >, +} + +struct SimpleBitVec { + buf: Vec, + len: usize, +} + +struct RankIndex< + const LARGE_LEN: usize, + const SMALL_LEN: usize, + const BIT_PATTERNS: usize, +> { + large: Vec, + small: Vec, +} + +struct SelectIndex< + const SMALL_LEN: usize, + const LARGE_SPARSE_LEN: usize, + const LARGE_POPCNT: usize, + const LARGE_NODE_LEN: usize, + const LARGE_BRANCH: usize, + const WORD_BIT_PATTERNS: usize, + const TREE_BIT_PATTERNS: usize, +> { + inner: Vec< + SelectIndexInner< + SMALL_LEN, + LARGE_SPARSE_LEN, + LARGE_POPCNT, + LARGE_NODE_LEN, + LARGE_BRANCH, + WORD_BIT_PATTERNS, + TREE_BIT_PATTERNS, + >, + >, +} + +enum SelectIndexInner< + const SMALL_LEN: usize, + const LARGE_SPARSE_LEN: usize, + const LARGE_POPCNT: usize, + const LARGE_NODE_LEN: usize, + const LARGE_BRANCH: usize, + const WORD_BIT_PATTERNS: usize, + const TREE_BIT_PATTERNS: usize, +> { + Sparse(Vec), + Dense(Vec, usize), +} + +trait RankLookup { + const WORD: [[u8; SMALL_LEN]; BIT_PATTERNS]; +} + +trait SelectLookup< + const NODE_LEN: usize, + const POPCNT: usize, + const TREE_BIT_PATTERNS: usize, + const LEAF_LEN: usize, + const WORD_BIT_PATTERNS: usize, +> +{ + const TREE: [[(u8, u8); POPCNT]; TREE_BIT_PATTERNS]; + const WORD: [[u8; LEAF_LEN]; WORD_BIT_PATTERNS]; +} + +const fn rank_lookup() +-> [[u8; SMALL_LEN]; BIT_PATTERNS] { + let mut table = [[0; SMALL_LEN]; BIT_PATTERNS]; + let mut i = 0; + while i < BIT_PATTERNS { + table[i][0] = (i & 1) as _; + let mut j = 1; + while j < SMALL_LEN { + table[i][j] = table[i][j - 1] + (i >> j & 1) as u8; + j += 1; + } + i += 1; + } + table +} + +const fn select_tree_lookup< + const NODE_LEN: usize, + const POPCNT: usize, + const BRANCH: usize, + const BIT_PATTERNS: usize, +>() -> [[(u8, u8); POPCNT]; BIT_PATTERNS] { + let mut table = [[(0, 0); POPCNT]; BIT_PATTERNS]; + let mut i = 0; + while i < BIT_PATTERNS { + let mut j = 0; + let mut index = 0; + while j < BRANCH { + // [011, 100, 010] (0b_010_100_011) + // [0, 0, 0, 1, 1, 1, 1, 2, 2] + let count = i >> (j * NODE_LEN) & !(!0 << NODE_LEN); + let mut k = 0; + while k < count && index < POPCNT { + table[i][index] = (j as _, (index - k) as _); + index += 1; + k += 1; + } + j += 1; + } + i += 1; + } + table +} + +const fn select_word_lookup< + const SMALL_LEN: usize, + const BIT_PATTERNS: usize, +>() -> [[u8; SMALL_LEN]; BIT_PATTERNS] { + let mut table = [[0; SMALL_LEN]; BIT_PATTERNS]; + let mut i = 0; + while i < BIT_PATTERNS { + let mut j = 0; + let mut count = 0; + while j < SMALL_LEN { + if i >> j & 1 != 0 { + table[i][count] = j as _; + count += 1; + } + j += 1; + } + i += 1; + } + table +} + +impl From<(Vec, usize)> for SimpleBitVec { + fn from((buf, len): (Vec, usize)) -> Self { Self { buf, len } } +} + +impl From<&[bool]> for SimpleBitVec { + fn from(a: &[bool]) -> Self { + let len = a.len(); + let n = (len + W - 1) / W; + let mut buf = vec![0; n]; + for i in 0..len { + if a[i] { + buf[i / W] |= 1 << (i % W); + } + } + Self { buf, len } + } +} + +impl SimpleBitVec { + fn new() -> Self { Self { buf: vec![], len: 0 } } + + fn len(&self) -> usize { self.len } + + fn get_single(&self, i: usize) -> bool { + debug_assert!(i < self.len); + self.buf[i / W] >> (i % W) != 0 + } + + fn get(&self, Range { start, end }: Range) -> u64 { + debug_assert!(end - start <= 64); + debug_assert!(end <= self.len); + + let mask = !(!0 << (end - start)); + let res = if start == end { + 0 + } else if start % W == 0 { + self.buf[start / W] & mask + } else if end <= (start / W + 1) * W { + self.buf[start / W] >> (start % W) + } else { + self.buf[start / W] >> (start % W) + | self.buf[end / W] << (W - start % W) + }; + (if X { res } else { !res }) & mask + } + + fn push(&mut self, w: u64, len: usize) { + assert_eq!(w & (!0 << len), 0); + + if len == 0 { + // nothing to do + } else if self.len % W == 0 { + // including the case `self.buf.is_empty()` + self.buf.push(w); + } else { + self.buf[self.len / W] |= w << (self.len % W); + if self.len % W + len > W { + self.buf.push(w >> (W - self.len % W)); + } + } + self.len += len; + } + + fn chunks( + &self, + size: usize, + ) -> impl Iterator + '_ { + (0..self.len) + .step_by(size) + .map(move |i| self.get::(i..self.len.min(i + size))) + } +} + +impl + RankLookup + for RankIndex +{ + const WORD: [[u8; SMALL_LEN]; BIT_PATTERNS] = + rank_lookup::(); +} + +impl + RankIndex +{ + fn new(a: &SimpleBitVec) -> Self { + let mut small = vec![]; + let mut large = vec![]; + let mut small_acc = 0; + let mut large_acc = 0; + let per = LARGE_LEN / SMALL_LEN; + for (c, i) in a + .chunks::(SMALL_LEN) + .map(|ai| Self::WORD[ai as usize][SMALL_LEN - 1]) + .zip((0..per).cycle()) + { + small.push(c); + if i == per - 1 { + small_acc = 0; + } else { + small_acc += c; + } + + if i == 0 { + large.push(large_acc); + } + large_acc += c as u32; + } + + Self { large, small } + } + + fn rank(&self, n: usize, b: &SimpleBitVec) -> usize { + let large_acc = self.large[n / LARGE_LEN] as usize; + let small_acc = self.small[n / SMALL_LEN] as usize; + let i = n / SMALL_LEN * SMALL_LEN; + let w = b.get::(i..i + SMALL_LEN); + let small = Self::WORD[w as usize][n % SMALL_LEN] as usize; + large_acc + small_acc + small + } +} + +impl< + const SMALL_LEN: usize, + const LARGE_SPARSE_LEN: usize, + const LARGE_POPCNT: usize, + const LARGE_NODE_LEN: usize, + const LARGE_BRANCH: usize, + const WORD_BIT_PATTERNS: usize, + const TREE_BIT_PATTERNS: usize, +> + SelectIndex< + SMALL_LEN, + LARGE_SPARSE_LEN, + LARGE_POPCNT, + LARGE_NODE_LEN, + LARGE_BRANCH, + WORD_BIT_PATTERNS, + TREE_BIT_PATTERNS, + > +{ + fn new(b: &SimpleBitVec) -> Self { + let n = b.len(); + + let mut cur = vec![]; + let mut res = vec![]; + let mut start = 0; + for i in 0..n { + if b.get_single(i) == X { + cur.push(i); + } + if cur.len() >= LARGE_POPCNT || i == n - 1 { + let tmp = std::mem::take(&mut cur); + res.push(SelectIndexInner::new::(tmp, start..=i, b)); + start = i + 1 + } + } + Self { inner: res } + } + + fn select(&self, i: usize, b: &SimpleBitVec) -> usize { + self.inner[i / LARGE_POPCNT].select::(i % LARGE_POPCNT, b) + } +} + +impl< + const SMALL_LEN: usize, + const LARGE_SPARSE_LEN: usize, + const LARGE_POPCNT: usize, + const LARGE_NODE_LEN: usize, + const LARGE_BRANCH: usize, + const WORD_BIT_PATTERNS: usize, + const TREE_BIT_PATTERNS: usize, +> + SelectIndexInner< + SMALL_LEN, + LARGE_SPARSE_LEN, + LARGE_POPCNT, + LARGE_NODE_LEN, + LARGE_BRANCH, + WORD_BIT_PATTERNS, + TREE_BIT_PATTERNS, + > +{ + fn new( + a: Vec, + range: RangeInclusive, + b: &SimpleBitVec, + ) -> Self { + let start = *range.start(); + let end = range.end() + 1; + if end - start >= LARGE_SPARSE_LEN { + Self::Sparse(a) + } else { + Self::new_dense::(b, start..end) + } + } + + fn new_dense( + b: &SimpleBitVec, + Range { start, end }: Range, + ) -> Self { + let rl = &RankIndex::::WORD; + let len = end - start; + + let leaf = { + let mut leaf = SimpleBitVec::new(); + for i in 0..(len + SMALL_LEN - 1) / SMALL_LEN { + let il = start + i * SMALL_LEN; + let ir = len.min(il + SMALL_LEN); + let w = b.get::(il..ir); + leaf.push(rl[w as usize][SMALL_LEN - 1] as u64, SMALL_LEN); + } + leaf + }; + + let mut nodes = leaf.len() / SMALL_LEN; + let mut tree = vec![]; + let mut last = leaf; + let mut width = SMALL_LEN; + while nodes / LARGE_BRANCH > 1 { + let mut cur = SimpleBitVec::new(); + let tmp = last; + { + let mut it = tmp.chunks::(width); + while let Some(mut sum) = it.next() { + sum += (1..LARGE_BRANCH) + .filter_map(|_| it.next()) + .sum::(); + cur.push(sum, LARGE_NODE_LEN); + } + } + tree.push(tmp); + last = cur; + width = LARGE_NODE_LEN; + nodes /= LARGE_BRANCH; + } + tree.push(last); + tree.reverse(); + Self::Dense(tree, start) + } + + fn select(&self, i: usize, b: &SimpleBitVec) -> usize { + match self { + Self::Sparse(index) => index[i], + Self::Dense(tree, start) => { + // 葉ノードだけ一つあたりのビット長が異なるので注意が必要。 + // フィールドを分けた方がよい? ちょっと考えた方がいいかも。 + todo!() + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::*; + + #[test] + fn test_rank_lookup() { + let table = rank_lookup::<3, 8>(); + + assert_eq!(&table[0b000][..3], [0, 0, 0]); + assert_eq!(&table[0b100][..3], [0, 0, 1]); + assert_eq!(&table[0b010][..3], [0, 1, 1]); + assert_eq!(&table[0b110][..3], [0, 1, 2]); + assert_eq!(&table[0b001][..3], [1, 1, 1]); + assert_eq!(&table[0b101][..3], [1, 1, 2]); + assert_eq!(&table[0b011][..3], [1, 2, 2]); + assert_eq!(&table[0b111][..3], [1, 2, 3]); + } + + #[test] + #[cfg(any())] + fn test_select_tree_lookup() { + let table = select_tree_lookup::<3, 9, 3, 512>(); + // [3, 4, 2] + let tmp: [_; 9] = table[0b_010_100_011][..9].try_into().unwrap(); + + assert_eq!(tmp.map(|x| x.0), [0, 0, 0, 1, 1, 1, 1, 2, 2]); + assert_eq!(tmp.map(|x| x.1), [0, 0, 0, 3, 3, 3, 3, 7, 7]); + } + + #[test] + fn test_select_word_lookup() { + let table = select_word_lookup::<3, 8>(); + + assert_eq!(&table[0b001][..1], [0]); + assert_eq!(&table[0b010][..1], [1]); + assert_eq!(&table[0b011][..2], [0, 1]); + assert_eq!(&table[0b100][..1], [2]); + assert_eq!(&table[0b101][..2], [0, 2]); + assert_eq!(&table[0b110][..2], [1, 2]); + assert_eq!(&table[0b111][..3], [0, 1, 2]); + } +}