diff --git a/benches/iter.rs b/benches/iter.rs index 289f1fb50..24bda95b9 100644 --- a/benches/iter.rs +++ b/benches/iter.rs @@ -370,3 +370,23 @@ fn iter_axis_chunks_5_iter_sum(bench: &mut Bencher) { .sum::() }); } + +pub fn zip_mut_with(data: &Array3, out: &mut Array3) { + out.zip_mut_with(&data, |o, &i| { + *o = i; + }); +} + +#[bench] +fn zip_mut_with_cc(b: &mut Bencher) { + let data: Array3 = Array3::zeros((ISZ, ISZ, ISZ)); + let mut out = Array3::zeros(data.dim()); + b.iter(|| black_box(zip_mut_with(&data, &mut out))); +} + +#[bench] +fn zip_mut_with_ff(b: &mut Bencher) { + let data: Array3 = Array3::zeros((ISZ, ISZ, ISZ).f()); + let mut out = Array3::zeros(data.dim().f()); + b.iter(|| black_box(zip_mut_with(&data, &mut out))); +} diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 4bfe7c0b2..0bc03f9d5 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -229,6 +229,25 @@ pub trait Dimension: !end_iteration } + /// Returns `true` iff `strides1` and `strides2` are equivalent for the + /// shape `self`. + /// + /// The strides are equivalent if, for each axis with length > 1, the + /// strides are equal. + /// + /// Note: Returns `false` if any of the ndims don't match. + #[doc(hidden)] + fn strides_equivalent(&self, strides1: &Self, strides2: &D) -> bool + where + D: Dimension, + { + let shape_ndim = self.ndim(); + shape_ndim == strides1.ndim() + && shape_ndim == strides2.ndim() + && izip!(self.slice(), strides1.slice(), strides2.slice()) + .all(|(&d, &s1, &s2)| d <= 1 || s1 as isize == s2 as isize) + } + #[doc(hidden)] /// Return stride offset for index. fn stride_offset(index: &Self, strides: &Self) -> isize { diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 027f5a8af..78c4dd0af 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -6,7 +6,6 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::cmp; use std::ptr as std_ptr; use std::slice; @@ -1917,18 +1916,19 @@ where F: FnMut(&mut A, &B), { debug_assert_eq!(self.shape(), rhs.shape()); - if let Some(self_s) = self.as_slice_mut() { - if let Some(rhs_s) = rhs.as_slice() { - let len = cmp::min(self_s.len(), rhs_s.len()); - let s = &mut self_s[..len]; - let r = &rhs_s[..len]; - for i in 0..len { - f(&mut s[i], &r[i]); + + if self.dim.strides_equivalent(&self.strides, &rhs.strides) { + if let Some(self_s) = self.as_slice_memory_order_mut() { + if let Some(rhs_s) = rhs.as_slice_memory_order() { + for (s, r) in self_s.iter_mut().zip(rhs_s) { + f(s, &r); + } + return; } - return; } } - // otherwise, fall back to the outer iter + + // Otherwise, fall back to the outer iter self.zip_mut_with_by_rows(rhs, f); }