Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add for_each_mut method #910

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,36 @@ where
}
}

/// Move the axis which has the smallest absolute stride and a length
/// greater than one to be the last axis.
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
where
D: Dimension,
{
debug_assert_eq!(dim.ndim(), strides.ndim());
match dim.ndim() {
0 | 1 => {}
2 => {
if dim[1] <= 1
|| dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
{
dim.slice_mut().swap(0, 1);
strides.slice_mut().swap(0, 1);
}
}
n => {
if let Some(min_stride_axis) = (0..n)
.filter(|&ax| dim[ax] > 1)
.min_by_key(|&ax| (strides[ax] as isize).abs())
{
let last = n - 1;
dim.slice_mut().swap(last, min_stride_axis);
strides.slice_mut().swap(last, min_stride_axis);
}
}
}
}

#[cfg(test)]
mod test {
use super::{
Expand Down
5 changes: 3 additions & 2 deletions src/doc/ndarray_for_numpy_users/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@
//! Note that [`.mapv()`][.mapv()] has corresponding methods [`.map()`][.map()],
//! [`.mapv_into()`][.mapv_into()], [`.map_inplace()`][.map_inplace()], and
//! [`.mapv_inplace()`][.mapv_inplace()]. Also look at [`.fold()`][.fold()],
//! [`.for_each()`][.for_each()], [`.fold_axis()`][.fold_axis()], and
//! [`.map_axis()`][.map_axis()].
//! [`.for_each()`][.for_each()], [`.for_each_mut()`][.for_each_mut()],
//! [`.fold_axis()`][.fold_axis()], and [`.map_axis()`][.map_axis()].
//!
//! <table>
//! <tr><th>
Expand Down Expand Up @@ -649,6 +649,7 @@
//! [.t()]: ../../struct.ArrayBase.html#method.t
//! [vec-* dot]: ../../struct.ArrayBase.html#method.dot
//! [.for_each()]: ../../struct.ArrayBase.html#method.for_each
//! [.for_each_mut()]: ../../struct.ArrayBase.html#method.for_each_mut
//! [::zeros()]: ../../struct.ArrayBase.html#method.zeros
//! [Zip]: ../../struct.Zip.html

Expand Down
66 changes: 37 additions & 29 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use crate::arraytraits;
use crate::dimension;
use crate::dimension::IntoDimension;
use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked,
stride_offset, Axes,
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
};
use crate::error::{self, ErrorKind, ShapeError};
use crate::math_cell::MathCell;
Expand Down Expand Up @@ -1456,20 +1456,29 @@ where
/// Return the array’s data as a slice if it is contiguous,
/// return `None` otherwise.
pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]>
where
S: DataMut,
{
self.try_as_slice_memory_order_mut().ok()
}

/// Return the array’s data as a slice if it is contiguous, otherwise
/// return `self` in the `Err` variant.
pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self>
where
S: DataMut,
{
if self.is_contiguous() {
self.ensure_unique();
let offset = offset_from_ptr_to_memory(&self.dim, &self.strides);
unsafe {
Some(slice::from_raw_parts_mut(
Ok(slice::from_raw_parts_mut(
self.ptr.offset(offset).as_ptr(),
self.len(),
))
}
} else {
None
Err(self)
}
}

Expand Down Expand Up @@ -1976,7 +1985,7 @@ where
S: DataMut,
A: Clone,
{
self.unordered_foreach_mut(move |elt| *elt = x.clone());
self.for_each_mut(move |elt| *elt = x.clone());
}

fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
Expand Down Expand Up @@ -2028,7 +2037,7 @@ where
S: DataMut,
F: FnMut(&mut A, &B),
{
self.unordered_foreach_mut(move |elt| f(elt, rhs_elem));
self.for_each_mut(move |elt| f(elt, rhs_elem));
}

/// Traverse two arrays in unspecified order, in lock step,
Expand Down Expand Up @@ -2070,27 +2079,7 @@ where
slc.iter().fold(init, f)
} else {
let mut v = self.view();
// put the narrowest axis at the last position
match v.ndim() {
0 | 1 => {}
2 => {
if self.len_of(Axis(1)) <= 1
|| self.len_of(Axis(0)) > 1
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
{
v.swap_axes(0, 1);
}
}
n => {
let last = n - 1;
let narrow_axis = v
.axes()
.filter(|ax| ax.len() > 1)
.min_by_key(|ax| ax.stride().abs())
.map_or(last, |ax| ax.axis().index());
v.swap_axes(last, narrow_axis);
}
}
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
v.into_elements_base().fold(init, f)
}
}
Expand Down Expand Up @@ -2205,7 +2194,7 @@ where
S: DataMut,
F: FnMut(&mut A),
{
self.unordered_foreach_mut(f);
self.for_each_mut(f);
}

