diff --git a/src/impl_methods.rs b/src/impl_methods.rs index c005ff22d..4fbfcd98a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2231,4 +2231,60 @@ where }) } } + + /// Iterates over pairs of consecutive elements along the axis. + /// + /// The first argument to the closure is an element, and the second + /// argument is the next element along the axis. Iteration is guaranteed to + /// proceed in order along the specified axis, but in all other respects + /// the iteration order is unspecified. + /// + /// # Example + /// + /// For example, this can be used to compute the cumulative sum along an + /// axis: + /// + /// ``` + /// use ndarray::{array, Axis}; + /// + /// let mut arr = array![ + /// [[1, 2], [3, 4], [5, 6]], + /// [[7, 8], [9, 10], [11, 12]], + /// ]; + /// arr.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev); + /// assert_eq!( + /// arr, + /// array![ + /// [[1, 2], [4, 6], [9, 12]], + /// [[7, 8], [16, 18], [27, 30]], + /// ], + /// ); + /// ``` + pub fn accumulate_axis_inplace<F>(&mut self, axis: Axis, mut f: F) + where + F: FnMut(&A, &mut A), + S: DataMut, + { + if self.len_of(axis) <= 1 { + return; + } + let mut curr = self.raw_view_mut(); // mut borrow of the array here + let mut prev = curr.raw_view(); // derive further raw views from the same borrow + prev.slice_axis_inplace(axis, Slice::from(..-1)); + curr.slice_axis_inplace(axis, Slice::from(1..)); + // This implementation relies on `Zip` iterating along `axis` in order. + Zip::from(prev).and(curr).apply(|prev, curr| unsafe { + // These pointer dereferences and borrows are safe because: + // + // 1. They're pointers to elements in the array. + // + // 2. `S: DataMut` guarantees that elements are safe to borrow + // mutably and that they don't alias. + // + // 3. The lifetimes of the borrows last only for the duration + // of the call to `f`, so aliasing across calls to `f` + // cannot occur. + f(&*prev, &mut *curr) + }); + } } diff --git a/src/zip/mod.rs b/src/zip/mod.rs index 940d87cc3..ccc123db7 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -47,7 +47,7 @@ where impl<S, D> ArrayBase<S, D> where - S: Data, + S: RawData, D: Dimension, { pub(crate) fn layout_impl(&self) -> Layout { @@ -57,7 +57,7 @@ where } else { CORDER } - } else if self.ndim() > 1 && self.t().is_standard_layout() { + } else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() { FORDER } else { 0 @@ -192,6 +192,14 @@ pub trait Offset: Copy { private_decl! {} } +impl<T> Offset for *const T { + type Stride = isize; + unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self { + self.offset(s * (index as isize)) + } + private_impl! {} +} + impl<T> Offset for *mut T { type Stride = isize; unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self { @@ -389,6 +397,112 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> { } } +impl<A, D: Dimension> NdProducer for RawArrayView<A, D> { + type Item = *const A; + type Dim = D; + type Ptr = *const A; + type Stride = isize; + + private_impl! {} + #[doc(hidden)] + fn raw_dim(&self) -> Self::Dim { + self.raw_dim() + } + + #[doc(hidden)] + fn equal_dim(&self, dim: &Self::Dim) -> bool { + self.dim.equal(dim) + } + + #[doc(hidden)] + fn as_ptr(&self) -> *const A { + self.as_ptr() + } + + #[doc(hidden)] + fn layout(&self) -> Layout { + self.layout_impl() + } + + #[doc(hidden)] + unsafe fn as_ref(&self, ptr: *const A) -> *const A { + ptr + } + + #[doc(hidden)] + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A { + self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + } + + #[doc(hidden)] + fn stride_of(&self, axis: Axis) -> isize { + self.stride_of(axis) + } + + #[inline(always)] + fn contiguous_stride(&self) -> Self::Stride { + 1 + } + + #[doc(hidden)] + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { + self.split_at(axis, index) + } +} + +impl<A, D: Dimension> NdProducer for RawArrayViewMut<A, D> { + type Item = *mut A; + type Dim = D; + type Ptr = *mut A; + type Stride = isize; + + private_impl! {} + #[doc(hidden)] + fn raw_dim(&self) -> Self::Dim { + self.raw_dim() + } + + #[doc(hidden)] + fn equal_dim(&self, dim: &Self::Dim) -> bool { + self.dim.equal(dim) + } + + #[doc(hidden)] + fn as_ptr(&self) -> *mut A { + self.as_ptr() as _ + } + + #[doc(hidden)] + fn layout(&self) -> Layout { + self.layout_impl() + } + + #[doc(hidden)] + unsafe fn as_ref(&self, ptr: *mut A) -> *mut A { + ptr + } + + #[doc(hidden)] + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { + self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + } + + #[doc(hidden)] + fn stride_of(&self, axis: Axis) -> isize { + self.stride_of(axis) + } + + #[inline(always)] + fn contiguous_stride(&self) -> Self::Stride { + 1 + } + + #[doc(hidden)] + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) { + self.split_at(axis, index) + } +} + /// Lock step function application across several arrays or other producers. /// /// Zip allows matching several producers to each other elementwise and applying diff --git a/tests/array.rs b/tests/array.rs index ea5c1e82a..807253104 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1098,7 +1098,7 @@ fn owned_array_with_stride() { #[test] fn owned_array_discontiguous() { - use ::std::iter::repeat; + use std::iter::repeat; let v: Vec<_> = (0..12).flat_map(|x| repeat(x).take(2)).collect(); let dim = (3, 2, 2); let strides = (8, 4, 2); @@ -1111,9 +1111,9 @@ fn owned_array_discontiguous() { #[test] fn owned_array_discontiguous_drop() { - use ::std::cell::RefCell; - use ::std::collections::BTreeSet; - use ::std::rc::Rc; + use std::cell::RefCell; + use std::collections::BTreeSet; + use std::rc::Rc; struct InsertOnDrop<T: Ord>(Rc<RefCell<BTreeSet<T>>>, Option<T>); impl<T: Ord> Drop for InsertOnDrop<T> { @@ -1952,6 +1952,48 @@ fn test_map_axis() { itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4); } +#[test] +fn test_accumulate_axis_inplace_noop() { + let mut a = Array2::<u8>::zeros((0, 3)); + a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); + assert_eq!(a, Array2::zeros((0, 3))); + + let mut a = Array2::<u8>::zeros((3, 1)); + a.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev); + assert_eq!(a, Array2::zeros((3, 1))); +} + +#[rustfmt::skip] // Allow block array formatting +#[test] +fn test_accumulate_axis_inplace_nonstandard_layout() { + let a = arr2(&[[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10,11,12]]); + + let mut a_t = a.clone().reversed_axes(); + a_t.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); + assert_eq!(a_t, aview2(&[[1, 4, 7, 10], + [3, 9, 15, 21], + [6, 15, 24, 33]])); + + let mut a0 = a.clone(); + a0.invert_axis(Axis(0)); + a0.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev); + assert_eq!(a0, aview2(&[[10, 11, 12], + [17, 19, 21], + [21, 24, 27], + [22, 26, 30]])); + + let mut a1 = a.clone(); + a1.invert_axis(Axis(1)); + a1.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev); + assert_eq!(a1, aview2(&[[3, 5, 6], + [6, 11, 15], + [9, 17, 24], + [12, 23, 33]])); +} + #[test] fn test_to_vec() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);