diff --git a/nekolib-src/ds/rs01_dict/src/lib.rs b/nekolib-src/ds/rs01_dict/src/lib.rs index 76a5a4976d..3dd2f549f4 100644 --- a/nekolib-src/ds/rs01_dict/src/lib.rs +++ b/nekolib-src/ds/rs01_dict/src/lib.rs @@ -20,14 +20,16 @@ const POW2_SMALL: usize = 1 << SMALL; const RANK_LOOKUP: [[u16; SMALL]; POW2_SMALL] = rank_lookup::(); -const fn rank_lookup() --> [[u16; SMALL]; POW2_SMALL] { - let mut table = [[0; SMALL]; POW2_SMALL]; +const fn rank_lookup< + const MAX_LEN: usize, // log(n)/2 + const BIT_PATTERNS: usize, // sqrt(n) +>() -> [[u16; MAX_LEN]; BIT_PATTERNS] { + let mut table = [[0; MAX_LEN]; BIT_PATTERNS]; let mut i = 0; - while i < (1 << SMALL) { + while i < BIT_PATTERNS { table[i][0] = (i & 1) as u16; let mut j = 1; - while j < SMALL { + while j < MAX_LEN { table[i][j] = table[i][j - 1] + (i >> j & 1) as u16; j += 1; } @@ -36,6 +38,34 @@ const fn rank_lookup() table } +const fn select_lookup< + const BIT_PATTERNS: usize, // 2^(branch * large) + const BRANCH: usize, // sqrt(log(n)) + const MAX_ONES: usize, // log(n)^2 + const LG2_MAX_ONES: usize, // O(log(log(n))) +>() -> [[u16; MAX_ONES]; BIT_PATTERNS] { + let mut table = [[0; MAX_ONES]; BIT_PATTERNS]; + let mut i = 0; + while i < BIT_PATTERNS { + let mut j = 0; + let mut index = 0; + while j < BRANCH { + // [0011, 0100, 0010] (0b_0010_0100_0011) + // [0, 0, 0, 1, 1, 1, 1, 2, 2, 3, ...] + let count = i >> (j * LG2_MAX_ONES) & !(!0 << LG2_MAX_ONES); + let mut k = 0; + while k < count && index < MAX_ONES { + table[i][index] = j as u16; + index += 1; + k += 1; + } + j += 1; + } + i += 1; + } + table +} + struct RankPreprocess { buf: Vec, @@ -146,6 +176,12 @@ fn test_rank_lookup() { assert_eq!(&table[0b111][0..3], [1, 2, 3]); } +#[test] +fn test_select_lookup() { + let table = select_lookup::<4096, 3, 16, 4>(); + assert_eq!(&table[0b_0010_0100_0011][0..9], [0, 0, 0, 1, 1, 1, 1, 2, 2]); +} + #[test] fn sanity_check() { let a = bitvec!(b"000 010 110 000; 111 001 000 011; 000 000 010 010");