/// Modify the array in place by calling `f` by **v**alue on each element.
Expand Down Expand Up @@ -2235,7 +2224,7 @@ where
F: FnMut(A) -> A,
A: Clone,
{
self.unordered_foreach_mut(move |x| *x = f(x.clone()));
self.for_each_mut(move |x| *x = f(x.clone()));
}

/// Call `f` for each element in the array.
Expand All @@ -2250,6 +2239,25 @@ where
self.fold((), move |(), elt| f(elt))
}

/// Call `f` for each element in the array.
///
/// Elements are visited in arbitrary order.
pub fn for_each_mut<'a, F>(&'a mut self, f: F)
where
F: FnMut(&'a mut A),
A: 'a,
S: DataMut,
{
match self.try_as_slice_memory_order_mut() {
Ok(slc) => slc.iter_mut().for_each(f),
Err(arr) => {
let mut v = arr.view_mut();
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
v.into_elements_base().for_each(f);
}
}
}

/// Visit each element in the array by calling `f` by reference
/// on each element.
///
Expand Down
10 changes: 5 additions & 5 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
{
type Output = ArrayBase<S, D>;
fn $mth(mut self, x: B) -> ArrayBase<S, D> {
self.unordered_foreach_mut(move |elt| {
self.for_each_mut(move |elt| {
*elt = elt.clone() $operator x.clone();
});
self
Expand Down Expand Up @@ -194,7 +194,7 @@ impl<S, D> $trt<ArrayBase<S, D>> for $scalar
rhs.$mth(self)
} or {{
let mut rhs = rhs;
rhs.unordered_foreach_mut(move |elt| {
rhs.for_each_mut(move |elt| {
*elt = self $operator *elt;
});
rhs
Expand Down Expand Up @@ -299,7 +299,7 @@ mod arithmetic_ops {
type Output = Self;
/// Perform an elementwise negation of `self` and return the result.
fn neg(mut self) -> Self {
self.unordered_foreach_mut(|elt| {
self.for_each_mut(|elt| {
*elt = -elt.clone();
});
self
Expand Down Expand Up @@ -329,7 +329,7 @@ mod arithmetic_ops {
type Output = Self;
/// Perform an elementwise unary not of `self` and return the result.
fn not(mut self) -> Self {
self.unordered_foreach_mut(|elt| {
self.for_each_mut(|elt| {
*elt = !elt.clone();
});
self
Expand Down Expand Up @@ -386,7 +386,7 @@ mod assign_ops {
D: Dimension,
{
fn $method(&mut self, rhs: A) {
self.unordered_foreach_mut(move |elt| {
self.for_each_mut(move |elt| {
elt.$method(rhs.clone());
});
}
Expand Down
27 changes: 1 addition & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ pub use crate::indexes::{indices, indices_of};
pub use crate::slice::{Slice, SliceInfo, SliceNextDim, SliceOrIndex};

use crate::iterators::Baseiter;
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, LanesMut};
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes};

pub use crate::arraytraits::AsArray;
#[cfg(feature = "std")]
Expand Down Expand Up @@ -1544,22 +1544,6 @@ where
self.strides.clone()
}

/// Apply closure `f` to each element in the array, in whatever
/// order is the fastest to visit.
fn unordered_foreach_mut<F>(&mut self, mut f: F)
where
S: DataMut,
F: FnMut(&mut A),
{
if let Some(slc) = self.as_slice_memory_order_mut() {
slc.iter_mut().for_each(f);
} else {
for row in self.inner_rows_mut() {
row.into_iter_().fold((), |(), elt| f(elt));
}
}
}

/// Remove array axis `axis` and return the result.
fn try_remove_axis(self, axis: Axis) -> ArrayBase<S, D::Smaller> {
let d = self.dim.try_remove_axis(axis);
Expand All @@ -1577,15 +1561,6 @@ where
let n = self.ndim();
Lanes::new(self.view(), Axis(n.saturating_sub(1)))
}

/// n-d generalization of rows, just like inner iter
fn inner_rows_mut(&mut self) -> iterators::LanesMut<'_, A, D::Smaller>
where
S: DataMut,
{
let n = self.ndim();
LanesMut::new(self.view_mut(), Axis(n.saturating_sub(1)))
}
}

// parallel methods
Expand Down