From cdabb70b9ff20793c6c8629e6d65e18b44f741e4 Mon Sep 17 00:00:00 2001 From: Aaron O'Mullan Date: Tue, 25 Apr 2023 15:54:41 -0300 Subject: [PATCH] perf: SIMD neon support (#133) First pass at neon support, building off #132 --- .github/workflows/ci.yml | 38 ++++++ src/simd/mod.rs | 13 ++ src/simd/neon.rs | 256 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 307 insertions(+) create mode 100644 src/simd/neon.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 847be80..3ea589f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -147,3 +147,41 @@ jobs: - name: Test run: MIRIFLAGS="-Zmiri-tag-raw-pointers -Zmiri-check-number-validity" cargo miri test + + aarch64: + name: Test aarch64 (neon) + runs-on: ubuntu-latest + env: + CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER: aarch64-linux-gnu-gcc + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Setup Rust + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + target: aarch64-unknown-linux-gnu + + - name: Install QEMU and dependencies + run: | + sudo apt-get update + sudo apt-get install -y qemu qemu-user gcc-aarch64-linux-gnu + + - name: Build tests + run: cargo build --tests --target aarch64-unknown-linux-gnu + + - name: Run tests with QEMU + run: | + test_binaries=$(find target/aarch64-unknown-linux-gnu/debug/deps/ -type f -executable -name 'httparse-*') + if [ -n "$test_binaries" ]; then + for test_binary in $test_binaries + do + echo "Running tests in ${test_binary}" + /usr/bin/qemu-aarch64 -L /usr/aarch64-linux-gnu/ "${test_binary}" + done + else + echo "No test binaries found." + fi diff --git a/src/simd/mod.rs b/src/simd/mod.rs index 91f682d..81bdd87 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -5,6 +5,7 @@ mod swar; any( target_arch = "x86", target_arch = "x86_64", + target_arch = "aarch64", ), )))] pub use self::swar::*; @@ -132,3 +133,15 @@ mod avx2_compile_time { ), ))] pub use self::avx2_compile_time::*; + +#[cfg(all( + httparse_simd, + target_arch = "aarch64", +))] +mod neon; + +#[cfg(all( + httparse_simd, + target_arch = "aarch64", +))] +pub use self::neon::*; diff --git a/src/simd/neon.rs b/src/simd/neon.rs new file mode 100644 index 0000000..8897d65 --- /dev/null +++ b/src/simd/neon.rs @@ -0,0 +1,256 @@ +use crate::iter::Bytes; +use core::arch::aarch64::*; + +#[inline] +pub fn match_header_name_vectored(bytes: &mut Bytes) { + while bytes.as_ref().len() >= 16 { + unsafe { + let advance = match_header_name_char_16_neon(bytes.as_ref().as_ptr()); + bytes.advance(advance); + + if advance != 16 { + return; + } + } + } + super::swar::match_header_name_vectored(bytes); +} + +#[inline] +pub fn match_header_value_vectored(bytes: &mut Bytes) { + while bytes.as_ref().len() >= 16 { + unsafe { + let advance = match_header_value_char_16_neon(bytes.as_ref().as_ptr()); + bytes.advance(advance); + + if advance != 16 { + return; + } + } + } + super::swar::match_header_value_vectored(bytes); +} + +#[inline] +pub fn match_uri_vectored(bytes: &mut Bytes) { + while bytes.as_ref().len() >= 16 { + unsafe { + let advance = match_url_char_16_neon(bytes.as_ref().as_ptr()); + bytes.advance(advance); + + if advance != 16 { + return; + } + } + } + super::swar::match_uri_vectored(bytes); +} + +const fn bit_set(x: u8) -> bool { + // Validates if a byte is a valid header name character + // https://tools.ietf.org/html/rfc7230#section-3.2.6 + matches!(x, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' | b'^' | b'_' | b'`' | b'|' | b'~') +} + +// A 256-bit bitmap, split into two halves +// lower half contains bits whose higher nibble is <= 7 +// higher half contains bits whose higher nibble is >= 8 +const fn build_bitmap() -> ([u8; 16], [u8; 16]) { + let mut bitmap_0_7 = [0u8; 16]; // 0x00..0x7F + let mut bitmap_8_15 = [0u8; 16]; // 0x80..0xFF + let mut i = 0; + while i < 256 { + if bit_set(i as u8) { + // Nibbles + let (lo, hi) = (i & 0x0F, i >> 4); + if i < 128 { + bitmap_0_7[lo] |= 1 << hi; + } else { + bitmap_8_15[lo] |= 1 << hi; + } + } + i += 1; + } + (bitmap_0_7, bitmap_8_15) +} + +const BITMAPS: ([u8; 16], [u8; 16]) = build_bitmap(); + +// NOTE: adapted from 256-bit version, with upper 128-bit ops commented out +#[inline] +unsafe fn match_header_name_char_16_neon(ptr: *const u8) -> usize { + let bitmaps = BITMAPS; + // NOTE: ideally compile-time constants + let (bitmap_0_7, _bitmap_8_15) = bitmaps; + let bitmap_0_7 = vld1q_u8(bitmap_0_7.as_ptr()); + // let bitmap_8_15 = vld1q_u8(bitmap_8_15.as_ptr()); + + // Initialize the bitmask_lookup. + const BITMASK_LOOKUP_DATA: [u8; 16] = + [1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128]; + let bitmask_lookup = vld1q_u8(BITMASK_LOOKUP_DATA.as_ptr()); + + // Load 16 input bytes. + let input = vld1q_u8(ptr); + + // Extract indices for row_0_7. + let indices_0_7 = vandq_u8(input, vdupq_n_u8(0x8F)); // 0b1000_1111; + + // Extract indices for row_8_15. + // let msb = vandq_u8(input, vdupq_n_u8(0x80)); + // let indices_8_15 = veorq_u8(indices_0_7, msb); + + // Fetch row_0_7 and row_8_15. + let row_0_7 = vqtbl1q_u8(bitmap_0_7, indices_0_7); + // let row_8_15 = vqtbl1q_u8(bitmap_8_15, indices_8_15); + + // Calculate a bitmask, i.e. (1 << hi_nibble % 8). + let bitmask = vqtbl1q_u8(bitmask_lookup, vshrq_n_u8(input, 4)); + + // Choose rows halves depending on higher nibbles. + // let bitsets = vorrq_u8(row_0_7, row_8_15); + let bitsets = row_0_7; + + // Finally check which bytes belong to the set. + let tmp = vandq_u8(bitsets, bitmask); + let result = vceqq_u8(tmp, bitmask); + + offsetz(result) as usize +} + +#[inline] +unsafe fn match_url_char_16_neon(ptr: *const u8) -> usize { + let input = vld1q_u8(ptr); + + // Check that b'!' <= input <= b'~' + let result = vandq_u8( + vcleq_u8(vdupq_n_u8(b'!'), input), + vcleq_u8(input, vdupq_n_u8(b'~')), + ); + // Check that input != b'<' and input != b'>' + let lt = vceqq_u8(input, vdupq_n_u8(b'<')); + let gt = vceqq_u8(input, vdupq_n_u8(b'>')); + let ltgt = vorrq_u8(lt, gt); + // Nand with result + let result = vbicq_u8(result, ltgt); + + offsetz(result) as usize +} + +#[inline] +unsafe fn match_header_value_char_16_neon(ptr: *const u8) -> usize { + let input = vld1q_u8(ptr); + + // Check that b' ' <= and b != 127 or b == 9 + let result = vcleq_u8(vdupq_n_u8(b' '), input); + + // Allow tab + let tab = vceqq_u8(input, vdupq_n_u8(0x09)); + let result = vorrq_u8(result, tab); + + // Disallow del + let del = vceqq_u8(input, vdupq_n_u8(0x7F)); + let result = vbicq_u8(result, del); + + offsetz(result) as usize +} + +#[inline] +unsafe fn offsetz(x: uint8x16_t) -> u32 { + // NOT the vector since it's faster to operate with zeros instead + offsetnz(vmvnq_u8(x)) +} + +#[inline] +unsafe fn offsetnz(x: uint8x16_t) -> u32 { + // Extract two u64 + let x = vreinterpretq_u64_u8(x); + let low: u64 = std::mem::transmute(vget_low_u64(x)); + let high: u64 = std::mem::transmute(vget_high_u64(x)); + + #[inline] + fn clz(x: u64) -> u32 { + // perf: rust will unroll this loop + // and it's much faster than rbit + clz so voila + for (i, b) in x.to_ne_bytes().iter().copied().enumerate() { + if b != 0 { + return i as u32; + } + } + 8 // Technically not reachable since zero-guarded + } + + if low != 0 { + return clz(low); + } else if high != 0 { + return 8 + clz(high); + } else { + return 16; + } +} + +#[test] +fn neon_code_matches_uri_chars_table() { + unsafe { + assert!(byte_is_allowed(b'_', match_uri_vectored)); + + for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() { + assert_eq!( + byte_is_allowed(b as u8, match_uri_vectored), + allowed, + "byte_is_allowed({:?}) should be {:?}", + b, + allowed, + ); + } + } +} + +#[test] +fn neon_code_matches_header_value_chars_table() { + unsafe { + assert!(byte_is_allowed(b'_', match_header_value_vectored)); + + for (b, allowed) in crate::HEADER_VALUE_MAP.iter().cloned().enumerate() { + assert_eq!( + byte_is_allowed(b as u8, match_header_value_vectored), + allowed, + "byte_is_allowed({:?}) should be {:?}", + b, + allowed, + ); + } + } +} + +#[test] +fn neon_code_matches_header_name_chars_table() { + unsafe { + assert!(byte_is_allowed(b'_', match_header_name_vectored)); + + for (b, allowed) in crate::HEADER_NAME_MAP.iter().cloned().enumerate() { + assert_eq!( + byte_is_allowed(b as u8, match_header_name_vectored), + allowed, + "byte_is_allowed({:?}) should be {:?}", + b, + allowed, + ); + } + } +} + +#[cfg(test)] +unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>)) -> bool { + let mut slice = [b'_'; 16]; + slice[10] = byte; + let mut bytes = Bytes::new(&slice); + + f(&mut bytes); + + match bytes.pos() { + 16 => true, + 10 => false, + x => panic!("unexpected pos: {}", x), + } +}