Skip to content

Commit

Permalink
Merge pull request #728 from rust-ndarray/rand-from-fn-opt
Browse files Browse the repository at this point in the history
Faster version of from_shape_fn and speed up ndarray-rand
bluss authored Oct 13, 2019
2 parents 6ee8853 + c516803 commit 792e17c
Showing 4 changed files with 54 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ndarray-rand/src/lib.rs
Original file line number Diff line number Diff line change
@@ -246,7 +246,7 @@ where
R: Rng + ?Sized,
Sh: ShapeBuilder<Dim = D>,
{
Self::from_shape_fn(shape, |_| dist.sample(rng))
Self::from_shape_simple_fn(shape, move || dist.sample(rng))
}

fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
17 changes: 17 additions & 0 deletions ndarray-rand/tests/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use ndarray::{Array, Array2, ArrayView1, Axis};
#[cfg(feature = "quickcheck")]
use ndarray_rand::rand::{distributions::Distribution, thread_rng};

use ndarray::ShapeBuilder;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::{RandomExt, SamplingStrategy};
use quickcheck::quickcheck;
@@ -14,6 +16,21 @@ fn test_dim() {
assert_eq!(a.shape(), &[m, n]);
assert!(a.iter().all(|x| *x < 2.));
assert!(a.iter().all(|x| *x >= 0.));
assert!(a.is_standard_layout());
}
}
}

#[test]
fn test_dim_f() {
let (mm, nn) = (5, 5);
for m in 0..mm {
for n in 0..nn {
let a = Array::random((m, n).f(), Uniform::new(0., 2.));
assert_eq!(a.shape(), &[m, n]);
assert!(a.iter().all(|x| *x < 2.));
assert!(a.iter().all(|x| *x >= 0.));
assert!(a.t().is_standard_layout());
}
}
}
36 changes: 34 additions & 2 deletions src/impl_constructors.rs
Original file line number Diff line number Diff line change
@@ -305,10 +305,28 @@ where
where
A: Default,
Sh: ShapeBuilder<Dim = D>,
{
Self::from_shape_simple_fn(shape, A::default)
}

/// Create an array with values created by the function `f`.
///
/// `f` is called with no argument, and it should return the element to
/// create. If the precise index of the element to create is needed,
/// use [`from_shape_fn`](ArrayBase::from_shape_fn) instead.
///
/// This constructor can be useful if the element order is not important,
/// for example if they are identical or random.
///
/// **Panics** if the product of non-zero axis lengths overflows `isize`.
pub fn from_shape_simple_fn<Sh, F>(shape: Sh, mut f: F) -> Self
where
Sh: ShapeBuilder<Dim = D>,
F: FnMut() -> A,
{
let shape = shape.into_shape();
let size = size_of_shape_checked_unwrap!(&shape.dim);
let v = to_vec((0..size).map(|_| A::default()));
let len = size_of_shape_checked_unwrap!(&shape.dim);
let v = to_vec_mapped(0..len, move |_| f());
unsafe { Self::from_shape_vec_unchecked(shape, v) }
}

@@ -318,6 +336,20 @@ where
/// visited in arbitrary order.
///
/// **Panics** if the product of non-zero axis lengths overflows `isize`.
///
/// ```
/// use ndarray::{Array, arr2};
///
/// // Create a table of i × j (with i and j from 1 to 3)
/// let ij_table = Array::from_shape_fn((3, 3), |(i, j)| (1 + i) * (1 + j));
///
/// assert_eq!(
/// ij_table,
/// arr2(&[[1, 2, 3],
/// [2, 4, 6],
/// [3, 6, 9]])
/// );
/// ```
pub fn from_shape_fn<Sh, F>(shape: Sh, f: F) -> Self
where
Sh: ShapeBuilder<Dim = D>,
4 changes: 2 additions & 2 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
@@ -2187,7 +2187,7 @@ where
let view_stride = self.strides.axis(axis);
if view_len == 0 {
let new_dim = self.dim.remove_axis(axis);
Array::from_shape_fn(new_dim, move |_| mapping(ArrayView::from(&[])))
Array::from_shape_simple_fn(new_dim, move || mapping(ArrayView::from(&[])))
} else {
// use the 0th subview as a map to each 1d array view extended from
// the 0th element.
@@ -2218,7 +2218,7 @@ where
let view_stride = self.strides.axis(axis);
if view_len == 0 {
let new_dim = self.dim.remove_axis(axis);
Array::from_shape_fn(new_dim, move |_| mapping(ArrayViewMut::from(&mut [])))
Array::from_shape_simple_fn(new_dim, move || mapping(ArrayViewMut::from(&mut [])))
} else {
// use the 0th subview as a map to each 1d array view extended from
// the 0th element.

0 comments on commit 792e17c

Please sign in to comment.