Skip to content

Commit

Permalink
Merge pull request #724 from LukeMathWalker/random-convenient-functions
Browse files Browse the repository at this point in the history
Add lane sampling to ndarray-rand
  • Loading branch information
bluss authored Oct 6, 2019
2 parents ecb7643 + 0c9a2e3 commit ad7efff
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ rawpointer = { version = "0.2" }

[dev-dependencies]
defmac = "0.2"
quickcheck = { version = "0.8", default-features = false }
quickcheck = { version = "0.9", default-features = false }
approx = "0.3.2"
itertools = { version = "0.8.0", default-features = false, features = ["use_std"] }

Expand Down
2 changes: 2 additions & 0 deletions ndarray-rand/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ keywords = ["multidimensional", "matrix", "rand", "ndarray"]
[dependencies]
ndarray = { version = "0.13", path = ".." }
rand_distr = "0.2.1"
quickcheck = { version = "0.9", default-features = false, optional = true }

[dependencies.rand]
version = "0.7.0"
features = ["small_rng"]

[dev-dependencies]
rand_isaac = "0.2.0"
quickcheck = { version = "0.9", default-features = false }

[package.metadata.release]
no-dev-version = true
Expand Down
185 changes: 176 additions & 9 deletions ndarray-rand/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@
//! that the items are not compatible (e.g. that a type doesn't implement a
//! necessary trait).
use crate::rand::distributions::Distribution;
use crate::rand::distributions::{Distribution, Uniform};
use crate::rand::rngs::SmallRng;
use crate::rand::seq::index;
use crate::rand::{thread_rng, Rng, SeedableRng};

use ndarray::ShapeBuilder;
use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
use ndarray::{ArrayBase, DataOwned, Dimension};
#[cfg(feature = "quickcheck")]
use quickcheck::{Arbitrary, Gen};

/// [`rand`](https://docs.rs/rand/0.7), re-exported for convenience and version-compatibility.
pub mod rand {
Expand All @@ -59,9 +62,9 @@ pub mod rand_distr {
/// low-quality random numbers, and reproducibility is not guaranteed. See its
/// documentation for information. You can select a different RNG with
/// [`.random_using()`](#tymethod.random_using).
pub trait RandomExt<S, D>
pub trait RandomExt<S, A, D>
where
S: DataOwned,
S: DataOwned<Elem = A>,
D: Dimension,
{
/// Create an array with shape `dim` with elements drawn from
Expand Down Expand Up @@ -116,21 +119,125 @@ where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
Sh: ShapeBuilder<Dim = D>;

/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
///
/// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
/// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
///
/// ***Panics*** when:
/// - creation of the RNG fails;
/// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
/// - length of `axis` is 0.
///
/// ```
/// use ndarray::{array, Axis};
/// use ndarray_rand::{RandomExt, SamplingStrategy};
///
/// # fn main() {
/// let a = array![
/// [1., 2., 3.],
/// [4., 5., 6.],
/// [7., 8., 9.],
/// [10., 11., 12.],
/// ];
/// // Sample 2 rows, without replacement
/// let sample_rows = a.sample_axis(Axis(0), 2, SamplingStrategy::WithoutReplacement);
/// println!("{:?}", sample_rows);
/// // Example Output: (1st and 3rd rows)
/// // [
/// // [1., 2., 3.],
/// // [7., 8., 9.]
/// // ]
/// // Sample 2 columns, with replacement
/// let sample_columns = a.sample_axis(Axis(1), 1, SamplingStrategy::WithReplacement);
/// println!("{:?}", sample_columns);
/// // Example Output: (2nd column, sampled twice)
/// // [
/// // [2., 2.],
/// // [5., 5.],
/// // [8., 8.],
/// // [11., 11.]
/// // ]
/// # }
/// ```
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
D: RemoveAxis;

/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
///
/// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
/// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
///
/// ***Panics*** when:
/// - creation of the RNG fails;
/// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
/// - length of `axis` is 0.
///
/// ```
/// use ndarray::{array, Axis};
/// use ndarray_rand::{RandomExt, SamplingStrategy};
/// use ndarray_rand::rand::SeedableRng;
/// use rand_isaac::isaac64::Isaac64Rng;
///
/// # fn main() {
/// // Get a seeded random number generator for reproducibility (Isaac64 algorithm)
/// let seed = 42;
/// let mut rng = Isaac64Rng::seed_from_u64(seed);
///
/// let a = array![
/// [1., 2., 3.],
/// [4., 5., 6.],
/// [7., 8., 9.],
/// [10., 11., 12.],
/// ];
/// // Sample 2 rows, without replacement
/// let sample_rows = a.sample_axis_using(Axis(0), 2, SamplingStrategy::WithoutReplacement, &mut rng);
/// println!("{:?}", sample_rows);
/// // Example Output: (1st and 3rd rows)
/// // [
/// // [1., 2., 3.],
/// // [7., 8., 9.]
/// // ]
///
/// // Sample 2 columns, with replacement
/// let sample_columns = a.sample_axis_using(Axis(1), 1, SamplingStrategy::WithReplacement, &mut rng);
/// println!("{:?}", sample_columns);
/// // Example Output: (2nd column, sampled twice)
/// // [
/// // [2., 2.],
/// // [5., 5.],
/// // [8., 8.],
/// // [11., 11.]
/// // ]
/// # }
/// ```
fn sample_axis_using<R>(
&self,
axis: Axis,
n_samples: usize,
strategy: SamplingStrategy,
rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
D: RemoveAxis;
}

impl<S, D> RandomExt<S, D> for ArrayBase<S, D>
impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D>
where
S: DataOwned,
S: DataOwned<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
Sh: ShapeBuilder<Dim = D>,
{
let mut rng =
SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed");
Self::random_using(shape, dist, &mut rng)
Self::random_using(shape, dist, &mut get_rng())
}

fn random_using<Sh, IdS, R>(shape: Sh, dist: IdS, rng: &mut R) -> ArrayBase<S, D>
Expand All @@ -141,6 +248,66 @@ where
{
Self::from_shape_fn(shape, |_| dist.sample(rng))
}

fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
D: RemoveAxis,
{
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
}

fn sample_axis_using<R>(
&self,
axis: Axis,
n_samples: usize,
strategy: SamplingStrategy,
rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
D: RemoveAxis,
{
let indices: Vec<_> = match strategy {
SamplingStrategy::WithReplacement => {
let distribution = Uniform::from(0..self.len_of(axis));
(0..n_samples).map(|_| distribution.sample(rng)).collect()
}
SamplingStrategy::WithoutReplacement => {
index::sample(rng, self.len_of(axis), n_samples).into_vec()
}
};
self.select(axis, &indices)
}
}

/// Used as parameter in [`sample_axis`] and [`sample_axis_using`] to determine
/// if lanes from the original array should only be sampled once (*without replacement*) or
/// multiple times (*with replacement*).
///
/// [`sample_axis`]: trait.RandomExt.html#tymethod.sample_axis
/// [`sample_axis_using`]: trait.RandomExt.html#tymethod.sample_axis_using
#[derive(Debug, Clone)]
pub enum SamplingStrategy {
WithReplacement,
WithoutReplacement,
}

// `Arbitrary` enables `quickcheck` to generate random `SamplingStrategy` values for testing.
#[cfg(feature = "quickcheck")]
impl Arbitrary for SamplingStrategy {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
if g.gen_bool(0.5) {
SamplingStrategy::WithReplacement
} else {
SamplingStrategy::WithoutReplacement
}
}
}

fn get_rng() -> SmallRng {
SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed")
}

/// A wrapper type that allows casting f64 distributions to f32
Expand Down
98 changes: 96 additions & 2 deletions ndarray-rand/tests/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use ndarray::Array;
use ndarray::{Array, Array2, ArrayView1, Axis};
#[cfg(feature = "quickcheck")]
use ndarray_rand::rand::{distributions::Distribution, thread_rng};
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use ndarray_rand::{RandomExt, SamplingStrategy};
use quickcheck::quickcheck;

#[test]
fn test_dim() {
Expand All @@ -14,3 +17,94 @@ fn test_dim() {
}
}
}

#[test]
#[should_panic]
fn oversampling_without_replacement_should_panic() {
let m = 5;
let a = Array::random((m, 4), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), m + 1, SamplingStrategy::WithoutReplacement);
}

