From 3e52b7de8451603c3aaf31d522572c5259787930 Mon Sep 17 00:00:00 2001 From: Redzic Date: Sat, 2 Jul 2022 00:28:49 -0500 Subject: [PATCH 1/3] Add initial aarch64 neon support --- build.rs | 8 +- src/memchr/aarch64/mod.rs | 79 +++++++ src/memchr/aarch64/neon.rs | 432 +++++++++++++++++++++++++++++++++++++ src/memchr/mod.rs | 44 ++++ src/memchr/x86/mod.rs | 2 +- 5 files changed, 563 insertions(+), 2 deletions(-) create mode 100644 src/memchr/aarch64/mod.rs create mode 100644 src/memchr/aarch64/neon.rs diff --git a/build.rs b/build.rs index 584a608..761fa0f 100644 --- a/build.rs +++ b/build.rs @@ -10,7 +10,7 @@ fn main() { // This can be disabled with RUSTFLAGS="--cfg memchr_disable_auto_simd", but // this is generally only intended for testing. // -// On targets which don't feature SSE2, this is disabled, as LLVM wouln't know +// On targets which don't feature SSE2, this is disabled, as LLVM wouldn't know // how to work with SSE2 operands. Enabling SSE4.2 and AVX on SSE2-only targets // is not a problem. In that case, the fastest option will be chosen at // runtime. @@ -29,6 +29,12 @@ fn enable_simd_optimizations() { println!("cargo:rustc-cfg=memchr_runtime_sse42"); println!("cargo:rustc-cfg=memchr_runtime_avx"); } + "aarch64" => { + if !target_has_feature("neon") { + return; + } + println!("cargo:rustc-cfg=memchr_runtime_neon"); + } "wasm32" | "wasm64" => { if !target_has_feature("simd128") { return; diff --git a/src/memchr/aarch64/mod.rs b/src/memchr/aarch64/mod.rs new file mode 100644 index 0000000..775f2b7 --- /dev/null +++ b/src/memchr/aarch64/mod.rs @@ -0,0 +1,79 @@ +use super::fallback; + +mod neon; + +/// AArch64 is a 64-bit architecture introduced with ARMv8. NEON is required +/// in all standard ARMv8 implementations, so no runtime detection is required +/// to call NEON functions. +/// +/// # Safety +/// +/// There are no safety requirements for this definition of the macro. It is +/// safe for all inputs since it is restricted to either the fallback routine +/// or the NEON routine, which is always safe to call on AArch64 as explained +/// previously. +macro_rules! unsafe_ifunc { + ($fnty:ty, $name:ident, $haystack:ident, $($needle:ident),+) => {{ + if cfg!(memchr_runtime_neon) { + unsafe { neon::$name($($needle),+, $haystack) } + } else { + fallback::$name($($needle),+, $haystack) + } + }} +} + +#[inline(always)] +pub fn memchr(n1: u8, haystack: &[u8]) -> Option { + unsafe_ifunc!(fn(u8, &[u8]) -> Option, memchr, haystack, n1) +} + +#[inline(always)] +pub fn memchr2(n1: u8, n2: u8, haystack: &[u8]) -> Option { + unsafe_ifunc!( + fn(u8, u8, &[u8]) -> Option, + memchr2, + haystack, + n1, + n2 + ) +} + +#[inline(always)] +pub fn memchr3(n1: u8, n2: u8, n3: u8, haystack: &[u8]) -> Option { + unsafe_ifunc!( + fn(u8, u8, u8, &[u8]) -> Option, + memchr3, + haystack, + n1, + n2, + n3 + ) +} + +#[inline(always)] +pub fn memrchr(n1: u8, haystack: &[u8]) -> Option { + unsafe_ifunc!(fn(u8, &[u8]) -> Option, memrchr, haystack, n1) +} + +#[inline(always)] +pub fn memrchr2(n1: u8, n2: u8, haystack: &[u8]) -> Option { + unsafe_ifunc!( + fn(u8, u8, &[u8]) -> Option, + memrchr2, + haystack, + n1, + n2 + ) +} + +#[inline(always)] +pub fn memrchr3(n1: u8, n2: u8, n3: u8, haystack: &[u8]) -> Option { + unsafe_ifunc!( + fn(u8, u8, u8, &[u8]) -> Option, + memrchr3, + haystack, + n1, + n2, + n3 + ) +} diff --git a/src/memchr/aarch64/neon.rs b/src/memchr/aarch64/neon.rs new file mode 100644 index 0000000..2a4eff9 --- /dev/null +++ b/src/memchr/aarch64/neon.rs @@ -0,0 +1,432 @@ +use std::arch::aarch64::*; +use std::mem::transmute; + +const VEC_SIZE: usize = 16; + +/// Unroll size for mem{r}chr. +const UNROLL_SIZE_1: usize = 4; +/// Unroll size for mem{r}chr{2,3}. +const UNROLL_SIZE_23: usize = 2; + +#[target_feature(enable = "neon")] +pub unsafe fn memchr(n1: u8, haystack: &[u8]) -> Option { + memchr_generic_neon::( + [n1], + haystack, + ) +} + +#[target_feature(enable = "neon")] +pub unsafe fn memchr2(n1: u8, n2: u8, haystack: &[u8]) -> Option { + memchr_generic_neon::( + [n1, n2], + haystack, + ) +} + +#[target_feature(enable = "neon")] +pub unsafe fn memchr3( + n1: u8, + n2: u8, + n3: u8, + haystack: &[u8], +) -> Option { + memchr_generic_neon::( + [n1, n2, n3], + haystack, + ) +} + +#[target_feature(enable = "neon")] +pub unsafe fn memrchr(n1: u8, haystack: &[u8]) -> Option { + memchr_generic_neon::( + [n1], + haystack, + ) +} + +#[target_feature(enable = "neon")] +pub unsafe fn memrchr2(n1: u8, n2: u8, haystack: &[u8]) -> Option { + memchr_generic_neon::( + [n1, n2], + haystack, + ) +} + +#[target_feature(enable = "neon")] +pub(crate) unsafe fn memrchr3( + n1: u8, + n2: u8, + n3: u8, + haystack: &[u8], +) -> Option { + memchr_generic_neon::( + [n1, n2, n3], + haystack, + ) +} + +const fn generate_mask32() -> u32 { + let mut mask = 0; + let mut byte = 0b0000_0011; + + let mut i = 0; + while i < 4 { + mask |= byte; + byte <<= 8 + 2; + + i += 1; + } + + mask +} + +const fn generate_mask64() -> u64 { + let mut mask = 0; + let mut byte = 0b0000_0001; + + let mut i = 0; + while i < 8 { + mask |= byte; + byte <<= 8 + 1; + + i += 1; + } + + mask +} + +/// Returns true if the all bits in the register are set to 0. +#[inline(always)] +unsafe fn eq0(x: uint8x16_t) -> bool { + low64(vpmaxq_u8(x, x)) == 0 +} + +#[inline(always)] +unsafe fn low64(x: uint8x16_t) -> u64 { + vgetq_lane_u64(vreinterpretq_u64_u8(x), 0) +} + +// .fold() and .reduce() cause LLVM to generate a huge dependency chain, +// so we need a custom function to explicitly parallelize the bitwise OR +// reduction to better take advantage of modern superscalar CPUs. +#[inline(always)] +unsafe fn parallel_reduce( + mut masks: [uint8x16_t; N], +) -> uint8x16_t { + let mut len = masks.len(); + + while len != 1 { + for i in 0..len / 2 { + masks[i] = vorrq_u8(masks[i * 2], masks[i * 2 + 1]); + } + if len & 1 != 0 { + masks[0] = vorrq_u8(masks[0], masks[len - 1]); + } + len /= 2; + } + + masks[0] +} + +/// Search 64 bytes +#[inline(always)] +unsafe fn search64< + const IS_FWD: bool, + const N: usize, + const N2: usize, + const N4: usize, +>( + n: [u8; N], + ptr: *const u8, + start_ptr: *const u8, +) -> Option { + assert!(N4 == 4 * N); + assert!(N2 == 2 * N); + + const MASK4: u64 = generate_mask64(); + + let repmask4 = vreinterpretq_u8_u64(vdupq_n_u64(MASK4)); + + let x1 = vld1q_u8(ptr); + let x2 = vld1q_u8(ptr.add(1 * VEC_SIZE)); + let x3 = vld1q_u8(ptr.add(2 * VEC_SIZE)); + let x4 = vld1q_u8(ptr.add(3 * VEC_SIZE)); + + let mut nv: [uint8x16_t; N] = [vdupq_n_u8(0); N]; + for i in 0..N { + nv[i] = vdupq_n_u8(n[i]); + } + + let mut masks1 = [vdupq_n_u8(0); N]; + let mut masks2 = [vdupq_n_u8(0); N]; + let mut masks3 = [vdupq_n_u8(0); N]; + let mut masks4 = [vdupq_n_u8(0); N]; + + for i in 0..N { + masks1[i] = vceqq_u8(x1, nv[i]); + masks2[i] = vceqq_u8(x2, nv[i]); + masks3[i] = vceqq_u8(x3, nv[i]); + masks4[i] = vceqq_u8(x4, nv[i]); + } + + let cmpmask = parallel_reduce({ + let mut mask1234 = [vdupq_n_u8(0); N4]; + mask1234[..N].copy_from_slice(&masks1); + mask1234[N..2 * N].copy_from_slice(&masks2); + mask1234[2 * N..3 * N].copy_from_slice(&masks3); + mask1234[3 * N..4 * N].copy_from_slice(&masks4); + mask1234 + }); + + if !eq0(cmpmask) { + let cmp1 = parallel_reduce(masks1); + let cmp2 = parallel_reduce(masks2); + let cmp3 = parallel_reduce(masks3); + let cmp4 = parallel_reduce(masks4); + + let cmp1 = vandq_u8(repmask4, cmp1); + let cmp2 = vandq_u8(repmask4, cmp2); + let cmp3 = vandq_u8(repmask4, cmp3); + let cmp4 = vandq_u8(repmask4, cmp4); + + let reduce1 = vpaddq_u8(cmp1, cmp2); + let reduce2 = vpaddq_u8(cmp3, cmp4); + let reduce3 = vpaddq_u8(reduce1, reduce2); + let reduce4 = vpaddq_u8(reduce3, reduce3); + + let low64: u64 = low64(reduce4); + + let offset = ptr as usize - start_ptr as usize; + + if IS_FWD { + return Some(offset + low64.trailing_zeros() as usize); + } else { + return Some( + offset + (4 * VEC_SIZE - 1) - (low64.leading_zeros() as usize), + ); + } + } + + None +} + +/// Search 32 bytes +#[inline(always)] +unsafe fn search32< + const IS_FWD: bool, + const N: usize, + const N2: usize, + const N4: usize, +>( + n: [u8; N], + ptr: *const u8, + start_ptr: *const u8, +) -> Option { + assert!(N2 == 2 * N); + + const MASK: u32 = generate_mask32(); + let repmask2 = vdupq_n_u32(MASK); + + let x1 = vld1q_u8(ptr); + let x2 = vld1q_u8(ptr.add(VEC_SIZE)); + + let mut nv: [uint8x16_t; N] = [vdupq_n_u8(0); N]; + for i in 0..N { + nv[i] = vdupq_n_u8(n[i]); + } + + let mut masks1 = [vdupq_n_u8(0); N]; + let mut masks2 = [vdupq_n_u8(0); N]; + + for i in 0..N { + masks1[i] = vceqq_u8(x1, nv[i]); + masks2[i] = vceqq_u8(x2, nv[i]); + } + + let cmpmask = parallel_reduce({ + let mut mask12 = [vdupq_n_u8(0); N2]; + mask12[..N].copy_from_slice(&masks1); + mask12[N..2 * N].copy_from_slice(&masks2); + mask12 + }); + + if !eq0(cmpmask) { + let cmp1 = parallel_reduce(masks1); + let cmp2 = parallel_reduce(masks2); + + let cmp1 = vandq_u8(transmute(repmask2), cmp1); + let cmp2 = vandq_u8(transmute(repmask2), cmp2); + + let reduce1 = vpaddq_u8(cmp1, cmp2); + let reduce2 = vpaddq_u8(reduce1, reduce1); + + let low64: u64 = low64(reduce2); + + let offset = ptr as usize - start_ptr as usize; + + if IS_FWD { + return Some(offset + low64.trailing_zeros() as usize / 2); + } else { + return Some( + offset + (2 * VEC_SIZE - 1) + - (low64.leading_zeros() as usize / 2), + ); + } + } + + None +} + +/// Search 16 bytes +#[inline(always)] +unsafe fn search16< + const IS_FWD: bool, + const N: usize, + const N2: usize, + const N4: usize, +>( + n: [u8; N], + ptr: *const u8, + start_ptr: *const u8, +) -> Option { + let repmask1 = vreinterpretq_u8_u16(vdupq_n_u16(0xF00F)); + + let mut nv: [uint8x16_t; N] = [vdupq_n_u8(0); N]; + for i in 0..N { + nv[i] = vdupq_n_u8(n[i]); + } + + let x1 = vld1q_u8(ptr); + + let mut cmp_masks = [vdupq_n_u8(0); N]; + + for i in 0..N { + cmp_masks[i] = vceqq_u8(x1, nv[i]); + } + + let cmpmask = parallel_reduce(cmp_masks); + + if !eq0(cmpmask) { + let cmpmask = vandq_u8(cmpmask, repmask1); + let combined = vpaddq_u8(cmpmask, cmpmask); + let comb_low: u64 = low64(combined); + + let offset = ptr as usize - start_ptr as usize; + + if IS_FWD { + return Some(offset + comb_low.trailing_zeros() as usize / 4); + } else { + return Some( + offset + (VEC_SIZE - 1) + - (comb_low.leading_zeros() as usize / 4), + ); + } + } + + None +} + +#[inline] +#[target_feature(enable = "neon")] +unsafe fn memchr_generic_neon< + const IS_FWD: bool, + const N: usize, + const N2: usize, + const N4: usize, + const UNROLL: usize, +>( + n: [u8; N], + haystack: &[u8], +) -> Option { + assert!(UNROLL <= 4 && UNROLL.is_power_of_two()); + + let is_match = |x: u8| -> bool { n.iter().any(|&y| y == x) }; + + let start_ptr = haystack.as_ptr(); + + if haystack.len() < VEC_SIZE { + if IS_FWD { + // For whatever reason, LLVM generates significantly worse + // code when using .copied() on the forward search, but + // generates very good code for the reverse search (even + // better than manual pointer arithmetic). + return haystack.iter().position(|&x| is_match(x)); + } else { + return haystack.iter().copied().rposition(is_match); + } + } + + // dynamic trait object devirtualized by LLVM upon monomorphization + let iter: &mut dyn Iterator; + + let mut x1; + let mut x2; + let remainder; + + if IS_FWD { + let temp = haystack.chunks_exact(UNROLL * VEC_SIZE); + remainder = temp.remainder(); + x1 = temp; + iter = &mut x1; + } else { + let temp = haystack.rchunks_exact(UNROLL * VEC_SIZE); + remainder = temp.remainder(); + x2 = temp; + iter = &mut x2; + } + + let loop_search = match UNROLL { + 1 => search16::, + 2 => search32::, + 4 => search64::, + _ => unreachable!(), + }; + + for chunk in iter { + if let Some(idx) = loop_search(n, chunk.as_ptr(), start_ptr) { + return Some(idx); + } + } + + let mut ptr = if IS_FWD { + remainder.as_ptr() + } else { + remainder.as_ptr().add(remainder.len()).offset(-(VEC_SIZE as isize)) + }; + + if UNROLL > 1 { + for _ in 0..remainder.len() / VEC_SIZE { + if let Some(idx) = if IS_FWD { + let ret = search16::(n, ptr, start_ptr); + + ptr = ptr.add(VEC_SIZE); + + ret + } else { + let ret = search16::(n, ptr, start_ptr); + + ptr = ptr.offset(-(VEC_SIZE as isize)); + + ret + } { + return Some(idx); + } + } + } + + if haystack.len() % VEC_SIZE != 0 { + // overlapped search of remainder + if IS_FWD { + return search16::( + n, + start_ptr.add(haystack.len() - VEC_SIZE), + start_ptr, + ); + } else { + return search16::(n, start_ptr, start_ptr); + } + } + + None +} diff --git a/src/memchr/mod.rs b/src/memchr/mod.rs index 09ce6ef..9a4cdeb 100644 --- a/src/memchr/mod.rs +++ b/src/memchr/mod.rs @@ -2,6 +2,8 @@ use core::iter::Rev; pub use self::iter::{Memchr, Memchr2, Memchr3}; +#[cfg(all(not(miri), target_arch = "aarch64", memchr_runtime_neon))] +mod aarch64; // N.B. If you're looking for the cfg knobs for libc, see build.rs. #[cfg(memchr_libc)] mod c; @@ -107,9 +109,16 @@ pub fn memchr(needle: u8, haystack: &[u8]) -> Option { c::memchr(n1, haystack) } + #[cfg(all(target_arch = "aarch64", memchr_runtime_neon, not(miri)))] + #[inline(always)] + fn imp(n1: u8, haystack: &[u8]) -> Option { + aarch64::memchr(n1, haystack) + } + #[cfg(all( not(memchr_libc), not(all(target_arch = "x86_64", memchr_runtime_simd)), + not(all(target_arch = "aarch64", memchr_runtime_neon)), not(miri), ))] #[inline(always)] @@ -161,8 +170,15 @@ pub fn memchr2(needle1: u8, needle2: u8, haystack: &[u8]) -> Option { x86::memchr2(n1, n2, haystack) } + #[cfg(all(target_arch = "aarch64", memchr_runtime_neon, not(miri)))] + #[inline(always)] + fn imp(n1: u8, n2: u8, haystack: &[u8]) -> Option { + aarch64::memchr2(n1, n2, haystack) + } + #[cfg(all( not(all(target_arch = "x86_64", memchr_runtime_simd)), + not(all(target_arch = "aarch64", memchr_runtime_neon)), not(miri), ))] #[inline(always)] @@ -219,8 +235,15 @@ pub fn memchr3( x86::memchr3(n1, n2, n3, haystack) } + #[cfg(all(target_arch = "aarch64", memchr_runtime_neon, not(miri)))] + #[inline(always)] + fn imp(n1: u8, n2: u8, n3: u8, haystack: &[u8]) -> Option { + aarch64::memchr3(n1, n2, n3, haystack) + } + #[cfg(all( not(all(target_arch = "x86_64", memchr_runtime_simd)), + not(all(target_arch = "aarch64", memchr_runtime_neon)), not(miri), ))] #[inline(always)] @@ -281,9 +304,16 @@ pub fn memrchr(needle: u8, haystack: &[u8]) -> Option { c::memrchr(n1, haystack) } + #[cfg(all(target_arch = "aarch64", memchr_runtime_neon, not(miri)))] + #[inline(always)] + fn imp(n1: u8, haystack: &[u8]) -> Option { + aarch64::memrchr(n1, haystack) + } + #[cfg(all( not(all(memchr_libc, target_os = "linux")), not(all(target_arch = "x86_64", memchr_runtime_simd)), + not(all(target_arch = "aarch64", memchr_runtime_neon)), not(miri), ))] #[inline(always)] @@ -335,7 +365,14 @@ pub fn memrchr2(needle1: u8, needle2: u8, haystack: &[u8]) -> Option { x86::memrchr2(n1, n2, haystack) } + #[cfg(all(target_arch = "aarch64", memchr_runtime_neon, not(miri)))] + #[inline(always)] + fn imp(n1: u8, n2: u8, haystack: &[u8]) -> Option { + aarch64::memrchr2(n1, n2, haystack) + } + #[cfg(all( + not(all(target_arch = "aarch64", memchr_runtime_neon)), not(all(target_arch = "x86_64", memchr_runtime_simd)), not(miri), ))] @@ -393,7 +430,14 @@ pub fn memrchr3( x86::memrchr3(n1, n2, n3, haystack) } + #[cfg(all(target_arch = "aarch64", memchr_runtime_neon, not(miri)))] + #[inline(always)] + fn imp(n1: u8, n2: u8, n3: u8, haystack: &[u8]) -> Option { + aarch64::memrchr3(n1, n2, n3, haystack) + } + #[cfg(all( + not(all(target_arch = "aarch64", memchr_runtime_neon)), not(all(target_arch = "x86_64", memchr_runtime_simd)), not(miri), ))] diff --git a/src/memchr/x86/mod.rs b/src/memchr/x86/mod.rs index aec35db..7cd7d43 100644 --- a/src/memchr/x86/mod.rs +++ b/src/memchr/x86/mod.rs @@ -72,7 +72,7 @@ macro_rules! unsafe_ifunc { /// When std isn't available to provide runtime CPU feature detection, or if /// runtime CPU feature detection has been explicitly disabled, then just -/// call our optimized SSE2 routine directly. SSE2 is avalbale on all x86_64 +/// call our optimized SSE2 routine directly. SSE2 is available on all x86_64 /// targets, so no CPU feature detection is necessary. /// /// # Safety From 9815fdef90f13b6d0ca3de0b7ad06266205eaf3b Mon Sep 17 00:00:00 2001 From: Redzic Date: Fri, 8 Jul 2022 19:46:25 -0500 Subject: [PATCH 2/3] fix inconsistency --- src/memchr/aarch64/neon.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memchr/aarch64/neon.rs b/src/memchr/aarch64/neon.rs index 2a4eff9..29f960c 100644 --- a/src/memchr/aarch64/neon.rs +++ b/src/memchr/aarch64/neon.rs @@ -54,7 +54,7 @@ pub unsafe fn memrchr2(n1: u8, n2: u8, haystack: &[u8]) -> Option { } #[target_feature(enable = "neon")] -pub(crate) unsafe fn memrchr3( +pub unsafe fn memrchr3( n1: u8, n2: u8, n3: u8, From 564a24ae0522bd362dc5bb464ece6acfd43fe338 Mon Sep 17 00:00:00 2001 From: Redzic Date: Fri, 8 Jul 2022 19:55:34 -0500 Subject: [PATCH 3/3] update docs --- README.md | 2 +- src/lib.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 77a7a0f..cc49913 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ memchr links to the standard library by default, but you can disable the memchr = { version = "2", default-features = false } ``` -On x86 platforms, when the `std` feature is disabled, the SSE2 accelerated +On x86-64 platforms, when the `std` feature is disabled, the SSE2 accelerated implementations will be used. When `std` is enabled, AVX accelerated implementations will be used if the CPU is determined to support it at runtime. diff --git a/src/lib.rs b/src/lib.rs index e0b4ce3..9097ce6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -133,9 +133,9 @@ library haven't quite been worked out yet. **NOTE:** Currently, only `x86_64` targets have highly accelerated implementations of substring search. For `memchr`, all targets have -somewhat-accelerated implementations, while only `x86_64` targets have highly -accelerated implementations. This limitation is expected to be lifted once the -standard library exposes a platform independent SIMD API. +somewhat-accelerated implementations, while `x86_64` and `aarch64` targets +have highly accelerated implementations. This limitation is expected to be +lifted once the standard library exposes a platform independent SIMD API. # Crate features @@ -144,7 +144,7 @@ standard library exposes a platform independent SIMD API. from the standard library is runtime SIMD CPU feature detection. This means that this feature must be enabled to get AVX accelerated routines. When `std` is not enabled, this crate will still attempt to use SSE2 accelerated - routines on `x86_64`. + routines on `x86_64` and NEON accelerated routines on `aarch64`. * **libc** - When enabled (**not** the default), this library will use your platform's libc implementation of `memchr` (and `memrchr` on Linux). This can be useful on non-`x86_64` targets where the fallback implementation in