From 291a9245187e794f295e86290acca5e70e95d1f3 Mon Sep 17 00:00:00 2001 From: zachs18 <8355914+zachs18@users.noreply.github.com> Date: Tue, 30 Jul 2024 23:05:11 +0000 Subject: [PATCH] Allow casting between slices of ZSTs and slices of non-ZSTs in all cases. (#256) Casting ZST to non-ZST will result in a slice length of 0. Casting non-ZST to ZST will only work if the input slice has length 0, and results in a slice length of 0; if the input slice is not of length 0, PodCastError::OutputSliceWouldHaveSlop is returned. Updates the docs of the PodCastError variants to reflect when they can occur. Updates the docs of try_cast_slice (and checked::) to remove note about ZST <-> non-ZST not being allowed. Update bytes_of(_mut) to remove ZST check, since casting [ZST] -> [u8] is now allowed directly using cast_slice(_mut). Update must_cast_slice checks and doctests to allow [ZST] -> [non-ZST], but disallow [non-ZST] -> [ZST]. --- src/allocation.rs | 21 ++++++++++++-------- src/checked.rs | 2 -- src/internal.rs | 42 +++++++++++++++++---------------------- src/lib.rs | 11 +++++----- src/must.rs | 34 ++++++++++++++++++++++++------- tests/cast_slice_tests.rs | 39 +++++++++++++++++++++++++++--------- 6 files changed, 93 insertions(+), 56 deletions(-) diff --git a/src/allocation.rs b/src/allocation.rs index 78c30ff..bb4032f 100644 --- a/src/allocation.rs +++ b/src/allocation.rs @@ -178,7 +178,7 @@ pub fn try_cast_slice_box( { // If the size in bytes of the underlying buffer does not match an exact // multiple of the size of B, we cannot cast between them. - Err((PodCastError::SizeMismatch, input)) + Err((PodCastError::OutputSliceWouldHaveSlop, input)) } else { // Because the size is an exact multiple, we can now change the length // of the slice and recreate the Box @@ -239,7 +239,7 @@ pub fn try_cast_vec( // length and capacity are valid under B, as we do not want to // change which bytes are considered part of the initialized slice // of the Vec - Err((PodCastError::SizeMismatch, input)) + Err((PodCastError::OutputSliceWouldHaveSlop, input)) } else { // Because the size is an exact multiple, we can now change the length and // capacity and recreate the Vec @@ -431,7 +431,7 @@ pub fn try_cast_slice_rc< { // If the size in bytes of the underlying buffer does not match an exact // multiple of the size of B, we cannot cast between them. - Err((PodCastError::SizeMismatch, input)) + Err((PodCastError::OutputSliceWouldHaveSlop, input)) } else { // Because the size is an exact multiple, we can now change the length // of the slice and recreate the Rc @@ -499,7 +499,7 @@ pub fn try_cast_slice_arc< { // If the size in bytes of the underlying buffer does not match an exact // multiple of the size of B, we cannot cast between them. - Err((PodCastError::SizeMismatch, input)) + Err((PodCastError::OutputSliceWouldHaveSlop, input)) } else { // Because the size is an exact multiple, we can now change the length // of the slice and recreate the Arc @@ -850,13 +850,18 @@ impl sealed::FromBoxBytes for [T] { let single_layout = Layout::new::(); if bytes.layout.align() != single_layout.align() { Err((PodCastError::AlignmentMismatch, bytes)) - } else if single_layout.size() == 0 { - Err((PodCastError::SizeMismatch, bytes)) - } else if bytes.layout.size() % single_layout.size() != 0 { + } else if (single_layout.size() == 0 && bytes.layout.size() != 0) + || (single_layout.size() != 0 + && bytes.layout.size() % single_layout.size() != 0) + { Err((PodCastError::OutputSliceWouldHaveSlop, bytes)) } else { let (ptr, layout) = bytes.into_raw_parts(); - let length = layout.size() / single_layout.size(); + let length = if single_layout.size() != 0 { + layout.size() / single_layout.size() + } else { + 0 + }; let ptr = core::ptr::slice_from_raw_parts_mut(ptr.as_ptr() as *mut T, length); // SAFETY: See BoxBytes type invariant. diff --git a/src/checked.rs b/src/checked.rs index cb8c4c3..3299ce8 100644 --- a/src/checked.rs +++ b/src/checked.rs @@ -368,8 +368,6 @@ pub fn try_cast_mut< /// type, and the output slice wouldn't be a whole number of elements when /// accounting for the size change (eg: 3 `u16` values is 1.5 `u32` values, so /// that's a failure). -/// * Similarly, you can't convert between a [ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts) -/// and a non-ZST. /// * If any element of the converted slice would contain an invalid bit pattern /// for `B` this fails. #[inline] diff --git a/src/internal.rs b/src/internal.rs index 3ede50f..06935e6 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -51,13 +51,9 @@ pub(crate) fn something_went_wrong(_src: &str, _err: D) -> ! { /// empty slice might not match the pointer value of the input reference. #[inline(always)] pub(crate) unsafe fn bytes_of(t: &T) -> &[u8] { - if size_of::() == 0 { - &[] - } else { - match try_cast_slice::(core::slice::from_ref(t)) { - Ok(s) => s, - Err(_) => unreachable!(), - } + match try_cast_slice::(core::slice::from_ref(t)) { + Ok(s) => s, + Err(_) => unreachable!(), } } @@ -67,13 +63,9 @@ pub(crate) unsafe fn bytes_of(t: &T) -> &[u8] { /// empty slice might not match the pointer value of the input reference. #[inline] pub(crate) unsafe fn bytes_of_mut(t: &mut T) -> &mut [u8] { - if size_of::() == 0 { - &mut [] - } else { - match try_cast_slice_mut::(core::slice::from_mut(t)) { - Ok(s) => s, - Err(_) => unreachable!(), - } + match try_cast_slice_mut::(core::slice::from_mut(t)) { + Ok(s) => s, + Err(_) => unreachable!(), } } @@ -347,12 +339,11 @@ pub(crate) unsafe fn try_cast_mut( /// type, and the output slice wouldn't be a whole number of elements when /// accounting for the size change (eg: 3 `u16` values is 1.5 `u32` values, so /// that's a failure). -/// * Similarly, you can't convert between a [ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts) -/// and a non-ZST. #[inline] pub(crate) unsafe fn try_cast_slice( a: &[A], ) -> Result<&[B], PodCastError> { + let input_bytes = core::mem::size_of_val::<[A]>(a); // Note(Lokathor): everything with `align_of` and `size_of` will optimize away // after monomorphization. if align_of::() > align_of::() @@ -361,10 +352,11 @@ pub(crate) unsafe fn try_cast_slice( Err(PodCastError::TargetAlignmentGreaterAndInputNotAligned) } else if size_of::() == size_of::() { Ok(unsafe { core::slice::from_raw_parts(a.as_ptr() as *const B, a.len()) }) - } else if size_of::() == 0 || size_of::() == 0 { - Err(PodCastError::SizeMismatch) - } else if core::mem::size_of_val(a) % size_of::() == 0 { - let new_len = core::mem::size_of_val(a) / size_of::(); + } else if (size_of::() != 0 && input_bytes % size_of::() == 0) + || (size_of::() == 0 && input_bytes == 0) + { + let new_len = + if size_of::() != 0 { input_bytes / size_of::() } else { 0 }; Ok(unsafe { core::slice::from_raw_parts(a.as_ptr() as *const B, new_len) }) } else { Err(PodCastError::OutputSliceWouldHaveSlop) @@ -379,6 +371,7 @@ pub(crate) unsafe fn try_cast_slice( pub(crate) unsafe fn try_cast_slice_mut( a: &mut [A], ) -> Result<&mut [B], PodCastError> { + let input_bytes = core::mem::size_of_val::<[A]>(a); // Note(Lokathor): everything with `align_of` and `size_of` will optimize away // after monomorphization. if align_of::() > align_of::() @@ -389,10 +382,11 @@ pub(crate) unsafe fn try_cast_slice_mut( Ok(unsafe { core::slice::from_raw_parts_mut(a.as_mut_ptr() as *mut B, a.len()) }) - } else if size_of::() == 0 || size_of::() == 0 { - Err(PodCastError::SizeMismatch) - } else if core::mem::size_of_val(a) % size_of::() == 0 { - let new_len = core::mem::size_of_val(a) / size_of::(); + } else if (size_of::() != 0 && input_bytes % size_of::() == 0) + || (size_of::() == 0 && input_bytes == 0) + { + let new_len = + if size_of::() != 0 { input_bytes / size_of::() } else { 0 }; Ok(unsafe { core::slice::from_raw_parts_mut(a.as_mut_ptr() as *mut B, new_len) }) diff --git a/src/lib.rs b/src/lib.rs index 1526587..dfb8ae8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,15 +198,14 @@ pub use bytemuck_derive::{ /// The things that can go wrong when casting between [`Pod`] data forms. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PodCastError { - /// You tried to cast a slice to an element type with a higher alignment - /// requirement but the slice wasn't aligned. + /// You tried to cast a reference into a reference to a type with a higher alignment + /// requirement but the input reference wasn't aligned. TargetAlignmentGreaterAndInputNotAligned, - /// If the element size changes then the output slice changes length - /// accordingly. If the output slice wouldn't be a whole number of elements + /// If the element size of a slice changes, then the output slice changes length + /// accordingly. If the output slice wouldn't be a whole number of elements, /// then the conversion fails. OutputSliceWouldHaveSlop, - /// When casting a slice you can't convert between ZST elements and non-ZST - /// elements. When casting an individual `T`, `&T`, or `&mut T` value the + /// When casting an individual `T`, `&T`, or `&mut T` value the /// source size and destination size must be an exact match. SizeMismatch, /// For this type of cast the alignments must be exactly the same and they diff --git a/src/must.rs b/src/must.rs index 96661be..4c8c613 100644 --- a/src/must.rs +++ b/src/must.rs @@ -11,9 +11,9 @@ impl Cast { const ASSERT_ALIGN_GREATER_THAN_EQUAL: () = assert!(align_of::() >= align_of::()); const ASSERT_SIZE_EQUAL: () = assert!(size_of::() == size_of::()); - const ASSERT_SIZE_MULTIPLE_OF: () = assert!( - (size_of::() == 0) == (size_of::() == 0) - && (size_of::() % size_of::() == 0) + const ASSERT_SIZE_MULTIPLE_OF_OR_INPUT_ZST: () = assert!( + (size_of::() == 0) + || (size_of::() != 0 && size_of::() % size_of::() == 0) ); } @@ -113,8 +113,8 @@ pub fn must_cast_mut< /// * If the target type has a greater alignment requirement. /// * If the target element type doesn't evenly fit into the the current element /// type (eg: 3 `u16` values is 1.5 `u32` values, so that's a failure). -/// * Similarly, you can't convert between a [ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts) -/// and a non-ZST. +/// * Similarly, you can't convert from a non-[ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts) +/// to a ZST (e.g. 3 `u8` values is not any number of `()` values). /// /// ## Examples /// ``` @@ -122,6 +122,11 @@ pub fn must_cast_mut< /// // compiles: /// let bytes: &[u8] = bytemuck::must_cast_slice(indicies); /// ``` +/// ``` +/// let zsts: &[()] = &[(), (), ()]; +/// // compiles: +/// let bytes: &[u8] = bytemuck::must_cast_slice(zsts); +/// ``` /// ```compile_fail,E0080 /// # let bytes : &[u8] = &[1, 0, 2, 0, 3, 0]; /// // fails to compile (bytes.len() might not be a multiple of 2): @@ -132,9 +137,14 @@ pub fn must_cast_mut< /// // fails to compile (alignment requirements increased): /// let indicies : &[u16] = bytemuck::must_cast_slice(byte_pairs); /// ``` +/// ```compile_fail,E0080 +/// let bytes: &[u8] = &[]; +/// // fails to compile: (bytes.len() might not be 0) +/// let zsts: &[()] = bytemuck::must_cast_slice(bytes); +/// ``` #[inline] pub fn must_cast_slice(a: &[A]) -> &[B] { - let _ = Cast::::ASSERT_SIZE_MULTIPLE_OF; + let _ = Cast::::ASSERT_SIZE_MULTIPLE_OF_OR_INPUT_ZST; let _ = Cast::::ASSERT_ALIGN_GREATER_THAN_EQUAL; let new_len = if size_of::() == size_of::() { a.len() @@ -156,6 +166,11 @@ pub fn must_cast_slice(a: &[A]) -> &[B] { /// // compiles: /// let bytes: &mut [u8] = bytemuck::must_cast_slice_mut(indicies); /// ``` +/// ``` +/// let zsts: &mut [()] = &mut [(), (), ()]; +/// // compiles: +/// let bytes: &mut [u8] = bytemuck::must_cast_slice_mut(zsts); +/// ``` /// ```compile_fail,E0080 /// # let mut bytes = [1, 0, 2, 0, 3, 0]; /// # let bytes : &mut [u8] = &mut bytes[..]; @@ -168,6 +183,11 @@ pub fn must_cast_slice(a: &[A]) -> &[B] { /// // fails to compile (alignment requirements increased): /// let indicies : &mut [u16] = bytemuck::must_cast_slice_mut(byte_pairs); /// ``` +/// ```compile_fail,E0080 +/// let bytes: &mut [u8] = &mut []; +/// // fails to compile: (bytes.len() might not be 0) +/// let zsts: &mut [()] = bytemuck::must_cast_slice_mut(bytes); +/// ``` #[inline] pub fn must_cast_slice_mut< A: NoUninit + AnyBitPattern, @@ -175,7 +195,7 @@ pub fn must_cast_slice_mut< >( a: &mut [A], ) -> &mut [B] { - let _ = Cast::::ASSERT_SIZE_MULTIPLE_OF; + let _ = Cast::::ASSERT_SIZE_MULTIPLE_OF_OR_INPUT_ZST; let _ = Cast::::ASSERT_ALIGN_GREATER_THAN_EQUAL; let new_len = if size_of::() == size_of::() { a.len() diff --git a/tests/cast_slice_tests.rs b/tests/cast_slice_tests.rs index 016f8f8..0f94f0b 100644 --- a/tests/cast_slice_tests.rs +++ b/tests/cast_slice_tests.rs @@ -196,6 +196,30 @@ fn test_panics() { should_panic!(from_bytes::(&aligned_bytes[1..5])); } +#[test] +fn test_zsts() { + #[derive(Debug, Clone, Copy)] + struct MyZst; + unsafe impl Zeroable for MyZst {} + unsafe impl Pod for MyZst {} + assert_eq!(42, cast_slice::<(), MyZst>(&[(); 42]).len()); + assert_eq!(42, cast_slice_mut::<(), MyZst>(&mut [(); 42]).len()); + assert_eq!(0, cast_slice::<(), u8>(&[(); 42]).len()); + assert_eq!(0, cast_slice_mut::<(), u8>(&mut [(); 42]).len()); + assert_eq!(0, cast_slice::(&[]).len()); + assert_eq!(0, cast_slice_mut::(&mut []).len()); + + assert_eq!( + PodCastError::OutputSliceWouldHaveSlop, + try_cast_slice::(&[42]).unwrap_err() + ); + + assert_eq!( + PodCastError::OutputSliceWouldHaveSlop, + try_cast_slice_mut::(&mut [42]).unwrap_err() + ); +} + #[cfg(feature = "extern_crate_alloc")] #[test] fn test_boxed_slices() { @@ -209,7 +233,6 @@ fn test_boxed_slices() { result.expect_err("u16 and i8 have different alignment"); assert_eq!(error, PodCastError::AlignmentMismatch); - // FIXME(#253): Should these next two casts' errors be consistent? let result: Result<&[[i8; 3]], PodCastError> = try_cast_slice(&*boxed_i8_slice); let error = @@ -220,7 +243,7 @@ fn test_boxed_slices() { try_cast_slice_box(boxed_i8_slice); let (error, boxed_i8_slice) = result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s"); - assert_eq!(error, PodCastError::SizeMismatch); + assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop); let empty: Box<[()]> = cast_slice_box::(Box::new([])); assert!(empty.is_empty()); @@ -229,7 +252,7 @@ fn test_boxed_slices() { try_cast_slice_box(boxed_i8_slice); let (error, boxed_i8_slice) = result.expect_err("slice of ZST cannot be made from slice of 4 u8s"); - assert_eq!(error, PodCastError::SizeMismatch); + assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop); drop(boxed_i8_slice); @@ -254,7 +277,6 @@ fn test_rc_slices() { result.expect_err("u16 and i8 have different alignment"); assert_eq!(error, PodCastError::AlignmentMismatch); - // FIXME(#253): Should these next two casts' errors be consistent? let result: Result<&[[i8; 3]], PodCastError> = try_cast_slice(&*rc_i8_slice); let error = result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s"); @@ -264,7 +286,7 @@ fn test_rc_slices() { try_cast_slice_rc(rc_i8_slice); let (error, rc_i8_slice) = result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s"); - assert_eq!(error, PodCastError::SizeMismatch); + assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop); let empty: Rc<[()]> = cast_slice_rc::(Rc::new([])); assert!(empty.is_empty()); @@ -273,7 +295,7 @@ fn test_rc_slices() { try_cast_slice_rc(rc_i8_slice); let (error, rc_i8_slice) = result.expect_err("slice of ZST cannot be made from slice of 4 u8s"); - assert_eq!(error, PodCastError::SizeMismatch); + assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop); drop(rc_i8_slice); @@ -299,7 +321,6 @@ fn test_arc_slices() { result.expect_err("u16 and i8 have different alignment"); assert_eq!(error, PodCastError::AlignmentMismatch); - // FIXME(#253): Should these next two casts' errors be consistent? let result: Result<&[[i8; 3]], PodCastError> = try_cast_slice(&*arc_i8_slice); let error = result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s"); @@ -309,7 +330,7 @@ fn test_arc_slices() { try_cast_slice_arc(arc_i8_slice); let (error, arc_i8_slice) = result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s"); - assert_eq!(error, PodCastError::SizeMismatch); + assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop); let empty: Arc<[()]> = cast_slice_arc::(Arc::new([])); assert!(empty.is_empty()); @@ -318,7 +339,7 @@ fn test_arc_slices() { try_cast_slice_arc(arc_i8_slice); let (error, arc_i8_slice) = result.expect_err("slice of ZST cannot be made from slice of 4 u8s"); - assert_eq!(error, PodCastError::SizeMismatch); + assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop); drop(arc_i8_slice);