Skip to content

Commit

Permalink
Merge pull request #570 from jturner314/slice-new-axis
Browse files Browse the repository at this point in the history
Add support for inserting new axes while slicing
  • Loading branch information
bluss authored Mar 13, 2021
2 parents 9b23ba5 + 7506f90 commit 09884fc
Show file tree
Hide file tree
Showing 12 changed files with 828 additions and 353 deletions.
10 changes: 5 additions & 5 deletions blas-tests/tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ extern crate num_traits;
use ndarray::linalg::general_mat_mul;
use ndarray::linalg::general_mat_vec_mul;
use ndarray::prelude::*;
use ndarray::{AxisSliceInfo, Ix, Ixs};
use ndarray::{Data, LinalgScalar};
use ndarray::{Ix, Ixs, SliceInfo, SliceOrIndex};

use approx::{assert_abs_diff_eq, assert_relative_eq};
use defmac::defmac;
Expand Down Expand Up @@ -420,19 +420,19 @@ fn scaled_add_3() {
let mut answer = a.clone();
let cdim = if n == 1 { vec![q] } else { vec![n, q] };
let cslice = if n == 1 {
vec![SliceOrIndex::from(..).step_by(s2)]
vec![AxisSliceInfo::from(..).step_by(s2)]
} else {
vec![
SliceOrIndex::from(..).step_by(s1),
SliceOrIndex::from(..).step_by(s2),
AxisSliceInfo::from(..).step_by(s1),
AxisSliceInfo::from(..).step_by(s2),
]
};

let c = range_mat64(n, q).into_shape(cdim).unwrap();

{
let mut av = a.slice_mut(s![..;s1, ..;s2]);
let c = c.slice(SliceInfo::<_, IxDyn>::new(cslice).unwrap().as_ref());
let c = c.slice(&*cslice);

let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
answerv += &(beta * &c);
Expand Down
30 changes: 8 additions & 22 deletions src/dimension/dimension_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ use alloc::vec::Vec;

use super::axes_of;
use super::conversion::Convert;
use super::ops::DimAdd;
use super::{stride_offset, stride_offset_checked};
use crate::itertools::{enumerate, zip};
use crate::{Axis, DimMax};
use crate::IntoDimension;
use crate::RemoveAxis;
use crate::{ArrayView1, ArrayViewMut1};
use crate::{Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs, SliceOrIndex};
use crate::{Dim, Ix, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn, IxDynImpl, Ixs};

/// Array shape and index trait.
///
Expand Down Expand Up @@ -51,26 +52,17 @@ pub trait Dimension:
+ DimMax<IxDyn, Output=IxDyn>
+ DimMax<<Self as Dimension>::Smaller, Output=Self>
+ DimMax<<Self as Dimension>::Larger, Output=<Self as Dimension>::Larger>
+ DimAdd<Self>
+ DimAdd<<Self as Dimension>::Smaller>
+ DimAdd<<Self as Dimension>::Larger>
+ DimAdd<Ix0, Output = Self>
+ DimAdd<Ix1, Output = <Self as Dimension>::Larger>
+ DimAdd<IxDyn, Output = IxDyn>
{
/// For fixed-size dimension representations (e.g. `Ix2`), this should be
/// `Some(ndim)`, and for variable-size dimension representations (e.g.
/// `IxDyn`), this should be `None`.
const NDIM: Option<usize>;
/// `SliceArg` is the type which is used to specify slicing for this
/// dimension.
///
/// For the fixed size dimensions it is a fixed size array of the correct
/// size, which you pass by reference. For the dynamic dimension it is
/// a slice.
///
/// - For `Ix1`: `[SliceOrIndex; 1]`
/// - For `Ix2`: `[SliceOrIndex; 2]`
/// - and so on..
/// - For `IxDyn`: `[SliceOrIndex]`
///
/// The easiest way to create a `&SliceInfo<SliceArg, Do>` is using the
/// [`s![]`](macro.s!.html) macro.
type SliceArg: ?Sized + AsRef<[SliceOrIndex]>;
/// Pattern matching friendly form of the dimension value.
///
/// - For `Ix1`: `usize`,
Expand Down Expand Up @@ -399,7 +391,6 @@ macro_rules! impl_insert_axis_array(

impl Dimension for Dim<[Ix; 0]> {
const NDIM: Option<usize> = Some(0);
type SliceArg = [SliceOrIndex; 0];
type Pattern = ();
type Smaller = Self;
type Larger = Ix1;
Expand Down Expand Up @@ -443,7 +434,6 @@ impl Dimension for Dim<[Ix; 0]> {

impl Dimension for Dim<[Ix; 1]> {
const NDIM: Option<usize> = Some(1);
type SliceArg = [SliceOrIndex; 1];
type Pattern = Ix;
type Smaller = Ix0;
type Larger = Ix2;
Expand Down Expand Up @@ -559,7 +549,6 @@ impl Dimension for Dim<[Ix; 1]> {

impl Dimension for Dim<[Ix; 2]> {
const NDIM: Option<usize> = Some(2);
type SliceArg = [SliceOrIndex; 2];
type Pattern = (Ix, Ix);
type Smaller = Ix1;
type Larger = Ix3;
Expand Down Expand Up @@ -716,7 +705,6 @@ impl Dimension for Dim<[Ix; 2]> {

impl Dimension for Dim<[Ix; 3]> {
const NDIM: Option<usize> = Some(3);
type SliceArg = [SliceOrIndex; 3];
type Pattern = (Ix, Ix, Ix);
type Smaller = Ix2;
type Larger = Ix4;
Expand Down Expand Up @@ -839,7 +827,6 @@ macro_rules! large_dim {
($n:expr, $name:ident, $pattern:ty, $larger:ty, { $($insert_axis:tt)* }) => (
impl Dimension for Dim<[Ix; $n]> {
const NDIM: Option<usize> = Some($n);
type SliceArg = [SliceOrIndex; $n];
type Pattern = $pattern;
type Smaller = Dim<[Ix; $n - 1]>;
type Larger = $larger;
Expand Down Expand Up @@ -890,7 +877,6 @@ large_dim!(6, Ix6, (Ix, Ix, Ix, Ix, Ix, Ix), IxDyn, {
/// and memory wasteful, but it allows an arbitrary and dynamic number of axes.
impl Dimension for IxDyn {
const NDIM: Option<usize> = None;
type SliceArg = [SliceOrIndex];
type Pattern = Self;
type Smaller = Self;
type Larger = Self;
Expand Down
74 changes: 55 additions & 19 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
// except according to those terms.

use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::{Ix, Ixs, Slice, SliceOrIndex};
use crate::slice::SliceArg;
use crate::{AxisSliceInfo, Ix, Ixs, Slice};
use num_integer::div_floor;

pub use self::axes::{axes_of, Axes, AxisDescription};
Expand All @@ -18,6 +19,7 @@ pub use self::dim::*;
pub use self::dimension_trait::Dimension;
pub use self::dynindeximpl::IxDynImpl;
pub use self::ndindex::NdIndex;
pub use self::ops::DimAdd;
pub use self::remove_axis::RemoveAxis;

use crate::shape_builder::Strides;
Expand All @@ -35,6 +37,7 @@ pub mod dim;
mod dimension_trait;
mod dynindeximpl;
mod ndindex;
mod ops;
mod remove_axis;

/// Calculate offset from `Ix` stride converting sign properly
Expand Down Expand Up @@ -596,20 +599,24 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> {
/// Returns `true` iff the slices intersect.
pub fn slices_intersect<D: Dimension>(
dim: &D,
indices1: &D::SliceArg,
indices2: &D::SliceArg,
indices1: &impl SliceArg<D>,
indices2: &impl SliceArg<D>,
) -> bool {
debug_assert_eq!(indices1.as_ref().len(), indices2.as_ref().len());
for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) {
// The slices do not intersect iff any pair of `SliceOrIndex` does not intersect.
debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim());
for (&axis_len, &si1, &si2) in izip!(
dim.slice(),
indices1.as_ref().iter().filter(|si| !si.is_new_axis()),
indices2.as_ref().iter().filter(|si| !si.is_new_axis()),
) {
// The slices do not intersect iff any pair of `AxisSliceInfo` does not intersect.
match (si1, si2) {
(
SliceOrIndex::Slice {
AxisSliceInfo::Slice {
start: start1,
end: end1,
step: step1,
},
SliceOrIndex::Slice {
AxisSliceInfo::Slice {
start: start2,
end: end2,
step: step2,
Expand All @@ -630,8 +637,8 @@ pub fn slices_intersect<D: Dimension>(
return false;
}
}
(SliceOrIndex::Slice { start, end, step }, SliceOrIndex::Index(ind))
| (SliceOrIndex::Index(ind), SliceOrIndex::Slice { start, end, step }) => {
(AxisSliceInfo::Slice { start, end, step }, AxisSliceInfo::Index(ind))
| (AxisSliceInfo::Index(ind), AxisSliceInfo::Slice { start, end, step }) => {
let ind = abs_index(axis_len, ind);
let (min, max) = match slice_min_max(axis_len, Slice::new(start, end, step)) {
Some(m) => m,
Expand All @@ -641,13 +648,14 @@ pub fn slices_intersect<D: Dimension>(
return false;
}
}
(SliceOrIndex::Index(ind1), SliceOrIndex::Index(ind2)) => {
(AxisSliceInfo::Index(ind1), AxisSliceInfo::Index(ind2)) => {
let ind1 = abs_index(axis_len, ind1);
let ind2 = abs_index(axis_len, ind2);
if ind1 != ind2 {
return false;
}
}
(AxisSliceInfo::NewAxis, _) | (_, AxisSliceInfo::NewAxis) => unreachable!(),
}
}
true
Expand Down Expand Up @@ -719,7 +727,7 @@ mod test {
};
use crate::error::{from_kind, ErrorKind};
use crate::slice::Slice;
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn};
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn, NewAxis};
use num_integer::gcd;
use quickcheck::{quickcheck, TestResult};

Expand Down Expand Up @@ -993,17 +1001,45 @@ mod test {

#[test]
fn slices_intersect_true() {
assert!(slices_intersect(&Dim([4, 5]), s![.., ..], s![.., ..]));
assert!(slices_intersect(&Dim([4, 5]), s![0, ..], s![0, ..]));
assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, ..]));
assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3]));
assert!(slices_intersect(
&Dim([4, 5]),
s![NewAxis, .., NewAxis, ..],
s![.., NewAxis, .., NewAxis]
));
assert!(slices_intersect(
&Dim([4, 5]),
s![NewAxis, 0, ..],
s![0, ..]
));
assert!(slices_intersect(
&Dim([4, 5]),
s![..;2, ..],
s![..;3, NewAxis, ..]
));
assert!(slices_intersect(
&Dim([4, 5]),
s![.., ..;2],
s![.., 1..;3, NewAxis]
));
assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6]));
}

#[test]
fn slices_intersect_false() {
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;2, ..]));
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;3, ..]));
assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6]));
assert!(!slices_intersect(
&Dim([4, 5]),
s![..;2, ..],
s![NewAxis, 1..;2, ..]
));
assert!(!slices_intersect(
&Dim([4, 5]),
s![..;2, NewAxis, ..],
s![1..;3, ..]
));
assert!(!slices_intersect(
&Dim([4, 5]),
s![.., ..;9],
s![.., 3..;6, NewAxis]
));
}
}
90 changes: 90 additions & 0 deletions src/dimension/ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use crate::imp_prelude::*;

/// Adds the two dimensions at compile time.
pub trait DimAdd<D: Dimension> {
/// The sum of the two dimensions.
type Output: Dimension;
}

macro_rules! impl_dimadd_const_out_const {
($lhs:expr, $rhs:expr) => {
impl DimAdd<Dim<[usize; $rhs]>> for Dim<[usize; $lhs]> {
type Output = Dim<[usize; $lhs + $rhs]>;
}
};
}

macro_rules! impl_dimadd_const_out_dyn {
($lhs:expr, IxDyn) => {
impl DimAdd<IxDyn> for Dim<[usize; $lhs]> {
type Output = IxDyn;
}
};
($lhs:expr, $rhs:expr) => {
impl DimAdd<Dim<[usize; $rhs]>> for Dim<[usize; $lhs]> {
type Output = IxDyn;
}
};
}

impl<D: Dimension> DimAdd<D> for Ix0 {
type Output = D;
}

impl_dimadd_const_out_const!(1, 0);
impl_dimadd_const_out_const!(1, 1);
impl_dimadd_const_out_const!(1, 2);
impl_dimadd_const_out_const!(1, 3);
impl_dimadd_const_out_const!(1, 4);
impl_dimadd_const_out_const!(1, 5);
impl_dimadd_const_out_dyn!(1, 6);
impl_dimadd_const_out_dyn!(1, IxDyn);

impl_dimadd_const_out_const!(2, 0);
impl_dimadd_const_out_const!(2, 1);
impl_dimadd_const_out_const!(2, 2);
impl_dimadd_const_out_const!(2, 3);
impl_dimadd_const_out_const!(2, 4);
impl_dimadd_const_out_dyn!(2, 5);
impl_dimadd_const_out_dyn!(2, 6);
impl_dimadd_const_out_dyn!(2, IxDyn);

impl_dimadd_const_out_const!(3, 0);
impl_dimadd_const_out_const!(3, 1);
impl_dimadd_const_out_const!(3, 2);
impl_dimadd_const_out_const!(3, 3);
impl_dimadd_const_out_dyn!(3, 4);
impl_dimadd_const_out_dyn!(3, 5);
impl_dimadd_const_out_dyn!(3, 6);
impl_dimadd_const_out_dyn!(3, IxDyn);

impl_dimadd_const_out_const!(4, 0);
impl_dimadd_const_out_const!(4, 1);
impl_dimadd_const_out_const!(4, 2);
impl_dimadd_const_out_dyn!(4, 3);
impl_dimadd_const_out_dyn!(4, 4);
impl_dimadd_const_out_dyn!(4, 5);
impl_dimadd_const_out_dyn!(4, 6);
impl_dimadd_const_out_dyn!(4, IxDyn);

impl_dimadd_const_out_const!(5, 0);
impl_dimadd_const_out_const!(5, 1);
impl_dimadd_const_out_dyn!(5, 2);
impl_dimadd_const_out_dyn!(5, 3);
impl_dimadd_const_out_dyn!(5, 4);
impl_dimadd_const_out_dyn!(5, 5);
impl_dimadd_const_out_dyn!(5, 6);
impl_dimadd_const_out_dyn!(5, IxDyn);

impl_dimadd_const_out_const!(6, 0);
impl_dimadd_const_out_dyn!(6, 1);
impl_dimadd_const_out_dyn!(6, 2);
impl_dimadd_const_out_dyn!(6, 3);
impl_dimadd_const_out_dyn!(6, 4);
impl_dimadd_const_out_dyn!(6, 5);
impl_dimadd_const_out_dyn!(6, 6);
impl_dimadd_const_out_dyn!(6, IxDyn);

impl<D: Dimension> DimAdd<D> for IxDyn {
type Output = IxDyn;
}
2 changes: 1 addition & 1 deletion src/doc/ndarray_for_numpy_users/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
//! `np.concatenate((a,b), axis=1)` | [`concatenate![Axis(1), a, b]`][concatenate!] or [`concatenate(Axis(1), &[a.view(), b.view()])`][concatenate()] | concatenate arrays `a` and `b` along axis 1
//! `np.stack((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), vec![a.view(), b.view()])`][stack()] | stack arrays `a` and `b` along axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.slice(s![.., NewAxis])`][.slice()] or [`a.insert_axis(Axis(1))`][.insert_axis()] | create an view of 1-D array `a`, inserting a new axis 1
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
//! `a.flatten()` | [`use std::iter::FromIterator; Array::from_iter(a.iter().cloned())`][::from_iter()] | create a 1-D array by flattening `a`
Expand Down
Loading

0 comments on commit 09884fc

Please sign in to comment.