From 0c9a2e3eaff9e9d801762bed62c7c66a3a2b453a Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 22 Sep 2019 19:12:54 +0100 Subject: [PATCH] Add `sample_axis` and `sample_axis_using` to ndarray-rand. Add `quickcheck` feature flag to `ndarray-rand` to get `Arbitrary` implementation for `ndarray-rand` types. Update quickcheck dependency in ndarray. Fix CI scripts to run ndarray-rand's tests. --- Cargo.toml | 2 +- ndarray-rand/Cargo.toml | 2 + ndarray-rand/src/lib.rs | 185 ++++++++++++++++++++++++++++++++++-- ndarray-rand/tests/tests.rs | 98 ++++++++++++++++++- scripts/all-tests.sh | 2 + 5 files changed, 277 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b0a5ed74c..3da63d1d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ rawpointer = { version = "0.2" } [dev-dependencies] defmac = "0.2" -quickcheck = { version = "0.8", default-features = false } +quickcheck = { version = "0.9", default-features = false } approx = "0.3.2" itertools = { version = "0.8.0", default-features = false, features = ["use_std"] } diff --git a/ndarray-rand/Cargo.toml b/ndarray-rand/Cargo.toml index ce9936233..aa4715fc4 100644 --- a/ndarray-rand/Cargo.toml +++ b/ndarray-rand/Cargo.toml @@ -16,6 +16,7 @@ keywords = ["multidimensional", "matrix", "rand", "ndarray"] [dependencies] ndarray = { version = "0.13", path = ".." } rand_distr = "0.2.1" +quickcheck = { version = "0.9", default-features = false, optional = true } [dependencies.rand] version = "0.7.0" @@ -23,6 +24,7 @@ features = ["small_rng"] [dev-dependencies] rand_isaac = "0.2.0" +quickcheck = { version = "0.9", default-features = false } [package.metadata.release] no-dev-version = true diff --git a/ndarray-rand/src/lib.rs b/ndarray-rand/src/lib.rs index 49511cbe7..19235237f 100644 --- a/ndarray-rand/src/lib.rs +++ b/ndarray-rand/src/lib.rs @@ -29,12 +29,15 @@ //! that the items are not compatible (e.g. that a type doesn't implement a //! necessary trait). -use crate::rand::distributions::Distribution; +use crate::rand::distributions::{Distribution, Uniform}; use crate::rand::rngs::SmallRng; +use crate::rand::seq::index; use crate::rand::{thread_rng, Rng, SeedableRng}; -use ndarray::ShapeBuilder; +use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder}; use ndarray::{ArrayBase, DataOwned, Dimension}; +#[cfg(feature = "quickcheck")] +use quickcheck::{Arbitrary, Gen}; /// [`rand`](https://docs.rs/rand/0.7), re-exported for convenience and version-compatibility. pub mod rand { @@ -59,9 +62,9 @@ pub mod rand_distr { /// low-quality random numbers, and reproducibility is not guaranteed. See its /// documentation for information. You can select a different RNG with /// [`.random_using()`](#tymethod.random_using). -pub trait RandomExt +pub trait RandomExt where - S: DataOwned, + S: DataOwned, D: Dimension, { /// Create an array with shape `dim` with elements drawn from @@ -116,11 +119,117 @@ where IdS: Distribution, R: Rng + ?Sized, Sh: ShapeBuilder; + + /// Sample `n_samples` lanes slicing along `axis` using the default RNG. + /// + /// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once. + /// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times. + /// + /// ***Panics*** when: + /// - creation of the RNG fails; + /// - `n_samples` is greater than the length of `axis` (if sampling without replacement); + /// - length of `axis` is 0. + /// + /// ``` + /// use ndarray::{array, Axis}; + /// use ndarray_rand::{RandomExt, SamplingStrategy}; + /// + /// # fn main() { + /// let a = array![ + /// [1., 2., 3.], + /// [4., 5., 6.], + /// [7., 8., 9.], + /// [10., 11., 12.], + /// ]; + /// // Sample 2 rows, without replacement + /// let sample_rows = a.sample_axis(Axis(0), 2, SamplingStrategy::WithoutReplacement); + /// println!("{:?}", sample_rows); + /// // Example Output: (1st and 3rd rows) + /// // [ + /// // [1., 2., 3.], + /// // [7., 8., 9.] + /// // ] + /// // Sample 2 columns, with replacement + /// let sample_columns = a.sample_axis(Axis(1), 1, SamplingStrategy::WithReplacement); + /// println!("{:?}", sample_columns); + /// // Example Output: (2nd column, sampled twice) + /// // [ + /// // [2., 2.], + /// // [5., 5.], + /// // [8., 8.], + /// // [11., 11.] + /// // ] + /// # } + /// ``` + fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array + where + A: Copy, + D: RemoveAxis; + + /// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`. + /// + /// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once. + /// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times. + /// + /// ***Panics*** when: + /// - creation of the RNG fails; + /// - `n_samples` is greater than the length of `axis` (if sampling without replacement); + /// - length of `axis` is 0. + /// + /// ``` + /// use ndarray::{array, Axis}; + /// use ndarray_rand::{RandomExt, SamplingStrategy}; + /// use ndarray_rand::rand::SeedableRng; + /// use rand_isaac::isaac64::Isaac64Rng; + /// + /// # fn main() { + /// // Get a seeded random number generator for reproducibility (Isaac64 algorithm) + /// let seed = 42; + /// let mut rng = Isaac64Rng::seed_from_u64(seed); + /// + /// let a = array![ + /// [1., 2., 3.], + /// [4., 5., 6.], + /// [7., 8., 9.], + /// [10., 11., 12.], + /// ]; + /// // Sample 2 rows, without replacement + /// let sample_rows = a.sample_axis_using(Axis(0), 2, SamplingStrategy::WithoutReplacement, &mut rng); + /// println!("{:?}", sample_rows); + /// // Example Output: (1st and 3rd rows) + /// // [ + /// // [1., 2., 3.], + /// // [7., 8., 9.] + /// // ] + /// + /// // Sample 2 columns, with replacement + /// let sample_columns = a.sample_axis_using(Axis(1), 1, SamplingStrategy::WithReplacement, &mut rng); + /// println!("{:?}", sample_columns); + /// // Example Output: (2nd column, sampled twice) + /// // [ + /// // [2., 2.], + /// // [5., 5.], + /// // [8., 8.], + /// // [11., 11.] + /// // ] + /// # } + /// ``` + fn sample_axis_using( + &self, + axis: Axis, + n_samples: usize, + strategy: SamplingStrategy, + rng: &mut R, + ) -> Array + where + R: Rng + ?Sized, + A: Copy, + D: RemoveAxis; } -impl RandomExt for ArrayBase +impl RandomExt for ArrayBase where - S: DataOwned, + S: DataOwned, D: Dimension, { fn random(shape: Sh, dist: IdS) -> ArrayBase @@ -128,9 +237,7 @@ where IdS: Distribution, Sh: ShapeBuilder, { - let mut rng = - SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed"); - Self::random_using(shape, dist, &mut rng) + Self::random_using(shape, dist, &mut get_rng()) } fn random_using(shape: Sh, dist: IdS, rng: &mut R) -> ArrayBase @@ -141,6 +248,66 @@ where { Self::from_shape_fn(shape, |_| dist.sample(rng)) } + + fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array + where + A: Copy, + D: RemoveAxis, + { + self.sample_axis_using(axis, n_samples, strategy, &mut get_rng()) + } + + fn sample_axis_using( + &self, + axis: Axis, + n_samples: usize, + strategy: SamplingStrategy, + rng: &mut R, + ) -> Array + where + R: Rng + ?Sized, + A: Copy, + D: RemoveAxis, + { + let indices: Vec<_> = match strategy { + SamplingStrategy::WithReplacement => { + let distribution = Uniform::from(0..self.len_of(axis)); + (0..n_samples).map(|_| distribution.sample(rng)).collect() + } + SamplingStrategy::WithoutReplacement => { + index::sample(rng, self.len_of(axis), n_samples).into_vec() + } + }; + self.select(axis, &indices) + } +} + +/// Used as parameter in [`sample_axis`] and [`sample_axis_using`] to determine +/// if lanes from the original array should only be sampled once (*without replacement*) or +/// multiple times (*with replacement*). +/// +/// [`sample_axis`]: trait.RandomExt.html#tymethod.sample_axis +/// [`sample_axis_using`]: trait.RandomExt.html#tymethod.sample_axis_using +#[derive(Debug, Clone)] +pub enum SamplingStrategy { + WithReplacement, + WithoutReplacement, +} + +// `Arbitrary` enables `quickcheck` to generate random `SamplingStrategy` values for testing. +#[cfg(feature = "quickcheck")] +impl Arbitrary for SamplingStrategy { + fn arbitrary(g: &mut G) -> Self { + if g.gen_bool(0.5) { + SamplingStrategy::WithReplacement + } else { + SamplingStrategy::WithoutReplacement + } + } +} + +fn get_rng() -> SmallRng { + SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed") } /// A wrapper type that allows casting f64 distributions to f32 diff --git a/ndarray-rand/tests/tests.rs b/ndarray-rand/tests/tests.rs index 6e64a10b9..52de74595 100644 --- a/ndarray-rand/tests/tests.rs +++ b/ndarray-rand/tests/tests.rs @@ -1,6 +1,9 @@ -use ndarray::Array; +use ndarray::{Array, Array2, ArrayView1, Axis}; +#[cfg(feature = "quickcheck")] +use ndarray_rand::rand::{distributions::Distribution, thread_rng}; use ndarray_rand::rand_distr::Uniform; -use ndarray_rand::RandomExt; +use ndarray_rand::{RandomExt, SamplingStrategy}; +use quickcheck::quickcheck; #[test] fn test_dim() { @@ -14,3 +17,94 @@ fn test_dim() { } } } + +#[test] +#[should_panic] +fn oversampling_without_replacement_should_panic() { + let m = 5; + let a = Array::random((m, 4), Uniform::new(0., 2.)); + let _samples = a.sample_axis(Axis(0), m + 1, SamplingStrategy::WithoutReplacement); +} + +quickcheck! { + fn oversampling_with_replacement_is_fine(m: usize, n: usize) -> bool { + let a = Array::random((m, n), Uniform::new(0., 2.)); + // Higher than the length of both axes + let n_samples = m + n + 1; + + // We don't want to deal with sampling from 0-length axes in this test + if m != 0 { + if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(0), n_samples) { + return false; + } + } + + // We don't want to deal with sampling from 0-length axes in this test + if n != 0 { + if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(1), n_samples) { + return false; + } + } + + true + } +} + +#[cfg(feature = "quickcheck")] +quickcheck! { + fn sampling_behaves_as_expected(m: usize, n: usize, strategy: SamplingStrategy) -> bool { + let a = Array::random((m, n), Uniform::new(0., 2.)); + let mut rng = &mut thread_rng(); + + // We don't want to deal with sampling from 0-length axes in this test + if m != 0 { + let n_row_samples = Uniform::from(1..m+1).sample(&mut rng); + if !sampling_works(&a, strategy.clone(), Axis(0), n_row_samples) { + return false; + } + } + + // We don't want to deal with sampling from 0-length axes in this test + if n != 0 { + let n_col_samples = Uniform::from(1..n+1).sample(&mut rng); + if !sampling_works(&a, strategy, Axis(1), n_col_samples) { + return false; + } + } + + true + } +} + +fn sampling_works( + a: &Array2, + strategy: SamplingStrategy, + axis: Axis, + n_samples: usize, +) -> bool { + let samples = a.sample_axis(axis, n_samples, strategy); + samples + .axis_iter(axis) + .all(|lane| is_subset(&a, &lane, axis)) +} + +// Check if, when sliced along `axis`, there is at least one lane in `a` equal to `b` +fn is_subset(a: &Array2, b: &ArrayView1, axis: Axis) -> bool { + a.axis_iter(axis).any(|lane| &lane == b) +} + +#[test] +#[should_panic] +fn sampling_without_replacement_from_a_zero_length_axis_should_panic() { + let n = 5; + let a = Array::random((0, n), Uniform::new(0., 2.)); + let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithoutReplacement); +} + +#[test] +#[should_panic] +fn sampling_with_replacement_from_a_zero_length_axis_should_panic() { + let n = 5; + let a = Array::random((0, n), Uniform::new(0., 2.)); + let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement); +} diff --git a/scripts/all-tests.sh b/scripts/all-tests.sh index 7925fd91c..9b41b41d8 100755 --- a/scripts/all-tests.sh +++ b/scripts/all-tests.sh @@ -13,6 +13,8 @@ cargo test --verbose --no-default-features cargo test --release --verbose --no-default-features cargo build --verbose --features "$FEATURES" cargo test --verbose --features "$FEATURES" +cargo test --manifest-path=ndarray-rand/Cargo.toml --no-default-features --verbose +cargo test --manifest-path=ndarray-rand/Cargo.toml --features quickcheck --verbose cargo test --manifest-path=serialization-tests/Cargo.toml --verbose cargo test --manifest-path=blas-tests/Cargo.toml --verbose CARGO_TARGET_DIR=target/ cargo test --manifest-path=numeric-tests/Cargo.toml --verbose