Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve autovectorization of to_lowercase / to_uppercase functions #123778

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions library/alloc/benches/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,5 @@ make_test!(rsplitn_space_char, s, s.rsplitn(10, ' ').count());

make_test!(split_space_str, s, s.split(" ").count());
make_test!(split_ad_str, s, s.split("ad").count());

make_test!(to_lowercase, s, s.to_lowercase());
129 changes: 77 additions & 52 deletions library/alloc/src/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

use core::borrow::{Borrow, BorrowMut};
use core::iter::FusedIterator;
use core::mem::MaybeUninit;
#[stable(feature = "encode_utf16", since = "1.8.0")]
pub use core::str::EncodeUtf16;
#[stable(feature = "split_ascii_whitespace", since = "1.34.0")]
Expand Down Expand Up @@ -365,14 +366,9 @@ impl str {
without modifying the original"]
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
pub fn to_lowercase(&self) -> String {
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase);

// Safety: we know this is a valid char boundary since
// out.len() is only progressed if ascii bytes are found
let rest = unsafe { self.get_unchecked(out.len()..) };

// Safety: We have written only valid ASCII to our vec
let mut s = unsafe { String::from_utf8_unchecked(out) };
let prefix_len = s.len();

for (i, c) in rest.char_indices() {
if c == 'Σ' {
Expand All @@ -381,8 +377,7 @@ impl str {
// in `SpecialCasing.txt`,
// so hard-code it rather than have a generic "condition" mechanism.
// See https://github.com/rust-lang/rust/issues/26035
let out_len = self.len() - rest.len();
let sigma_lowercase = map_uppercase_sigma(&self, i + out_len);
let sigma_lowercase = map_uppercase_sigma(self, prefix_len + i);
s.push(sigma_lowercase);
} else {
match conversions::to_lower(c) {
Expand Down Expand Up @@ -458,14 +453,7 @@ impl str {
without modifying the original"]
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
pub fn to_uppercase(&self) -> String {
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);

// Safety: we know this is a valid char boundary since
// out.len() is only progressed if ascii bytes are found
let rest = unsafe { self.get_unchecked(out.len()..) };

// Safety: We have written only valid ASCII to our vec
let mut s = unsafe { String::from_utf8_unchecked(out) };
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase);

for c in rest.chars() {
match conversions::to_upper(c) {
Expand Down Expand Up @@ -614,50 +602,87 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box<str> {
unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
}

/// Converts the bytes while the bytes are still ascii.
/// Converts leading ascii bytes in `s` by calling the `convert` function.
jhorstmann marked this conversation as resolved.
Show resolved Hide resolved
///
/// For better average performance, this happens in chunks of `2*size_of::<usize>()`.
/// Returns a vec with the converted bytes.
///
/// Returns a tuple of the converted prefix and the remainder starting from
/// the first non-ascii character.
///
/// This function is only public so that it can be verified in a codegen test,
/// see `issue-123712-str-to-lower-autovectorization.rs`.
#[unstable(feature = "str_internals", issue = "none")]
#[doc(hidden)]
#[inline]
#[cfg(not(test))]
#[cfg(not(no_global_oom_handling))]
fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec<u8> {
let mut out = Vec::with_capacity(b.len());
pub fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
// Process the input in chunks of 16 bytes to enable auto-vectorization.
// Previously the chunk size depended on the size of `usize`,
// but on 32-bit platforms with sse or neon is also the better choice.
// The only downside on other platforms would be a bit more loop-unrolling.
const N: usize = 16;

let mut slice = s.as_bytes();
let mut out = Vec::with_capacity(slice.len());
let mut out_slice = out.spare_capacity_mut();

let mut ascii_prefix_len = 0_usize;
let mut is_ascii = [false; N];

while slice.len() >= N {
// SAFETY: checked in loop condition
let chunk = unsafe { slice.get_unchecked(..N) };
// SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets
let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };

for j in 0..N {
is_ascii[j] = chunk[j] <= 127;
}

const USIZE_SIZE: usize = mem::size_of::<usize>();
const MAGIC_UNROLL: usize = 2;
const N: usize = USIZE_SIZE * MAGIC_UNROLL;
const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);
// Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk
// size gives the best result, specifically a pmovmsk instruction on x86.
// See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not
// currently recognize other similar idioms.
if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
break;
}

let mut i = 0;
unsafe {
while i + N <= b.len() {
// Safety: we have checks the sizes `b` and `out` to know that our
let in_chunk = b.get_unchecked(i..i + N);
let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);

let mut bits = 0;
for j in 0..MAGIC_UNROLL {
// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
// safety: in_chunk is valid bytes in the range
bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned();
}
// if our chunks aren't ascii, then return only the prior bytes as init
if bits & NONASCII_MASK != 0 {
break;
}
for j in 0..N {
out_chunk[j] = MaybeUninit::new(convert(&chunk[j]));
}

// perform the case conversions on N bytes (gets heavily autovec'd)
for j in 0..N {
// safety: in_chunk and out_chunk is valid bytes in the range
let out = out_chunk.get_unchecked_mut(j);
out.write(convert(in_chunk.get_unchecked(j)));
}
ascii_prefix_len += N;
slice = unsafe { slice.get_unchecked(N..) };
out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
}

// mark these bytes as initialised
i += N;
// handle the remainder as individual bytes
while slice.len() > 0 {
let byte = slice[0];
if byte > 127 {
break;
}
// SAFETY: out_slice has at least same length as input slice
unsafe {
*out_slice.get_unchecked_mut(0) = MaybeUninit::new(convert(&byte));
}
out.set_len(i);
ascii_prefix_len += 1;
slice = unsafe { slice.get_unchecked(1..) };
out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
}

out
unsafe {
// SAFETY: ascii_prefix_len bytes have been initialized above
out.set_len(ascii_prefix_len);

// SAFETY: We have written only valid ascii to the output vec
let ascii_string = String::from_utf8_unchecked(out);

// SAFETY: we know this is a valid char boundary
// since we only skipped over leading ascii bytes
let rest = core::str::from_utf8_unchecked(slice);

(ascii_string, rest)
}
}
3 changes: 3 additions & 0 deletions library/alloc/tests/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1854,7 +1854,10 @@ fn to_lowercase() {
assert_eq!("ΑΣ''Α".to_lowercase(), "ασ''α");

// https://github.com/rust-lang/rust/issues/124714
// input lengths around the boundary of the chunk size used by the ascii prefix optimization
assert_eq!("abcdefghijklmnoΣ".to_lowercase(), "abcdefghijklmnoς");
assert_eq!("abcdefghijklmnopΣ".to_lowercase(), "abcdefghijklmnopς");
assert_eq!("abcdefghijklmnopqΣ".to_lowercase(), "abcdefghijklmnopqς");

// a really long string that has it's lowercase form
// even longer. this tests that implementations don't assume
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//@ only-x86_64
//@ compile-flags: -C opt-level=3
#![crate_type = "lib"]
#![no_std]
#![feature(str_internals)]

extern crate alloc;

/// Ensure that the ascii-prefix loop for `str::to_lowercase` and `str::to_uppercase` uses vector
/// instructions.
///
/// The llvm ir should be the same for all targets that support some form of simd. Only targets
/// without any simd instructions would see scalarized ir.
/// Unfortunately, there is no `only-simd` directive to only run this test on only such platforms,
/// and using test revisions would still require the core libraries for all platforms.
// CHECK-LABEL: @lower_while_ascii
// CHECK: [[A:%[0-9]]] = load <16 x i8>
// CHECK-NEXT: [[B:%[0-9]]] = icmp slt <16 x i8> [[A]], zeroinitializer
// CHECK-NEXT: [[C:%[0-9]]] = bitcast <16 x i1> [[B]] to i16
#[no_mangle]
pub fn lower_while_ascii(s: &str) -> (alloc::string::String, &str) {
alloc::str::convert_while_ascii(s, u8::to_ascii_lowercase)
}
Loading