From 17f737022d933582cb16bdd329313c363594d791 Mon Sep 17 00:00:00 2001 From: Stepan Koltsov Date: Thu, 13 Jun 2024 17:11:58 +0100 Subject: [PATCH] all: improve perf of memchr fallback (v2) Resubmit of PR #151. That PR was reverted because it broke big endian implementation and CI did not catch it (see the revert PR #153 for details). Andrew, thank you for new test cases which made it easy to fix the issue. The fix is: ``` --- a/src/arch/all/memchr.rs +++ b/src/arch/all/memchr.rs @@ -1019,7 +1019,7 @@ fn find_zero_in_chunk(x: usize) -> Option { if cfg!(target_endian = "little") { lowest_zero_byte(x) } else { - highest_zero_byte(x) + Some(USIZE_BYTES - 1 - highest_zero_byte(x)?) } } @@ -1028,7 +1028,7 @@ fn rfind_zero_in_chunk(x: usize) -> Option { if cfg!(target_endian = "little") { highest_zero_byte(x) } else { - lowest_zero_byte(x) + Some(USIZE_BYTES - 1 - lowest_zero_byte(x)?) } } ``` Original description: Current generic ("all") implementation checks that a chunk (`usize`) contains a zero byte, and if it is, iterates over bytes of this chunk to find the index of zero byte. Instead, we can use more bit operations to find the index without loops. Context: we use `memchr`, but many of our strings are short. Currently SIMD-optimized `memchr` processes bytes one by one when the string length is shorter than SIMD register. I suspect it can be made faster if we take `usize` bytes a chunk which does not fit into SIMD register and process it with such utility, similarly to how AVX2 implementation falls back to SSE2. So I looked at generic implementation to reuse it in SIMD-optimized version, but there were none. So here is it. --- src/arch/all/memchr.rs | 347 ++++++++++++++++++++++++++++++++++------- 1 file changed, 291 insertions(+), 56 deletions(-) diff --git a/src/arch/all/memchr.rs b/src/arch/all/memchr.rs index 62fe2a3..7558615 100644 --- a/src/arch/all/memchr.rs +++ b/src/arch/all/memchr.rs @@ -141,8 +141,8 @@ impl One { // The start of the search may not be aligned to `*const usize`, // so we do an unaligned load here. let chunk = start.cast::().read_unaligned(); - if self.has_needle(chunk) { - return generic::fwd_byte_by_byte(start, end, confirm); + if let Some(index) = self.index_of_needle(chunk) { + return Some(start.add(index)); } // And now we start our search at a guaranteed aligned position. @@ -153,21 +153,33 @@ impl One { let mut cur = start.add(USIZE_BYTES - (start.as_usize() & USIZE_ALIGN)); debug_assert!(cur > start); - if len <= One::LOOP_BYTES { - return generic::fwd_byte_by_byte(cur, end, confirm); - } - debug_assert!(end.sub(One::LOOP_BYTES) >= start); - while cur <= end.sub(One::LOOP_BYTES) { + while end.distance(cur) >= One::LOOP_BYTES { debug_assert_eq!(0, cur.as_usize() % USIZE_BYTES); let a = cur.cast::().read(); let b = cur.add(USIZE_BYTES).cast::().read(); - if self.has_needle(a) || self.has_needle(b) { - break; + if let Some(index) = self.index_of_needle(a) { + return Some(cur.add(index)); + } + if let Some(index) = self.index_of_needle(b) { + return Some(cur.add(USIZE_BYTES + index)); } cur = cur.add(One::LOOP_BYTES); } - generic::fwd_byte_by_byte(cur, end, confirm) + if end.distance(cur) > USIZE_BYTES { + let chunk = cur.cast::().read(); + if let Some(index) = self.index_of_needle(chunk) { + return Some(cur.add(index)); + } + cur = cur.add(USIZE_BYTES); + } + debug_assert!(cur >= end.sub(USIZE_BYTES)); + cur = end.sub(USIZE_BYTES); + let chunk = cur.cast::().read_unaligned(); + if let Some(index) = self.index_of_needle(chunk) { + return Some(cur.add(index)); + } + None } /// Like `rfind`, but accepts and returns raw pointers. @@ -209,26 +221,39 @@ impl One { } let chunk = end.sub(USIZE_BYTES).cast::().read_unaligned(); - if self.has_needle(chunk) { - return generic::rev_byte_by_byte(start, end, confirm); + if let Some(index) = self.rindex_of_needle(chunk) { + return Some(end.sub(USIZE_BYTES).add(index)); } let mut cur = end.sub(end.as_usize() & USIZE_ALIGN); debug_assert!(start <= cur && cur <= end); - if len <= One::LOOP_BYTES { - return generic::rev_byte_by_byte(start, cur, confirm); - } - while cur >= start.add(One::LOOP_BYTES) { + while cur.distance(start) >= One::LOOP_BYTES { debug_assert_eq!(0, cur.as_usize() % USIZE_BYTES); let a = cur.sub(2 * USIZE_BYTES).cast::().read(); let b = cur.sub(1 * USIZE_BYTES).cast::().read(); - if self.has_needle(a) || self.has_needle(b) { - break; + if let Some(index) = self.rindex_of_needle(b) { + return Some(cur.sub(1 * USIZE_BYTES).add(index)); + } + if let Some(index) = self.rindex_of_needle(a) { + return Some(cur.sub(2 * USIZE_BYTES).add(index)); } cur = cur.sub(One::LOOP_BYTES); } - generic::rev_byte_by_byte(start, cur, confirm) + if cur > start.add(USIZE_BYTES) { + let chunk = cur.sub(USIZE_BYTES).cast::().read(); + if let Some(index) = self.rindex_of_needle(chunk) { + return Some(cur.sub(USIZE_BYTES).add(index)); + } + cur = cur.sub(USIZE_BYTES); + } + debug_assert!(start.add(USIZE_BYTES) >= cur); + cur = start; + let chunk = cur.cast::().read_unaligned(); + if let Some(index) = self.rindex_of_needle(chunk) { + return Some(cur.add(index)); + } + None } /// Counts all occurrences of this byte in the given haystack represented @@ -278,8 +303,13 @@ impl One { } #[inline(always)] - fn has_needle(&self, chunk: usize) -> bool { - has_zero_byte(self.v1 ^ chunk) + fn index_of_needle(&self, chunk: usize) -> Option { + find_zero_in_chunk(self.v1 ^ chunk) + } + + #[inline(always)] + fn rindex_of_needle(&self, chunk: usize) -> Option { + rfind_zero_in_chunk(self.v1 ^ chunk) } #[inline(always)] @@ -451,8 +481,8 @@ impl Two { // The start of the search may not be aligned to `*const usize`, // so we do an unaligned load here. let chunk = start.cast::().read_unaligned(); - if self.has_needle(chunk) { - return generic::fwd_byte_by_byte(start, end, confirm); + if let Some(index) = self.index_of_needle(chunk) { + return Some(start.add(index)); } // And now we start our search at a guaranteed aligned position. @@ -464,16 +494,22 @@ impl Two { start.add(USIZE_BYTES - (start.as_usize() & USIZE_ALIGN)); debug_assert!(cur > start); debug_assert!(end.sub(USIZE_BYTES) >= start); - while cur <= end.sub(USIZE_BYTES) { + while cur < end.sub(USIZE_BYTES) { debug_assert_eq!(0, cur.as_usize() % USIZE_BYTES); let chunk = cur.cast::().read(); - if self.has_needle(chunk) { - break; + if let Some(index) = self.index_of_needle(chunk) { + return Some(cur.add(index)); } cur = cur.add(USIZE_BYTES); } - generic::fwd_byte_by_byte(cur, end, confirm) + debug_assert!(cur >= end.sub(USIZE_BYTES) && cur <= end); + cur = end.sub(USIZE_BYTES); + let chunk = cur.cast::().read_unaligned(); + if let Some(index) = self.index_of_needle(chunk) { + return Some(cur.add(index)); + } + None } /// Like `rfind`, but accepts and returns raw pointers. @@ -515,22 +551,28 @@ impl Two { } let chunk = end.sub(USIZE_BYTES).cast::().read_unaligned(); - if self.has_needle(chunk) { - return generic::rev_byte_by_byte(start, end, confirm); + if let Some(index) = self.rindex_of_needle(chunk) { + return Some(end.sub(USIZE_BYTES).add(index)); } let mut cur = end.sub(end.as_usize() & USIZE_ALIGN); debug_assert!(start <= cur && cur <= end); - while cur >= start.add(USIZE_BYTES) { + while cur > start.add(USIZE_BYTES) { debug_assert_eq!(0, cur.as_usize() % USIZE_BYTES); let chunk = cur.sub(USIZE_BYTES).cast::().read(); - if self.has_needle(chunk) { - break; + if let Some(index) = self.rindex_of_needle(chunk) { + return Some(cur.sub(USIZE_BYTES).add(index)); } cur = cur.sub(USIZE_BYTES); } - generic::rev_byte_by_byte(start, cur, confirm) + debug_assert!(cur >= start && start.add(USIZE_BYTES) >= cur); + cur = start; + let chunk = cur.cast::().read_unaligned(); + if let Some(index) = self.rindex_of_needle(chunk) { + return Some(cur.add(index)); + } + None } /// Returns an iterator over all occurrences of one of the needle bytes in @@ -543,8 +585,29 @@ impl Two { } #[inline(always)] - fn has_needle(&self, chunk: usize) -> bool { - has_zero_byte(self.v1 ^ chunk) || has_zero_byte(self.v2 ^ chunk) + fn index_of_needle(&self, chunk: usize) -> Option { + match ( + find_zero_in_chunk(self.v1 ^ chunk), + find_zero_in_chunk(self.v2 ^ chunk), + ) { + (Some(a), Some(b)) => Some(a.min(b)), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + } + } + + #[inline(always)] + fn rindex_of_needle(&self, chunk: usize) -> Option { + match ( + rfind_zero_in_chunk(self.v1 ^ chunk), + rfind_zero_in_chunk(self.v2 ^ chunk), + ) { + (Some(a), Some(b)) => Some(a.max(b)), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + } } #[inline(always)] @@ -715,8 +778,8 @@ impl Three { // The start of the search may not be aligned to `*const usize`, // so we do an unaligned load here. let chunk = start.cast::().read_unaligned(); - if self.has_needle(chunk) { - return generic::fwd_byte_by_byte(start, end, confirm); + if let Some(index) = self.index_of_needle(chunk) { + return Some(start.add(index)); } // And now we start our search at a guaranteed aligned position. @@ -728,16 +791,22 @@ impl Three { start.add(USIZE_BYTES - (start.as_usize() & USIZE_ALIGN)); debug_assert!(cur > start); debug_assert!(end.sub(USIZE_BYTES) >= start); - while cur <= end.sub(USIZE_BYTES) { + while cur < end.sub(USIZE_BYTES) { debug_assert_eq!(0, cur.as_usize() % USIZE_BYTES); let chunk = cur.cast::().read(); - if self.has_needle(chunk) { - break; + if let Some(index) = self.index_of_needle(chunk) { + return Some(cur.add(index)); } cur = cur.add(USIZE_BYTES); } - generic::fwd_byte_by_byte(cur, end, confirm) + debug_assert!(cur >= end.sub(USIZE_BYTES) && cur <= end); + cur = end.sub(USIZE_BYTES); + let chunk = cur.cast::().read_unaligned(); + if let Some(index) = self.index_of_needle(chunk) { + return Some(cur.add(index)); + } + None } /// Like `rfind`, but accepts and returns raw pointers. @@ -779,22 +848,28 @@ impl Three { } let chunk = end.sub(USIZE_BYTES).cast::().read_unaligned(); - if self.has_needle(chunk) { - return generic::rev_byte_by_byte(start, end, confirm); + if let Some(index) = self.rindex_of_needle(chunk) { + return Some(end.sub(USIZE_BYTES).add(index)); } let mut cur = end.sub(end.as_usize() & USIZE_ALIGN); debug_assert!(start <= cur && cur <= end); - while cur >= start.add(USIZE_BYTES) { + while cur > start.add(USIZE_BYTES) { debug_assert_eq!(0, cur.as_usize() % USIZE_BYTES); let chunk = cur.sub(USIZE_BYTES).cast::().read(); - if self.has_needle(chunk) { - break; + if let Some(index) = self.rindex_of_needle(chunk) { + return Some(cur.sub(USIZE_BYTES).add(index)); } cur = cur.sub(USIZE_BYTES); } - generic::rev_byte_by_byte(start, cur, confirm) + debug_assert!(cur >= start && start.add(USIZE_BYTES) >= cur); + cur = start; + let chunk = cur.cast::().read_unaligned(); + if let Some(index) = self.rindex_of_needle(chunk) { + return Some(cur.add(index)); + } + None } /// Returns an iterator over all occurrences of one of the needle bytes in @@ -807,10 +882,45 @@ impl Three { } #[inline(always)] - fn has_needle(&self, chunk: usize) -> bool { - has_zero_byte(self.v1 ^ chunk) - || has_zero_byte(self.v2 ^ chunk) - || has_zero_byte(self.v3 ^ chunk) + fn index_of_needle(&self, chunk: usize) -> Option { + #[inline(always)] + fn min_index(a: Option, b: Option) -> Option { + match (a, b) { + (Some(a), Some(b)) => Some(a.min(b)), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + } + } + + min_index( + min_index( + find_zero_in_chunk(self.v1 ^ chunk), + find_zero_in_chunk(self.v2 ^ chunk), + ), + find_zero_in_chunk(self.v3 ^ chunk), + ) + } + + #[inline(always)] + fn rindex_of_needle(&self, chunk: usize) -> Option { + #[inline(always)] + fn max_index(a: Option, b: Option) -> Option { + match (a, b) { + (Some(a), Some(b)) => Some(a.max(b)), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + } + } + + max_index( + max_index( + rfind_zero_in_chunk(self.v1 ^ chunk), + rfind_zero_in_chunk(self.v2 ^ chunk), + ), + rfind_zero_in_chunk(self.v3 ^ chunk), + ) } #[inline(always)] @@ -867,21 +977,59 @@ impl<'a, 'h> DoubleEndedIterator for ThreeIter<'a, 'h> { } } -/// Return `true` if `x` contains any zero byte. +/// Return the index of the least significant zero byte in `x`. /// /// That is, this routine treats `x` as a register of 8-bit lanes and returns -/// true when any of those lanes is `0`. +/// the index of the least significant lane that is `0`. /// -/// From "Matters Computational" by J. Arndt. +/// Based on "Matters Computational" by J. Arndt. #[inline(always)] -fn has_zero_byte(x: usize) -> bool { +fn lowest_zero_byte(x: usize) -> Option { // "The idea is to subtract one from each of the bytes and then look for // bytes where the borrow propagated all the way to the most significant // bit." const LO: usize = splat(0x01); const HI: usize = splat(0x80); - (x.wrapping_sub(LO) & !x & HI) != 0 + let y = x.wrapping_sub(LO) & !x & HI; + if y == 0 { + None + } else { + Some(y.trailing_zeros() as usize / 8) + } +} + +/// Return the index of the most significant zero byte in `x`. +/// +/// That is, this routine treats `x` as a register of 8-bit lanes and returns +/// the index of the most significant lane that is `0`. +/// +/// Based on "Hacker's Delight" by Henry S. Warren. +#[inline(always)] +fn highest_zero_byte(x: usize) -> Option { + const SEVEN_F: usize = splat(0x7F); + + let y = (x & SEVEN_F).wrapping_add(SEVEN_F); + let y = !(y | x | SEVEN_F); + (USIZE_BYTES - 1).checked_sub(y.leading_zeros() as usize / 8) +} + +#[inline(always)] +fn find_zero_in_chunk(x: usize) -> Option { + if cfg!(target_endian = "little") { + lowest_zero_byte(x) + } else { + Some(USIZE_BYTES - 1 - highest_zero_byte(x)?) + } +} + +#[inline(always)] +fn rfind_zero_in_chunk(x: usize) -> Option { + if cfg!(target_endian = "little") { + highest_zero_byte(x) + } else { + Some(USIZE_BYTES - 1 - lowest_zero_byte(x)?) + } } /// Repeat the given byte into a word size number. That is, every 8 bits @@ -897,6 +1045,7 @@ const fn splat(b: u8) -> usize { #[cfg(test)] mod tests { use super::*; + use std::cfg; define_memchr_quickcheck!(super, try_new); @@ -1019,4 +1168,90 @@ mod tests { let data = [0, 0, 0, 0, 0, 0, 0, 0]; assert_eq!(One::new(b'\x00').find(&data), Some(0)); } + + /// Generate 500K values. + fn special_values() -> impl Iterator { + fn all_bytes() -> impl Iterator { + 0..=0xff + } + + fn some_bytes() -> impl Iterator { + [0x00, 0x01, 0x02, 0x10, 0x11, 0x8f, 0xff].into_iter() + } + + all_bytes().flat_map(move |first_byte| { + some_bytes().flat_map(move |middle_byte| { + all_bytes().map(move |last_byte| { + splat(middle_byte) & !0xff & !(0xff << (usize::BITS - 8)) + | ((first_byte as usize) << (usize::BITS - 8)) + | (last_byte as usize) + }) + }) + }) + } + + fn lowest_zero_byte_simple(value: usize) -> Option { + value.to_le_bytes().iter().position(|&b| b == 0) + } + + fn highest_zero_byte_simple(value: usize) -> Option { + value.to_le_bytes().iter().rposition(|&b| b == 0) + } + + #[test] + fn test_lowest_zero_byte() { + assert_eq!(Some(0), lowest_zero_byte(0x00000000)); + assert_eq!(Some(0), lowest_zero_byte(0x01000000)); + assert_eq!(Some(1), lowest_zero_byte(0x00000001)); + assert_eq!(Some(1), lowest_zero_byte(0x00000010)); + assert_eq!(Some(1), lowest_zero_byte(0x00220010)); + assert_eq!(Some(1), lowest_zero_byte(0xff220010)); + assert_eq!(Some(USIZE_BYTES - 1), lowest_zero_byte(usize::MAX >> 8)); + assert_eq!(Some(USIZE_BYTES - 1), lowest_zero_byte(usize::MAX >> 9)); + assert_eq!(Some(USIZE_BYTES - 2), lowest_zero_byte(usize::MAX >> 16)); + assert_eq!(None, lowest_zero_byte(usize::MAX >> 7)); + assert_eq!(None, lowest_zero_byte(usize::MAX)); + } + + #[test] + fn test_highest_zero_byte() { + assert_eq!(Some(USIZE_BYTES - 1), highest_zero_byte(0x00000000)); + assert_eq!(Some(USIZE_BYTES - 1), highest_zero_byte(0x00345678)); + assert_eq!(Some(USIZE_BYTES - 1), highest_zero_byte(usize::MAX >> 8)); + assert_eq!(Some(USIZE_BYTES - 1), highest_zero_byte(usize::MAX >> 9)); + assert_eq!(Some(USIZE_BYTES - 1), highest_zero_byte(usize::MAX >> 9)); + assert_eq!( + Some(USIZE_BYTES - 1), + highest_zero_byte((usize::MAX >> 9) & !0xff) + ); + assert_eq!(None, highest_zero_byte(usize::MAX >> 3)); + } + + #[test] + fn test_lowest_zero_bytes_special_values() { + if cfg!(miri) { + return; + } + + for value in special_values() { + assert_eq!( + lowest_zero_byte_simple(value), + lowest_zero_byte(value) + ); + } + } + + #[test] + fn test_highest_zero_bytes_special_values() { + if cfg!(miri) { + return; + } + + for value in special_values() { + assert_eq!( + highest_zero_byte_simple(value), + highest_zero_byte(value) + ); + } + } }