quickcheck! {
fn oversampling_with_replacement_is_fine(m: usize, n: usize) -> bool {
let a = Array::random((m, n), Uniform::new(0., 2.));
// Higher than the length of both axes
let n_samples = m + n + 1;

// We don't want to deal with sampling from 0-length axes in this test
if m != 0 {
if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(0), n_samples) {
return false;
}
}

// We don't want to deal with sampling from 0-length axes in this test
if n != 0 {
if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(1), n_samples) {
return false;
}
}

true
}
}

#[cfg(feature = "quickcheck")]
quickcheck! {
fn sampling_behaves_as_expected(m: usize, n: usize, strategy: SamplingStrategy) -> bool {
let a = Array::random((m, n), Uniform::new(0., 2.));
let mut rng = &mut thread_rng();

// We don't want to deal with sampling from 0-length axes in this test
if m != 0 {
let n_row_samples = Uniform::from(1..m+1).sample(&mut rng);
if !sampling_works(&a, strategy.clone(), Axis(0), n_row_samples) {
return false;
}
}

// We don't want to deal with sampling from 0-length axes in this test
if n != 0 {
let n_col_samples = Uniform::from(1..n+1).sample(&mut rng);
if !sampling_works(&a, strategy, Axis(1), n_col_samples) {
return false;
}
}

true
}
}

fn sampling_works(
a: &Array2<f64>,
strategy: SamplingStrategy,
axis: Axis,
n_samples: usize,
) -> bool {
let samples = a.sample_axis(axis, n_samples, strategy);
samples
.axis_iter(axis)
.all(|lane| is_subset(&a, &lane, axis))
}

// Check if, when sliced along `axis`, there is at least one lane in `a` equal to `b`
fn is_subset(a: &Array2<f64>, b: &ArrayView1<f64>, axis: Axis) -> bool {
a.axis_iter(axis).any(|lane| &lane == b)
}

#[test]
#[should_panic]
fn sampling_without_replacement_from_a_zero_length_axis_should_panic() {
let n = 5;
let a = Array::random((0, n), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithoutReplacement);
}

#[test]
#[should_panic]
fn sampling_with_replacement_from_a_zero_length_axis_should_panic() {
let n = 5;
let a = Array::random((0, n), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement);
}
2 changes: 2 additions & 0 deletions scripts/all-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ cargo test --verbose --no-default-features
cargo test --release --verbose --no-default-features
cargo build --verbose --features "$FEATURES"
cargo test --verbose --features "$FEATURES"
cargo test --manifest-path=ndarray-rand/Cargo.toml --no-default-features --verbose
cargo test --manifest-path=ndarray-rand/Cargo.toml --features quickcheck --verbose
cargo test --manifest-path=serialization-tests/Cargo.toml --verbose
cargo test --manifest-path=blas-tests/Cargo.toml --verbose
CARGO_TARGET_DIR=target/ cargo test --manifest-path=numeric-tests/Cargo.toml --verbose
Expand Down

0 comments on commit ad7efff

Please sign in to comment.