From b36528e75ffbd1bb90d061e4db5efa5209e59b64 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Wed, 3 Feb 2021 15:29:58 -0500 Subject: [PATCH 1/5] Rename unordered_foreach_mut to for_each_mut --- src/impl_methods.rs | 25 +++++++++++++++++++++---- src/impl_ops.rs | 10 +++++----- src/lib.rs | 16 ---------------- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 7b6f3f6f5..2aa7bf637 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1976,7 +1976,7 @@ where S: DataMut, A: Clone, { - self.unordered_foreach_mut(move |elt| *elt = x.clone()); + self.for_each_mut(move |elt| *elt = x.clone()); } fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) @@ -2028,7 +2028,7 @@ where S: DataMut, F: FnMut(&mut A, &B), { - self.unordered_foreach_mut(move |elt| f(elt, rhs_elem)); + self.for_each_mut(move |elt| f(elt, rhs_elem)); } /// Traverse two arrays in unspecified order, in lock step, @@ -2205,7 +2205,7 @@ where S: DataMut, F: FnMut(&mut A), { - self.unordered_foreach_mut(f); + self.for_each_mut(f); } /// Modify the array in place by calling `f` by **v**alue on each element. @@ -2235,7 +2235,7 @@ where F: FnMut(A) -> A, A: Clone, { - self.unordered_foreach_mut(move |x| *x = f(x.clone())); + self.for_each_mut(move |x| *x = f(x.clone())); } /// Call `f` for each element in the array. @@ -2250,6 +2250,23 @@ where self.fold((), move |(), elt| f(elt)) } + /// Call `f` for each element in the array. + /// + /// Elements are visited in arbitrary order. + pub(crate) fn for_each_mut(&mut self, mut f: F) + where + S: DataMut, + F: FnMut(&mut A), + { + if let Some(slc) = self.as_slice_memory_order_mut() { + slc.iter_mut().for_each(f); + } else { + for row in self.inner_rows_mut() { + row.into_iter_().fold((), |(), elt| f(elt)); + } + } + } + /// Visit each element in the array by calling `f` by reference /// on each element. /// diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 51d432ee6..1cc2b652c 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -141,7 +141,7 @@ impl $trt for ArrayBase { type Output = ArrayBase; fn $mth(mut self, x: B) -> ArrayBase { - self.unordered_foreach_mut(move |elt| { + self.for_each_mut(move |elt| { *elt = elt.clone() $operator x.clone(); }); self @@ -194,7 +194,7 @@ impl $trt> for $scalar rhs.$mth(self) } or {{ let mut rhs = rhs; - rhs.unordered_foreach_mut(move |elt| { + rhs.for_each_mut(move |elt| { *elt = self $operator *elt; }); rhs @@ -299,7 +299,7 @@ mod arithmetic_ops { type Output = Self; /// Perform an elementwise negation of `self` and return the result. fn neg(mut self) -> Self { - self.unordered_foreach_mut(|elt| { + self.for_each_mut(|elt| { *elt = -elt.clone(); }); self @@ -329,7 +329,7 @@ mod arithmetic_ops { type Output = Self; /// Perform an elementwise unary not of `self` and return the result. fn not(mut self) -> Self { - self.unordered_foreach_mut(|elt| { + self.for_each_mut(|elt| { *elt = !elt.clone(); }); self @@ -386,7 +386,7 @@ mod assign_ops { D: Dimension, { fn $method(&mut self, rhs: A) { - self.unordered_foreach_mut(move |elt| { + self.for_each_mut(move |elt| { elt.$method(rhs.clone()); }); } diff --git a/src/lib.rs b/src/lib.rs index c71f09aaf..66c813401 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1544,22 +1544,6 @@ where self.strides.clone() } - /// Apply closure `f` to each element in the array, in whatever - /// order is the fastest to visit. - fn unordered_foreach_mut(&mut self, mut f: F) - where - S: DataMut, - F: FnMut(&mut A), - { - if let Some(slc) = self.as_slice_memory_order_mut() { - slc.iter_mut().for_each(f); - } else { - for row in self.inner_rows_mut() { - row.into_iter_().fold((), |(), elt| f(elt)); - } - } - } - /// Remove array axis `axis` and return the result. fn try_remove_axis(self, axis: Axis) -> ArrayBase { let d = self.dim.try_remove_axis(axis); From e117049ecd14bd468650811b2e0293f7bb15854d Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Wed, 3 Feb 2021 15:32:41 -0500 Subject: [PATCH 2/5] Make for_each_mut public --- src/doc/ndarray_for_numpy_users/mod.rs | 5 +++-- src/impl_methods.rs | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index 45ef3ed06..3c6c40140 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -282,8 +282,8 @@ //! Note that [`.mapv()`][.mapv()] has corresponding methods [`.map()`][.map()], //! [`.mapv_into()`][.mapv_into()], [`.map_inplace()`][.map_inplace()], and //! [`.mapv_inplace()`][.mapv_inplace()]. Also look at [`.fold()`][.fold()], -//! [`.for_each()`][.for_each()], [`.fold_axis()`][.fold_axis()], and -//! [`.map_axis()`][.map_axis()]. +//! [`.for_each()`][.for_each()], [`.for_each_mut()`][.for_each_mut()], +//! [`.fold_axis()`][.fold_axis()], and [`.map_axis()`][.map_axis()]. //! //! //!
@@ -649,6 +649,7 @@ //! [.t()]: ../../struct.ArrayBase.html#method.t //! [vec-* dot]: ../../struct.ArrayBase.html#method.dot //! [.for_each()]: ../../struct.ArrayBase.html#method.for_each +//! [.for_each_mut()]: ../../struct.ArrayBase.html#method.for_each_mut //! [::zeros()]: ../../struct.ArrayBase.html#method.zeros //! [Zip]: ../../struct.Zip.html diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 2aa7bf637..c584b0bfe 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2253,7 +2253,7 @@ where /// Call `f` for each element in the array. /// /// Elements are visited in arbitrary order. - pub(crate) fn for_each_mut(&mut self, mut f: F) + pub fn for_each_mut(&mut self, mut f: F) where S: DataMut, F: FnMut(&mut A), From 7ceafef2286fb130087d1f6d68de544c31132cbe Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Wed, 3 Feb 2021 15:56:49 -0500 Subject: [PATCH 3/5] Unify implementation of fold and for_each_mut --- src/dimension/mod.rs | 30 ++++++++++++++++++++++++++++++ src/impl_methods.rs | 34 +++++++--------------------------- src/lib.rs | 11 +---------- 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 3b14ea221..1359b8f39 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -678,6 +678,36 @@ where } } +/// Move the axis which has the smallest absolute stride and a length +/// greater than one to be the last axis. +pub fn move_min_stride_axis_to_last(dim: &mut D, strides: &mut D) +where + D: Dimension, +{ + debug_assert_eq!(dim.ndim(), strides.ndim()); + match dim.ndim() { + 0 | 1 => {} + 2 => { + if dim[1] <= 1 + || dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs() + { + dim.slice_mut().swap(0, 1); + strides.slice_mut().swap(0, 1); + } + } + n => { + if let Some(min_stride_axis) = (0..n) + .filter(|&ax| dim[ax] > 1) + .min_by_key(|&ax| (strides[ax] as isize).abs()) + { + let last = n - 1; + dim.slice_mut().swap(last, min_stride_axis); + strides.slice_mut().swap(last, min_stride_axis); + } + } + } +} + #[cfg(test)] mod test { use super::{ diff --git a/src/impl_methods.rs b/src/impl_methods.rs index c584b0bfe..b6c12bd5a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -18,8 +18,8 @@ use crate::arraytraits; use crate::dimension; use crate::dimension::IntoDimension; use crate::dimension::{ - abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked, - stride_offset, Axes, + abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last, + offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; use crate::error::{self, ErrorKind, ShapeError}; use crate::math_cell::MathCell; @@ -2070,27 +2070,7 @@ where slc.iter().fold(init, f) } else { let mut v = self.view(); - // put the narrowest axis at the last position - match v.ndim() { - 0 | 1 => {} - 2 => { - if self.len_of(Axis(1)) <= 1 - || self.len_of(Axis(0)) > 1 - && self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs() - { - v.swap_axes(0, 1); - } - } - n => { - let last = n - 1; - let narrow_axis = v - .axes() - .filter(|ax| ax.len() > 1) - .min_by_key(|ax| ax.stride().abs()) - .map_or(last, |ax| ax.axis().index()); - v.swap_axes(last, narrow_axis); - } - } + move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); v.into_elements_base().fold(init, f) } } @@ -2253,7 +2233,7 @@ where /// Call `f` for each element in the array. /// /// Elements are visited in arbitrary order. - pub fn for_each_mut(&mut self, mut f: F) + pub fn for_each_mut(&mut self, f: F) where S: DataMut, F: FnMut(&mut A), @@ -2261,9 +2241,9 @@ where if let Some(slc) = self.as_slice_memory_order_mut() { slc.iter_mut().for_each(f); } else { - for row in self.inner_rows_mut() { - row.into_iter_().fold((), |(), elt| f(elt)); - } + let mut v = self.view_mut(); + move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); + v.into_elements_base().for_each(f); } } diff --git a/src/lib.rs b/src/lib.rs index 66c813401..8faaf5ba7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -143,7 +143,7 @@ pub use crate::indexes::{indices, indices_of}; pub use crate::slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex}; use crate::iterators::Baseiter; -use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, LanesMut}; +use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes}; pub use crate::arraytraits::AsArray; #[cfg(feature = "std")] @@ -1561,15 +1561,6 @@ where let n = self.ndim(); Lanes::new(self.view(), Axis(n.saturating_sub(1))) } - - /// n-d generalization of rows, just like inner iter - fn inner_rows_mut(&mut self) -> iterators::LanesMut<'_, A, D::Smaller> - where - S: DataMut, - { - let n = self.ndim(); - LanesMut::new(self.view_mut(), Axis(n.saturating_sub(1))) - } } // parallel methods From 2931a1bf6297f088fa4f31c7a8d824c98bce9a1d Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Wed, 3 Feb 2021 16:53:12 -0500 Subject: [PATCH 4/5] Lengthen lifetimes of for_each_mut element borrows It's unfortunate that the `.is_contiguous()` check is called twice (once explicitly, and once in `.as_slice_memory_order_mut()). The compiler rejects the following: if let Some(slc) = self.as_slice_memory_order_mut() { slc.iter_mut().for_each(f); } else { let mut v = self.view_mut(); ... } So, the chosen implementation is a safe compromise. --- src/impl_methods.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index b6c12bd5a..c73e1d39a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2233,12 +2233,14 @@ where /// Call `f` for each element in the array. /// /// Elements are visited in arbitrary order. - pub fn for_each_mut(&mut self, f: F) + pub fn for_each_mut<'a, F>(&'a mut self, f: F) where + F: FnMut(&'a mut A), + A: 'a, S: DataMut, - F: FnMut(&mut A), { - if let Some(slc) = self.as_slice_memory_order_mut() { + if self.is_contiguous() { + let slc = self.as_slice_memory_order_mut().unwrap(); slc.iter_mut().for_each(f); } else { let mut v = self.view_mut(); From 9131e29d60630c5186cd81019006bb656b352306 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Wed, 3 Feb 2021 19:58:17 -0500 Subject: [PATCH 5/5] Remove duplicate check in for_each_mut --- src/impl_methods.rs | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index c73e1d39a..ddce77a25 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1456,6 +1456,15 @@ where /// Return the array’s data as a slice if it is contiguous, /// return `None` otherwise. pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> + where + S: DataMut, + { + self.try_as_slice_memory_order_mut().ok() + } + + /// Return the array’s data as a slice if it is contiguous, otherwise + /// return `self` in the `Err` variant. + pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> where S: DataMut, { @@ -1463,13 +1472,13 @@ where self.ensure_unique(); let offset = offset_from_ptr_to_memory(&self.dim, &self.strides); unsafe { - Some(slice::from_raw_parts_mut( + Ok(slice::from_raw_parts_mut( self.ptr.offset(offset).as_ptr(), self.len(), )) } } else { - None + Err(self) } } @@ -2239,13 +2248,13 @@ where A: 'a, S: DataMut, { - if self.is_contiguous() { - let slc = self.as_slice_memory_order_mut().unwrap(); - slc.iter_mut().for_each(f); - } else { - let mut v = self.view_mut(); - move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); - v.into_elements_base().for_each(f); + match self.try_as_slice_memory_order_mut() { + Ok(slc) => slc.iter_mut().for_each(f), + Err(arr) => { + let mut v = arr.view_mut(); + move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); + v.into_elements_base().for_each(f); + } } }