Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: simd neon #133

Merged
merged 1 commit into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,41 @@ jobs:

- name: Test
run: MIRIFLAGS="-Zmiri-tag-raw-pointers -Zmiri-check-number-validity" cargo miri test

aarch64:
name: Test aarch64 (neon)
runs-on: ubuntu-latest
env:
CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER: aarch64-linux-gnu-gcc
steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Setup Rust
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
target: aarch64-unknown-linux-gnu

- name: Install QEMU and dependencies
run: |
sudo apt-get update
sudo apt-get install -y qemu qemu-user gcc-aarch64-linux-gnu

- name: Build tests
run: cargo build --tests --target aarch64-unknown-linux-gnu

- name: Run tests with QEMU
run: |
test_binaries=$(find target/aarch64-unknown-linux-gnu/debug/deps/ -type f -executable -name 'httparse-*')
if [ -n "$test_binaries" ]; then
for test_binary in $test_binaries
do
echo "Running tests in ${test_binary}"
/usr/bin/qemu-aarch64 -L /usr/aarch64-linux-gnu/ "${test_binary}"
done
else
echo "No test binaries found."
fi
13 changes: 13 additions & 0 deletions src/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod swar;
any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "aarch64",
),
)))]
pub use self::swar::*;
Expand Down Expand Up @@ -132,3 +133,15 @@ mod avx2_compile_time {
),
))]
pub use self::avx2_compile_time::*;

#[cfg(all(
httparse_simd,
target_arch = "aarch64",
))]
mod neon;

#[cfg(all(
httparse_simd,
target_arch = "aarch64",
))]
pub use self::neon::*;
256 changes: 256 additions & 0 deletions src/simd/neon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
use crate::iter::Bytes;
use core::arch::aarch64::*;

#[inline]
pub fn match_header_name_vectored(bytes: &mut Bytes) {
while bytes.as_ref().len() >= 16 {
unsafe {
let advance = match_header_name_char_16_neon(bytes.as_ref().as_ptr());
bytes.advance(advance);

if advance != 16 {
return;
}
}
}
super::swar::match_header_name_vectored(bytes);
}

#[inline]
pub fn match_header_value_vectored(bytes: &mut Bytes) {
while bytes.as_ref().len() >= 16 {
unsafe {
let advance = match_header_value_char_16_neon(bytes.as_ref().as_ptr());
bytes.advance(advance);

if advance != 16 {
return;
}
}
}
super::swar::match_header_value_vectored(bytes);
}

#[inline]
pub fn match_uri_vectored(bytes: &mut Bytes) {
while bytes.as_ref().len() >= 16 {
unsafe {
let advance = match_url_char_16_neon(bytes.as_ref().as_ptr());
bytes.advance(advance);

if advance != 16 {
return;
}
}
}
super::swar::match_uri_vectored(bytes);
}

