From efc7d6e68326f53d3c03b9a3d1b44c3d98b1435b Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Wed, 7 Feb 2024 17:04:34 -0500 Subject: [PATCH] Remove unnecessary string cloning from the parser --- crates/ruff_python_parser/src/ascii.rs | 345 +++++++++++++++++++ crates/ruff_python_parser/src/lib.rs | 1 + crates/ruff_python_parser/src/python.lalrpop | 4 +- crates/ruff_python_parser/src/python.rs | 6 +- crates/ruff_python_parser/src/string.rs | 266 +++++++++----- 5 files changed, 539 insertions(+), 83 deletions(-) create mode 100644 crates/ruff_python_parser/src/ascii.rs diff --git a/crates/ruff_python_parser/src/ascii.rs b/crates/ruff_python_parser/src/ascii.rs new file mode 100644 index 00000000000000..87614dc98a0e8f --- /dev/null +++ b/crates/ruff_python_parser/src/ascii.rs @@ -0,0 +1,345 @@ +#![allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_ptr_alignment, + clippy::inline_always, + clippy::ptr_as_ptr, + unsafe_code +)] + +//! Source: + +// The following ~400 lines of code exists for exactly one purpose, which is +// to optimize this code: +// +// byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len()) +// +// Yes... Overengineered is a word that comes to mind, but this is effectively +// a very similar problem to memchr, and virtually nobody has been able to +// resist optimizing the crap out of that (except for perhaps the BSD and MUSL +// folks). In particular, this routine makes a very common case (ASCII) very +// fast, which seems worth it. We do stop short of adding AVX variants of the +// code below in order to retain our sanity and also to avoid needing to deal +// with runtime target feature detection. RESIST! +// +// In order to understand the SIMD version below, it would be good to read this +// comment describing how my memchr routine works: +// https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106 +// +// The primary difference with memchr is that for ASCII, we can do a bit less +// work. In particular, we don't need to detect the presence of a specific +// byte, but rather, whether any byte has its most significant bit set. That +// means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to +// _mm_movemask_epi8. + +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +const USIZE_BYTES: usize = core::mem::size_of::(); +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES; + +// This is a mask where the most significant bit of each byte in the usize +// is set. We test this bit to determine whether a character is ASCII or not. +// Namely, a single byte is regarded as an ASCII codepoint if and only if it's +// most significant bit is not set. +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +const ASCII_MASK_U64: u64 = 0x8080_8080_8080_8080; +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +const ASCII_MASK: usize = ASCII_MASK_U64 as usize; + +/// Returns the index of the first non ASCII byte in the given slice. +/// +/// If slice only contains ASCII bytes, then the length of the slice is +/// returned. +pub(crate) fn first_non_ascii_byte(slice: &[u8]) -> usize { + #[cfg(any(miri, not(target_arch = "x86_64")))] + { + first_non_ascii_byte_fallback(slice) + } + + #[cfg(all(not(miri), target_arch = "x86_64"))] + { + first_non_ascii_byte_sse2(slice) + } +} + +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize { + let align = USIZE_BYTES - 1; + let start_ptr = slice.as_ptr(); + let end_ptr = slice[slice.len()..].as_ptr(); + let mut ptr = start_ptr; + + unsafe { + if slice.len() < USIZE_BYTES { + return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); + } + + let chunk = read_unaligned_usize(ptr); + let mask = chunk & ASCII_MASK; + if mask != 0 { + return first_non_ascii_byte_mask(mask); + } + + ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & align)); + debug_assert!(ptr > start_ptr); + debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr); + if slice.len() >= FALLBACK_LOOP_SIZE { + while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) { + debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); + + let a = *ptr.cast::(); + let b = *ptr_add(ptr, USIZE_BYTES).cast::(); + if (a | b) & ASCII_MASK != 0 { + // What a kludge. We wrap the position finding code into + // a non-inlineable function, which makes the codegen in + // the tight loop above a bit better by avoiding a + // couple extra movs. We pay for it by two additional + // stores, but only in the case of finding a non-ASCII + // byte. + #[inline(never)] + unsafe fn findpos(start_ptr: *const u8, ptr: *const u8) -> usize { + let a = *ptr.cast::(); + let b = *ptr_add(ptr, USIZE_BYTES).cast::(); + + let mut at = sub(ptr, start_ptr); + let maska = a & ASCII_MASK; + if maska != 0 { + return at + first_non_ascii_byte_mask(maska); + } + + at += USIZE_BYTES; + let maskb = b & ASCII_MASK; + debug_assert!(maskb != 0); + at + first_non_ascii_byte_mask(maskb) + } + return findpos(start_ptr, ptr); + } + ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE); + } + } + first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) + } +} + +#[cfg(all(not(miri), target_arch = "x86_64"))] +fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize { + use core::arch::x86_64::{ + __m128i, _mm_load_si128, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128, + }; + + const VECTOR_SIZE: usize = core::mem::size_of::<__m128i>(); + const VECTOR_ALIGN: usize = VECTOR_SIZE - 1; + const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE; + + let start_ptr = slice.as_ptr(); + let end_ptr = slice[slice.len()..].as_ptr(); + let mut ptr = start_ptr; + + unsafe { + if slice.len() < VECTOR_SIZE { + return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); + } + + let chunk = _mm_loadu_si128(ptr as *const __m128i); + let mask = _mm_movemask_epi8(chunk); + if mask != 0 { + return mask.trailing_zeros() as usize; + } + + ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN)); + debug_assert!(ptr > start_ptr); + debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr); + if slice.len() >= VECTOR_LOOP_SIZE { + while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) { + debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE); + + let a = _mm_load_si128(ptr as *const __m128i); + let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i); + let c = _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i); + let d = _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i); + + let or1 = _mm_or_si128(a, b); + let or2 = _mm_or_si128(c, d); + let or3 = _mm_or_si128(or1, or2); + if _mm_movemask_epi8(or3) != 0 { + let mut at = sub(ptr, start_ptr); + let mask = _mm_movemask_epi8(a); + if mask != 0 { + return at + mask.trailing_zeros() as usize; + } + + at += VECTOR_SIZE; + let mask = _mm_movemask_epi8(b); + if mask != 0 { + return at + mask.trailing_zeros() as usize; + } + + at += VECTOR_SIZE; + let mask = _mm_movemask_epi8(c); + if mask != 0 { + return at + mask.trailing_zeros() as usize; + } + + at += VECTOR_SIZE; + let mask = _mm_movemask_epi8(d); + debug_assert!(mask != 0); + return at + mask.trailing_zeros() as usize; + } + ptr = ptr_add(ptr, VECTOR_LOOP_SIZE); + } + } + while ptr <= end_ptr.sub(VECTOR_SIZE) { + debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE); + + let chunk = _mm_loadu_si128(ptr as *const __m128i); + let mask = _mm_movemask_epi8(chunk); + if mask != 0 { + return sub(ptr, start_ptr) + mask.trailing_zeros() as usize; + } + ptr = ptr.add(VECTOR_SIZE); + } + first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) + } +} + +#[inline(always)] +unsafe fn first_non_ascii_byte_slow( + start_ptr: *const u8, + end_ptr: *const u8, + mut ptr: *const u8, +) -> usize { + debug_assert!(start_ptr <= ptr); + debug_assert!(ptr <= end_ptr); + + while ptr < end_ptr { + if *ptr > 0x7F { + return sub(ptr, start_ptr); + } + ptr = ptr.offset(1); + } + sub(end_ptr, start_ptr) +} + +/// Compute the position of the first ASCII byte in the given mask. +/// +/// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is +/// 8 contiguous bytes of the slice being checked where *at least* one of those +/// bytes is not an ASCII byte. +/// +/// The position returned is always in the inclusive range [0, 7]. +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +fn first_non_ascii_byte_mask(mask: usize) -> usize { + #[cfg(target_endian = "little")] + { + mask.trailing_zeros() as usize / 8 + } + #[cfg(target_endian = "big")] + { + mask.leading_zeros() as usize / 8 + } +} + +/// Increment the given pointer by the given amount. +unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 { + debug_assert!(amt < ::core::isize::MAX as usize); + ptr.add(amt) +} + +/// Decrement the given pointer by the given amount. +unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 { + debug_assert!(amt < ::core::isize::MAX as usize); + ptr.offset((amt as isize).wrapping_neg()) +} + +#[cfg(any(test, miri, not(target_arch = "x86_64")))] +unsafe fn read_unaligned_usize(ptr: *const u8) -> usize { + use core::ptr; + + let mut n: usize = 0; + ptr::copy_nonoverlapping(ptr, std::ptr::addr_of_mut!(n) as *mut u8, USIZE_BYTES); + n +} + +/// Subtract `b` from `a` and return the difference. `a` should be greater than +/// or equal to `b`. +fn sub(a: *const u8, b: *const u8) -> usize { + debug_assert!(a >= b); + (a as usize) - (b as usize) +} + +#[cfg(test)] +mod tests { + use super::*; + + // Our testing approach here is to try and exhaustively test every case. + // This includes the position at which a non-ASCII byte occurs in addition + // to the alignment of the slice that we're searching. + + #[test] + fn positive_fallback_forward() { + for i in 0..517 { + let s = "a".repeat(i); + assert_eq!( + i, + first_non_ascii_byte_fallback(s.as_bytes()), + "i: {:?}, len: {:?}, s: {:?}", + i, + s.len(), + s + ); + } + } + + #[test] + #[cfg(target_arch = "x86_64")] + #[cfg(not(miri))] + fn positive_sse2_forward() { + for i in 0..517 { + let b = "a".repeat(i).into_bytes(); + assert_eq!(b.len(), first_non_ascii_byte_sse2(&b)); + } + } + + #[test] + #[cfg(not(miri))] + fn negative_fallback_forward() { + for i in 0..517 { + for align in 0..65 { + let mut s = "a".repeat(i); + s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃"); + let s = s.get(align..).unwrap_or(""); + assert_eq!( + i.saturating_sub(align), + first_non_ascii_byte_fallback(s.as_bytes()), + "i: {:?}, align: {:?}, len: {:?}, s: {:?}", + i, + align, + s.len(), + s + ); + } + } + } + + #[test] + #[cfg(target_arch = "x86_64")] + #[cfg(not(miri))] + fn negative_sse2_forward() { + for i in 0..517 { + for align in 0..65 { + let mut s = "a".repeat(i); + s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃"); + let s = s.get(align..).unwrap_or(""); + assert_eq!( + i.saturating_sub(align), + first_non_ascii_byte_sse2(s.as_bytes()), + "i: {:?}, align: {:?}, len: {:?}, s: {:?}", + i, + align, + s.len(), + s + ); + } + } + } +} diff --git a/crates/ruff_python_parser/src/lib.rs b/crates/ruff_python_parser/src/lib.rs index 2f95c684e87d97..074746ec01ab80 100644 --- a/crates/ruff_python_parser/src/lib.rs +++ b/crates/ruff_python_parser/src/lib.rs @@ -121,6 +121,7 @@ use crate::lexer::LexResult; mod function; // Skip flattening lexer to distinguish from full ruff_python_parser +mod ascii; mod context; mod invalid; pub mod lexer; diff --git a/crates/ruff_python_parser/src/python.lalrpop b/crates/ruff_python_parser/src/python.lalrpop index cc9bf71e8a1100..a3ae2dd514c350 100644 --- a/crates/ruff_python_parser/src/python.lalrpop +++ b/crates/ruff_python_parser/src/python.lalrpop @@ -1606,7 +1606,7 @@ StringLiteralOrFString: StringType = { StringLiteral: StringType = { =>? { let (source, kind, triple_quoted) = string; - Ok(parse_string_literal(&source, kind, triple_quoted, (location..end_location).into())?) + Ok(parse_string_literal(source, kind, triple_quoted, (location..end_location).into())?) } }; @@ -1623,7 +1623,7 @@ FStringMiddlePattern: ast::FStringElement = { FStringReplacementField, =>? { let (source, is_raw, _) = fstring_middle; - Ok(parse_fstring_literal_element(&source, is_raw, (location..end_location).into())?) + Ok(parse_fstring_literal_element(source, is_raw, (location..end_location).into())?) } }; diff --git a/crates/ruff_python_parser/src/python.rs b/crates/ruff_python_parser/src/python.rs index c409f91eeebf49..5addad53844267 100644 --- a/crates/ruff_python_parser/src/python.rs +++ b/crates/ruff_python_parser/src/python.rs @@ -1,5 +1,5 @@ // auto-generated: "lalrpop 0.20.0" -// sha3: aa0540221d25f4eadfc9e043fb4fc631d537b672b8a96785dfec2407e0524b79 +// sha3: 83dd2ba251ff635b813dfe48854debd5935c1a506789893ac1c2638639f27353 use ruff_text_size::{Ranged, TextLen, TextRange, TextSize}; use ruff_python_ast::{self as ast, Int, IpyEscapeKind}; use crate::{ @@ -36369,7 +36369,7 @@ fn __action217< { { let (source, kind, triple_quoted) = string; - Ok(parse_string_literal(&source, kind, triple_quoted, (location..end_location).into())?) + Ok(parse_string_literal(source, kind, triple_quoted, (location..end_location).into())?) } } @@ -36419,7 +36419,7 @@ fn __action220< { { let (source, is_raw, _) = fstring_middle; - Ok(parse_fstring_literal_element(&source, is_raw, (location..end_location).into())?) + Ok(parse_fstring_literal_element(source, is_raw, (location..end_location).into())?) } } diff --git a/crates/ruff_python_parser/src/string.rs b/crates/ruff_python_parser/src/string.rs index 80f42e453b089a..ca85495b2dacdb 100644 --- a/crates/ruff_python_parser/src/string.rs +++ b/crates/ruff_python_parser/src/string.rs @@ -1,8 +1,9 @@ //! Parsing of string literals, bytes literals, and implicit string concatenation. use ruff_python_ast::{self as ast, Expr}; -use ruff_text_size::{Ranged, TextLen, TextRange, TextSize}; +use ruff_text_size::{Ranged, TextRange, TextSize}; +use crate::ascii::first_non_ascii_byte; use crate::lexer::{LexicalError, LexicalErrorType}; use crate::token::{StringKind, Tok}; @@ -32,34 +33,40 @@ impl From for Expr { } } -struct StringParser<'a> { - rest: &'a str, +enum EscapedChar { + Literal(char), + Escape(char), +} + +struct StringParser { + source: String, + cursor: usize, kind: StringKind, - location: TextSize, + offset: TextSize, range: TextRange, } -impl<'a> StringParser<'a> { - fn new(source: &'a str, kind: StringKind, start: TextSize, range: TextRange) -> Self { +impl StringParser { + fn new(source: String, kind: StringKind, offset: TextSize, range: TextRange) -> Self { Self { - rest: source, + source, + cursor: 0, kind, - location: start, + offset, range, } } #[inline] - fn skip_bytes(&mut self, bytes: usize) -> &'a str { - let skipped_str = &self.rest[..bytes]; - self.rest = &self.rest[bytes..]; - self.location += skipped_str.text_len(); + fn skip_bytes(&mut self, bytes: usize) -> &str { + let skipped_str = &self.source[self.cursor..self.cursor + bytes]; + self.cursor += bytes; skipped_str } #[inline] fn get_pos(&self) -> TextSize { - self.location + self.offset + TextSize::try_from(self.cursor).unwrap() } /// Returns the next byte in the string, if there is one. @@ -69,25 +76,23 @@ impl<'a> StringParser<'a> { /// When the next byte is a part of a multi-byte character. #[inline] fn next_byte(&mut self) -> Option { - self.rest.as_bytes().first().map(|&byte| { - self.rest = &self.rest[1..]; - self.location += TextSize::new(1); + self.source[self.cursor..].as_bytes().first().map(|&byte| { + self.cursor += 1; byte }) } #[inline] fn next_char(&mut self) -> Option { - self.rest.chars().next().map(|c| { - self.rest = &self.rest[c.len_utf8()..]; - self.location += c.text_len(); + self.source[self.cursor..].chars().next().map(|c| { + self.cursor += c.len_utf8(); c }) } #[inline] fn peek_byte(&self) -> Option { - self.rest.as_bytes().first().copied() + self.source[self.cursor..].as_bytes().first().copied() } fn parse_unicode_literal(&mut self, literal_number: usize) -> Result { @@ -135,7 +140,7 @@ impl<'a> StringParser<'a> { }; let start_pos = self.get_pos(); - let Some(close_idx) = self.rest.find('}') else { + let Some(close_idx) = self.source[self.cursor..].find('}') else { return Err(LexicalError::new( LexicalErrorType::StringError, self.get_pos(), @@ -149,7 +154,8 @@ impl<'a> StringParser<'a> { .ok_or_else(|| LexicalError::new(LexicalErrorType::UnicodeError, start_pos)) } - fn parse_escaped_char(&mut self, string: &mut String) -> Result<(), LexicalError> { + /// Parse an escaped character, returning the new character. + fn parse_escaped_char(&mut self) -> Result, LexicalError> { let Some(first_char) = self.next_char() else { return Err(LexicalError { error: LexicalErrorType::StringError, @@ -174,13 +180,13 @@ impl<'a> StringParser<'a> { 'U' if !self.kind.is_any_bytes() => self.parse_unicode_literal(8)?, 'N' if !self.kind.is_any_bytes() => self.parse_unicode_name()?, // Special cases where the escape sequence is not a single character - '\n' => return Ok(()), + '\n' => return Ok(None), '\r' => { if self.peek_byte() == Some(b'\n') { self.next_byte(); } - return Ok(()); + return Ok(None); } _ => { if self.kind.is_any_bytes() && !first_char.is_ascii() { @@ -192,21 +198,42 @@ impl<'a> StringParser<'a> { }); } - string.push('\\'); - - first_char + return Ok(Some(EscapedChar::Escape(first_char))); } }; - string.push(new_char); - - Ok(()) + Ok(Some(EscapedChar::Literal(new_char))) } - fn parse_fstring_middle(&mut self) -> Result { - let mut value = String::with_capacity(self.rest.len()); - while let Some(ch) = self.next_char() { - match ch { + fn parse_fstring_middle(mut self) -> Result { + // Fast-path: if the f-string doesn't contain any escape sequences, return the literal. + let Some(mut index) = memchr::memchr3(b'{', b'}', b'\\', self.source.as_bytes()) else { + return Ok(ast::FStringElement::Literal(ast::FStringLiteralElement { + value: self.source, + range: self.range, + })); + }; + + let mut value = String::with_capacity(self.source.len()); + loop { + // Add the characters before the escape sequence to the string. + let before_with_slash = self.skip_bytes(index + 1); + let before = &before_with_slash[..before_with_slash.len() - 1]; + value.push_str(before); + + // Add the escaped character to the string. + match &self.source.as_bytes()[self.cursor - 1] { + // If there are any curly braces inside a `FStringMiddle` token, + // then they were escaped (i.e. `{{` or `}}`). This means that + // we need increase the location by 2 instead of 1. + b'{' => { + self.offset += TextSize::from(1); + value.push('{'); + } + b'}' => { + self.offset += TextSize::from(1); + value.push('}'); + } // We can encounter a `\` as the last character in a `FStringMiddle` // token which is valid in this context. For example, // @@ -227,69 +254,152 @@ impl<'a> StringParser<'a> { // This is still an invalid escape sequence, but we don't want to // raise a syntax error as is done by the CPython parser. It might // be supported in the future, refer to point 3: https://peps.python.org/pep-0701/#rejected-ideas - '\\' if !self.kind.is_raw() && self.peek_byte().is_some() => { - self.parse_escaped_char(&mut value)?; + b'\\' if !self.kind.is_raw() && self.peek_byte().is_some() => { + match self.parse_escaped_char()? { + None => {} + Some(EscapedChar::Literal(c)) => value.push(c), + Some(EscapedChar::Escape(c)) => { + value.push('\\'); + value.push(c); + } + } } - // If there are any curly braces inside a `FStringMiddle` token, - // then they were escaped (i.e. `{{` or `}}`). This means that - // we need increase the location by 2 instead of 1. - ch @ ('{' | '}') => { - self.location += ch.text_len(); - value.push(ch); + ch => { + value.push(char::from(*ch)); } - ch => value.push(ch), } + + let Some(next_index) = + memchr::memchr3(b'{', b'}', b'\\', self.source[self.cursor..].as_bytes()) + else { + // Add the rest of the string to the value. + let rest = &self.source[self.cursor..]; + value.push_str(rest); + break; + }; + + index = next_index; } + Ok(ast::FStringElement::Literal(ast::FStringLiteralElement { value, range: self.range, })) } - fn parse_bytes(&mut self) -> Result { - let mut content = String::with_capacity(self.rest.len()); - while let Some(ch) = self.next_char() { - match ch { - '\\' if !self.kind.is_raw() => { - self.parse_escaped_char(&mut content)?; - } - ch => { - if !ch.is_ascii() { - return Err(LexicalError::new( - LexicalErrorType::OtherError( - "bytes can only contain ASCII literal characters".to_string(), - ), - self.get_pos(), - )); - } - content.push(ch); + fn parse_bytes(mut self) -> Result { + let index = first_non_ascii_byte(self.source.as_bytes()); + if index < self.source.len() { + return Err(LexicalError::new( + LexicalErrorType::OtherError( + "bytes can only contain ASCII literal characters".to_string(), + ), + self.offset + TextSize::try_from(index).unwrap(), + )); + } + + if self.kind.is_raw() { + // For raw strings, no escaping is necessary. + return Ok(StringType::Bytes(ast::BytesLiteral { + value: self.source.into_bytes(), + range: self.range, + })); + } + + let Some(mut escape) = memchr::memchr(b'\\', self.source.as_bytes()) else { + // If the string doesn't contain any escape sequences, return the owned string. + return Ok(StringType::Bytes(ast::BytesLiteral { + value: self.source.into_bytes(), + range: self.range, + })); + }; + + // If the string contains escape sequences, we need to parse them. + let mut value = Vec::with_capacity(self.source.len()); + loop { + // Add the characters before the escape sequence to the string. + let before_with_slash = self.skip_bytes(escape + 1); + let before = &before_with_slash[..before_with_slash.len() - 1]; + value.extend_from_slice(before.as_bytes()); + + // Add the escaped character to the string. + match self.parse_escaped_char()? { + None => {} + Some(EscapedChar::Literal(c)) => value.push(c as u8), + Some(EscapedChar::Escape(c)) => { + value.push(b'\\'); + value.push(c as u8); } } + + let Some(next_escape) = memchr::memchr(b'\\', self.source[self.cursor..].as_bytes()) + else { + // Add the rest of the string to the value. + let rest = &self.source[self.cursor..]; + value.extend_from_slice(rest.as_bytes()); + break; + }; + + // Update the position of the next escape sequence. + escape = next_escape; } + Ok(StringType::Bytes(ast::BytesLiteral { - value: content.chars().map(|c| c as u8).collect::>(), + value, range: self.range, })) } - fn parse_string(&mut self) -> Result { - let mut value = String::with_capacity(self.rest.len()); + fn parse_string(mut self) -> Result { if self.kind.is_raw() { - value.push_str(self.skip_bytes(self.rest.len())); - } else { - loop { - let Some(escape_idx) = self.rest.find('\\') else { - value.push_str(self.skip_bytes(self.rest.len())); - break; - }; + // For raw strings, no escaping is necessary. + return Ok(StringType::Str(ast::StringLiteral { + value: self.source, + unicode: self.kind.is_unicode(), + range: self.range, + })); + } - let before_with_slash = self.skip_bytes(escape_idx + 1); - let before = &before_with_slash[..before_with_slash.len() - 1]; + let Some(mut escape) = memchr::memchr(b'\\', self.source.as_bytes()) else { + // If the string doesn't contain any escape sequences, return the owned string. + return Ok(StringType::Str(ast::StringLiteral { + value: self.source, + unicode: self.kind.is_unicode(), + range: self.range, + })); + }; - value.push_str(before); - self.parse_escaped_char(&mut value)?; + // If the string contains escape sequences, we need to parse them. + let mut value = String::with_capacity(self.source.len()); + + loop { + // Add the characters before the escape sequence to the string. + let before_with_slash = self.skip_bytes(escape + 1); + let before = &before_with_slash[..before_with_slash.len() - 1]; + value.push_str(before); + + // Add the escaped character to the string. + match self.parse_escaped_char()? { + None => {} + Some(EscapedChar::Literal(c)) => value.push(c), + Some(EscapedChar::Escape(c)) => { + value.push('\\'); + value.push(c); + } } + + let Some(next_escape) = memchr::memchr(b'\\', self.source[self.cursor..].as_bytes()) + else { + // Add the rest of the string to the value. + let rest = &self.source[self.cursor..]; + value.push_str(rest); + break; + }; + + // Update the position of the next escape sequence. + escape = next_escape; } + Ok(StringType::Str(ast::StringLiteral { value, unicode: self.kind.is_unicode(), @@ -297,7 +407,7 @@ impl<'a> StringParser<'a> { })) } - fn parse(&mut self) -> Result { + fn parse(self) -> Result { if self.kind.is_any_bytes() { self.parse_bytes() } else { @@ -307,7 +417,7 @@ impl<'a> StringParser<'a> { } pub(crate) fn parse_string_literal( - source: &str, + source: String, kind: StringKind, triple_quoted: bool, range: TextRange, @@ -323,7 +433,7 @@ pub(crate) fn parse_string_literal( } pub(crate) fn parse_fstring_literal_element( - source: &str, + source: String, is_raw: bool, range: TextRange, ) -> Result { @@ -356,7 +466,7 @@ pub(crate) fn concatenated_strings( if has_bytes && byte_literal_count < strings.len() { return Err(LexicalError { error: LexicalErrorType::OtherError( - "cannot mix bytes and nonbytes literals".to_owned(), + "cannot mix bytes and non-bytes literals".to_owned(), ), location: range.start(), });