Skip to content

Commit

Permalink
Merge pull request #940 from jturner314/make-sliceinfo-sized
Browse files Browse the repository at this point in the history
Simplifications for slicing-related types
  • Loading branch information
bluss authored Mar 15, 2021
2 parents d399751 + fb7e94d commit 2bcbb2a
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 69 deletions.
4 changes: 2 additions & 2 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,8 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> {
/// Returns `true` iff the slices intersect.
pub fn slices_intersect<D: Dimension>(
dim: &D,
indices1: &impl SliceArg<D>,
indices2: &impl SliceArg<D>,
indices1: impl SliceArg<D>,
indices2: impl SliceArg<D>,
) -> bool {
debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim());
for (&axis_len, &si1, &si2) in izip!(
Expand Down
16 changes: 8 additions & 8 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ where
///
/// **Panics** if an index is out of bounds or step size is zero.<br>
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
pub fn slice<I>(&self, info: &I) -> ArrayView<'_, A, I::OutDim>
pub fn slice<I>(&self, info: I) -> ArrayView<'_, A, I::OutDim>
where
I: SliceArg<D> + ?Sized,
I: SliceArg<D>,
S: Data,
{
self.view().slice_move(info)
Expand All @@ -353,9 +353,9 @@ where
///
/// **Panics** if an index is out of bounds or step size is zero.<br>
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
pub fn slice_mut<I>(&mut self, info: &I) -> ArrayViewMut<'_, A, I::OutDim>
pub fn slice_mut<I>(&mut self, info: I) -> ArrayViewMut<'_, A, I::OutDim>
where
I: SliceArg<D> + ?Sized,
I: SliceArg<D>,
S: DataMut,
{
self.view_mut().slice_move(info)
Expand Down Expand Up @@ -399,9 +399,9 @@ where
///
/// **Panics** if an index is out of bounds or step size is zero.<br>
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
pub fn slice_move<I>(mut self, info: &I) -> ArrayBase<S, I::OutDim>
pub fn slice_move<I>(mut self, info: I) -> ArrayBase<S, I::OutDim>
where
I: SliceArg<D> + ?Sized,
I: SliceArg<D>,
{
assert_eq!(
info.in_ndim(),
Expand Down Expand Up @@ -468,9 +468,9 @@ where
/// - if [`AxisSliceInfo::NewAxis`] is in `info`, e.g. if [`NewAxis`] was
/// used in the [`s!`] macro
/// - if `D` is `IxDyn` and `info` does not match the number of array axes
pub fn slice_collapse<I>(&mut self, info: &I)
pub fn slice_collapse<I>(&mut self, info: I)
where
I: SliceArg<D> + ?Sized,
I: SliceArg<D>,
{
assert_eq!(
info.in_ndim(),
Expand Down
7 changes: 3 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,8 @@ pub type Ixs = isize;
///
/// The slicing argument can be passed using the macro [`s![]`](macro.s!.html),
/// which will be used in all examples. (The explicit form is an instance of
/// [`&SliceInfo`]; see its docs for more information.)
///
/// [`&SliceInfo`]: struct.SliceInfo.html
/// [`SliceInfo`] or another type which implements [`SliceArg`]; see their docs
/// for more information.)
///
/// If a range is used, the axis is preserved. If an index is used, that index
/// is selected and the axis is removed; this selects a subview. See
Expand All @@ -512,7 +511,7 @@ pub type Ixs = isize;
/// [`NewAxis`]: struct.NewAxis.html
///
/// When slicing arrays with generic dimensionality, creating an instance of
/// [`&SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`]
/// [`SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`]
/// is awkward. In these cases, it's usually more convenient to use
/// [`.slice_each_axis()`]/[`.slice_each_axis_mut()`]/[`.slice_each_axis_inplace()`]
/// or to create a view and then slice individual axes of the view using
Expand Down
88 changes: 40 additions & 48 deletions src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub struct NewAxis;
/// A slice (range with step), an index, or a new axis token.
///
/// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a
/// `&SliceInfo<[AxisSliceInfo; n], Din, Dout>`.
/// `SliceInfo<[AxisSliceInfo; n], Din, Dout>`.
///
/// ## Examples
///
Expand Down Expand Up @@ -324,6 +324,24 @@ pub unsafe trait SliceArg<D: Dimension>: AsRef<[AxisSliceInfo]> {
private_decl! {}
}

unsafe impl<T, D> SliceArg<D> for &T
where
T: SliceArg<D> + ?Sized,
D: Dimension,
{
type OutDim = T::OutDim;

fn in_ndim(&self) -> usize {
T::in_ndim(self)
}

fn out_ndim(&self) -> usize {
T::out_ndim(self)
}

private_impl! {}
}

macro_rules! impl_slicearg_samedim {
($in_dim:ty) => {
unsafe impl<T, Dout> SliceArg<$in_dim> for SliceInfo<T, $in_dim, Dout>
Expand Down Expand Up @@ -388,7 +406,7 @@ unsafe impl SliceArg<IxDyn> for [AxisSliceInfo] {

/// Represents all of the necessary information to perform a slice.
///
/// The type `T` is typically `[AxisSliceInfo; n]`, `[AxisSliceInfo]`, or
/// The type `T` is typically `[AxisSliceInfo; n]`, `&[AxisSliceInfo]`, or
/// `Vec<AxisSliceInfo>`. The type `Din` is the dimension of the array to be
/// sliced, and `Dout` is the output dimension after calling [`.slice()`]. Note
/// that if `Din` is a fixed dimension type (`Ix0`, `Ix1`, `Ix2`, etc.), the
Expand All @@ -397,14 +415,13 @@ unsafe impl SliceArg<IxDyn> for [AxisSliceInfo] {
///
/// [`.slice()`]: struct.ArrayBase.html#method.slice
#[derive(Debug)]
#[repr(transparent)]
pub struct SliceInfo<T: ?Sized, Din: Dimension, Dout: Dimension> {
pub struct SliceInfo<T, Din: Dimension, Dout: Dimension> {
in_dim: PhantomData<Din>,
out_dim: PhantomData<Dout>,
indices: T,
}

impl<T: ?Sized, Din, Dout> Deref for SliceInfo<T, Din, Dout>
impl<T, Din, Dout> Deref for SliceInfo<T, Din, Dout>
where
Din: Dimension,
Dout: Dimension,
Expand Down Expand Up @@ -464,14 +481,7 @@ where
indices,
}
}
}

impl<T, Din, Dout> SliceInfo<T, Din, Dout>
where
T: AsRef<[AxisSliceInfo]>,
Din: Dimension,
Dout: Dimension,
{
/// Returns a new `SliceInfo` instance.
///
/// Errors if `Din` or `Dout` is not consistent with `indices`.
Expand All @@ -490,14 +500,7 @@ where
indices,
})
}
}

impl<T: ?Sized, Din, Dout> SliceInfo<T, Din, Dout>
where
T: AsRef<[AxisSliceInfo]>,
Din: Dimension,
Dout: Dimension,
{
/// Returns the number of dimensions of the input array for
/// [`.slice()`](struct.ArrayBase.html#method.slice).
///
Expand Down Expand Up @@ -528,7 +531,7 @@ where
}
}

impl<'a, Din, Dout> TryFrom<&'a [AxisSliceInfo]> for &'a SliceInfo<[AxisSliceInfo], Din, Dout>
impl<'a, Din, Dout> TryFrom<&'a [AxisSliceInfo]> for SliceInfo<&'a [AxisSliceInfo], Din, Dout>
where
Din: Dimension,
Dout: Dimension,
Expand All @@ -537,16 +540,11 @@ where

fn try_from(
indices: &'a [AxisSliceInfo],
) -> Result<&'a SliceInfo<[AxisSliceInfo], Din, Dout>, ShapeError> {
check_dims_for_sliceinfo::<Din, Dout>(indices)?;
) -> Result<SliceInfo<&'a [AxisSliceInfo], Din, Dout>, ShapeError> {
unsafe {
// This is okay because we've already checked the correctness of
// `Din` and `Dout`, and the only non-zero-sized member of
// `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din,
// Dout>` should have the same bitwise representation as
// `&[AxisSliceInfo]`.
Ok(&*(indices as *const [AxisSliceInfo]
as *const SliceInfo<[AxisSliceInfo], Din, Dout>))
// This is okay because `&[AxisSliceInfo]` always returns the same
// value for `.as_ref()`.
Self::new(indices)
}
}
}
Expand Down Expand Up @@ -612,20 +610,18 @@ where
}
}

impl<T, Din, Dout> AsRef<SliceInfo<[AxisSliceInfo], Din, Dout>> for SliceInfo<T, Din, Dout>
impl<'a, T, Din, Dout> From<&'a SliceInfo<T, Din, Dout>>
for SliceInfo<&'a [AxisSliceInfo], Din, Dout>
where
T: AsRef<[AxisSliceInfo]>,
Din: Dimension,
Dout: Dimension,
{
fn as_ref(&self) -> &SliceInfo<[AxisSliceInfo], Din, Dout> {
unsafe {
// This is okay because the only non-zero-sized member of
// `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din, Dout>`
// should have the same bitwise representation as
// `&[AxisSliceInfo]`.
&*(self.indices.as_ref() as *const [AxisSliceInfo]
as *const SliceInfo<[AxisSliceInfo], Din, Dout>)
fn from(info: &'a SliceInfo<T, Din, Dout>) -> SliceInfo<&'a [AxisSliceInfo], Din, Dout> {
SliceInfo {
in_dim: info.in_dim,
out_dim: info.out_dim,
indices: info.indices.as_ref(),
}
}
}
Expand Down Expand Up @@ -703,9 +699,7 @@ impl_slicenextdim!((), NewAxis, Ix0, Ix1);
///
/// `s![]` takes a list of ranges/slices/indices/new-axes, separated by comma,
/// with optional step sizes that are separated from the range by a semicolon.
/// It is converted into a [`&SliceInfo`] instance.
///
/// [`&SliceInfo`]: struct.SliceInfo.html
/// It is converted into a [`SliceInfo`] instance.
///
/// Each range/slice/index uses signed indices, where a negative value is
/// counted from the end of the axis. Step sizes are also signed and may be
Expand Down Expand Up @@ -889,9 +883,7 @@ macro_rules! s(
<$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r).step_by($s as isize)
};
($($t:tt)*) => {
// The extra `*&` is a workaround for this compiler bug:
// https://github.com/rust-lang/rust/issues/23014
&*&$crate::s![@parse
$crate::s![@parse
::std::marker::PhantomData::<$crate::Ix0>,
::std::marker::PhantomData::<$crate::Ix0>,
[]
Expand Down Expand Up @@ -933,7 +925,7 @@ where
private_impl! {}
}

impl<'a, A, D, I0> MultiSliceArg<'a, A, D> for (&I0,)
impl<'a, A, D, I0> MultiSliceArg<'a, A, D> for (I0,)
where
A: 'a,
D: Dimension,
Expand All @@ -942,7 +934,7 @@ where
type Output = (ArrayViewMut<'a, A, I0::OutDim>,);

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
(view.slice_move(self.0),)
(view.slice_move(&self.0),)
}

private_impl! {}
Expand All @@ -953,7 +945,7 @@ macro_rules! impl_multislice_tuple {
impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last);
};
(@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => {
impl<'a, A, D, $($all,)*> MultiSliceArg<'a, A, D> for ($(&$all,)*)
impl<'a, A, D, $($all,)*> MultiSliceArg<'a, A, D> for ($($all,)*)
where
A: 'a,
D: Dimension,
Expand All @@ -963,7 +955,7 @@ macro_rules! impl_multislice_tuple {

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
#[allow(non_snake_case)]
let &($($all,)*) = self;
let ($($all,)*) = self;

let shape = view.raw_dim();
assert!(!impl_multislice_tuple!(@intersects_self &shape, ($($all,)*)));
Expand Down
12 changes: 6 additions & 6 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ fn test_slice_dyninput_array_fixed() {
#[test]
fn test_slice_array_dyn() {
let mut arr = Array3::<f64>::zeros((5, 2, 5));
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([
let info = SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
Expand All @@ -229,7 +229,7 @@ fn test_slice_array_dyn() {
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
let info2 = &SliceInfo::<_, Ix3, IxDyn>::try_from([
let info2 = SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(..).step_by(2),
Expand All @@ -241,7 +241,7 @@ fn test_slice_array_dyn() {
#[test]
fn test_slice_dyninput_array_dyn() {
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([
let info = SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
Expand All @@ -251,7 +251,7 @@ fn test_slice_dyninput_array_dyn() {
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
let info2 = &SliceInfo::<_, Ix3, IxDyn>::try_from([
let info2 = SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(..).step_by(2),
Expand All @@ -273,7 +273,7 @@ fn test_slice_dyninput_vec_fixed() {
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
let info2 = &SliceInfo::<_, Ix3, Ix2>::try_from(vec![
let info2 = SliceInfo::<_, Ix3, Ix2>::try_from(vec![
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(..).step_by(2),
Expand All @@ -295,7 +295,7 @@ fn test_slice_dyninput_vec_dyn() {
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
let info2 = &SliceInfo::<_, Ix3, IxDyn>::try_from(vec![
let info2 = SliceInfo::<_, Ix3, IxDyn>::try_from(vec![
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(..).step_by(2),
Expand Down
2 changes: 1 addition & 1 deletion tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ fn scaled_add_3() {

{
let mut av = a.slice_mut(s![..;s1, ..;s2]);
let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap());
let c = c.slice(SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap());

let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
answerv += &(beta * &c);
Expand Down

0 comments on commit 2bcbb2a

Please sign in to comment.