const fn bit_set(x: u8) -> bool {
// Validates if a byte is a valid header name character
// https://tools.ietf.org/html/rfc7230#section-3.2.6
matches!(x, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' | b'^' | b'_' | b'`' | b'|' | b'~')
}

// A 256-bit bitmap, split into two halves
// lower half contains bits whose higher nibble is <= 7
// higher half contains bits whose higher nibble is >= 8
const fn build_bitmap() -> ([u8; 16], [u8; 16]) {
let mut bitmap_0_7 = [0u8; 16]; // 0x00..0x7F
let mut bitmap_8_15 = [0u8; 16]; // 0x80..0xFF
let mut i = 0;
while i < 256 {
if bit_set(i as u8) {
// Nibbles
let (lo, hi) = (i & 0x0F, i >> 4);
if i < 128 {
bitmap_0_7[lo] |= 1 << hi;
} else {
bitmap_8_15[lo] |= 1 << hi;
}
}
i += 1;
}
(bitmap_0_7, bitmap_8_15)
}

const BITMAPS: ([u8; 16], [u8; 16]) = build_bitmap();

// NOTE: adapted from 256-bit version, with upper 128-bit ops commented out
#[inline]
unsafe fn match_header_name_char_16_neon(ptr: *const u8) -> usize {
let bitmaps = BITMAPS;
// NOTE: ideally compile-time constants
let (bitmap_0_7, _bitmap_8_15) = bitmaps;
let bitmap_0_7 = vld1q_u8(bitmap_0_7.as_ptr());
// let bitmap_8_15 = vld1q_u8(bitmap_8_15.as_ptr());
AaronO marked this conversation as resolved.
Show resolved Hide resolved

// Initialize the bitmask_lookup.
const BITMASK_LOOKUP_DATA: [u8; 16] =
[1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128];
let bitmask_lookup = vld1q_u8(BITMASK_LOOKUP_DATA.as_ptr());

// Load 16 input bytes.
let input = vld1q_u8(ptr);

// Extract indices for row_0_7.
let indices_0_7 = vandq_u8(input, vdupq_n_u8(0x8F)); // 0b1000_1111;

// Extract indices for row_8_15.
// let msb = vandq_u8(input, vdupq_n_u8(0x80));
// let indices_8_15 = veorq_u8(indices_0_7, msb);

// Fetch row_0_7 and row_8_15.
let row_0_7 = vqtbl1q_u8(bitmap_0_7, indices_0_7);
// let row_8_15 = vqtbl1q_u8(bitmap_8_15, indices_8_15);

// Calculate a bitmask, i.e. (1 << hi_nibble % 8).
let bitmask = vqtbl1q_u8(bitmask_lookup, vshrq_n_u8(input, 4));

// Choose rows halves depending on higher nibbles.
// let bitsets = vorrq_u8(row_0_7, row_8_15);
let bitsets = row_0_7;

// Finally check which bytes belong to the set.
let tmp = vandq_u8(bitsets, bitmask);
let result = vceqq_u8(tmp, bitmask);

offsetz(result) as usize
}

#[inline]
unsafe fn match_url_char_16_neon(ptr: *const u8) -> usize {
let input = vld1q_u8(ptr);

// Check that b'!' <= input <= b'~'
let result = vandq_u8(
vcleq_u8(vdupq_n_u8(b'!'), input),
vcleq_u8(input, vdupq_n_u8(b'~')),
);
// Check that input != b'<' and input != b'>'
let lt = vceqq_u8(input, vdupq_n_u8(b'<'));
let gt = vceqq_u8(input, vdupq_n_u8(b'>'));
let ltgt = vorrq_u8(lt, gt);
// Nand with result
let result = vbicq_u8(result, ltgt);

offsetz(result) as usize
}

#[inline]
unsafe fn match_header_value_char_16_neon(ptr: *const u8) -> usize {
let input = vld1q_u8(ptr);

// Check that b' ' <= and b != 127 or b == 9
let result = vcleq_u8(vdupq_n_u8(b' '), input);

// Allow tab
let tab = vceqq_u8(input, vdupq_n_u8(0x09));
let result = vorrq_u8(result, tab);

// Disallow del
let del = vceqq_u8(input, vdupq_n_u8(0x7F));
let result = vbicq_u8(result, del);

offsetz(result) as usize
}

#[inline]
unsafe fn offsetz(x: uint8x16_t) -> u32 {
// NOT the vector since it's faster to operate with zeros instead
offsetnz(vmvnq_u8(x))
}

#[inline]
unsafe fn offsetnz(x: uint8x16_t) -> u32 {
// Extract two u64
let x = vreinterpretq_u64_u8(x);
let low: u64 = std::mem::transmute(vget_low_u64(x));
let high: u64 = std::mem::transmute(vget_high_u64(x));

#[inline]
fn clz(x: u64) -> u32 {
// perf: rust will unroll this loop
// and it's much faster than rbit + clz so voila
for (i, b) in x.to_ne_bytes().iter().copied().enumerate() {
if b != 0 {
return i as u32;
}
}
8 // Technically not reachable since zero-guarded
}

if low != 0 {
return clz(low);
} else if high != 0 {
return 8 + clz(high);
} else {
return 16;
}
}

#[test]
fn neon_code_matches_uri_chars_table() {
unsafe {
assert!(byte_is_allowed(b'_', match_uri_vectored));

for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() {
assert_eq!(
byte_is_allowed(b as u8, match_uri_vectored),
allowed,
"byte_is_allowed({:?}) should be {:?}",
b,
allowed,
);
}
}
}

#[test]
fn neon_code_matches_header_value_chars_table() {
unsafe {
assert!(byte_is_allowed(b'_', match_header_value_vectored));

for (b, allowed) in crate::HEADER_VALUE_MAP.iter().cloned().enumerate() {
assert_eq!(
byte_is_allowed(b as u8, match_header_value_vectored),
allowed,
"byte_is_allowed({:?}) should be {:?}",
b,
allowed,
);
}
}
}

#[test]
fn neon_code_matches_header_name_chars_table() {
unsafe {
assert!(byte_is_allowed(b'_', match_header_name_vectored));

for (b, allowed) in crate::HEADER_NAME_MAP.iter().cloned().enumerate() {
assert_eq!(
byte_is_allowed(b as u8, match_header_name_vectored),
allowed,
"byte_is_allowed({:?}) should be {:?}",
b,
allowed,
);
}
}
}

#[cfg(test)]
unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>)) -> bool {
let mut slice = [b'_'; 16];
slice[10] = byte;
let mut bytes = Bytes::new(&slice);

f(&mut bytes);

match bytes.pos() {
16 => true,
10 => false,
x => panic!("unexpected pos: {}", x),
}
}