diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs
index 6469e4f5d..a9e864a55 100644
--- a/src/doc/ndarray_for_numpy_users/mod.rs
+++ b/src/doc/ndarray_for_numpy_users/mod.rs
@@ -531,6 +531,7 @@
//! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1
+//! `np.stack((a,b), axis=1)` | [`stack_new_axis![Axis(1), a, b]`][stack_new_axis!] or [`stack_new_axis(Axis(1), vec![a.view(), b.view()])`][stack_new_axis()] | stack arrays `a` and `b` along axis 1
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
@@ -640,6 +641,8 @@
//! [.shape()]: ../../struct.ArrayBase.html#method.shape
//! [stack!]: ../../macro.stack.html
//! [stack()]: ../../fn.stack.html
+//! [stack_new_axis!]: ../../macro.stack_new_axis.html
+//! [stack_new_axis()]: ../../fn.stack_new_axis.html
//! [.strides()]: ../../struct.ArrayBase.html#method.strides
//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis
//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis
diff --git a/src/impl_methods.rs b/src/impl_methods.rs
index 8a906be3d..ca5b5499a 100644
--- a/src/impl_methods.rs
+++ b/src/impl_methods.rs
@@ -28,7 +28,7 @@ use crate::iter::{
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
};
use crate::slice::MultiSlice;
-use crate::stacking::stack;
+use crate::stacking::concatenate;
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};
/// # Methods For All Array Types
@@ -840,7 +840,7 @@ where
dim.set_axis(axis, 0);
unsafe { Array::from_shape_vec_unchecked(dim, vec![]) }
} else {
- stack(axis, &subs).unwrap()
+ concatenate(axis, &subs).unwrap()
}
}
diff --git a/src/lib.rs b/src/lib.rs
index f65e82ec9..1b7590da1 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -131,7 +131,9 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane
pub use crate::arraytraits::AsArray;
pub use crate::linalg_traits::{LinalgScalar, NdFloat};
-pub use crate::stacking::stack;
+
+#[allow(deprecated)]
+pub use crate::stacking::{concatenate, stack, stack_new_axis};
pub use crate::impl_views::IndexLonger;
pub use crate::shape_builder::ShapeBuilder;
diff --git a/src/stacking.rs b/src/stacking.rs
index 0f6161a3c..3e3b1afd8 100644
--- a/src/stacking.rs
+++ b/src/stacking.rs
@@ -9,7 +9,7 @@
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;
-/// Stack arrays along the given axis.
+/// Concatenate arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).
@@ -29,10 +29,11 @@ use crate::imp_prelude::*;
/// [3., 3.]]))
/// );
/// ```
-pub fn stack<'a, A, D>(
- axis: Axis,
- arrays: &[ArrayView<'a, A, D>],
-) -> Result, ShapeError>
+#[deprecated(
+ since = "0.13.2",
+ note = "Please use the `concatenate` function instead"
+)]
+pub fn stack(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError>
where
A: Copy,
D: RemoveAxis,
@@ -76,7 +77,103 @@ where
Ok(res)
}
-/// Stack arrays along the given axis.
+/// Concatenate arrays along the given axis.
+///
+/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
+/// (may be made more flexible in the future).
+/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
+/// if the result is larger than is possible to represent.
+///
+/// ```
+/// use ndarray::{arr2, Axis, concatenate};
+///
+/// let a = arr2(&[[2., 2.],
+/// [3., 3.]]);
+/// assert!(
+/// concatenate(Axis(0), &[a.view(), a.view()])
+/// == Ok(arr2(&[[2., 2.],
+/// [3., 3.],
+/// [2., 2.],
+/// [3., 3.]]))
+/// );
+/// ```
+#[allow(deprecated)]
+pub fn concatenate(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError>
+where
+ A: Copy,
+ D: RemoveAxis,
+{
+ stack(axis, arrays)
+}
+
+/// Stack arrays along the new axis.
+///
+/// ***Errors*** if the arrays have mismatching shapes.
+/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
+/// if the result is larger than is possible to represent.
+///
+/// ```
+/// extern crate ndarray;
+///
+/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
+///
+/// # fn main() {
+///
+/// let a = arr2(&[[2., 2.],
+/// [3., 3.]]);
+/// assert!(
+/// stack_new_axis(Axis(0), &[a.view(), a.view()])
+/// == Ok(arr3(&[[[2., 2.],
+/// [3., 3.]],
+/// [[2., 2.],
+/// [3., 3.]]]))
+/// );
+/// # }
+/// ```
+pub fn stack_new_axis(
+ axis: Axis,
+ arrays: &[ArrayView],
+) -> Result, ShapeError>
+where
+ A: Copy,
+ D: Dimension,
+ D::Larger: RemoveAxis,
+{
+ if arrays.is_empty() {
+ return Err(from_kind(ErrorKind::Unsupported));
+ }
+ let common_dim = arrays[0].raw_dim();
+ // Avoid panic on `insert_axis` call, return an Err instead of it.
+ if axis.index() > common_dim.ndim() {
+ return Err(from_kind(ErrorKind::OutOfBounds));
+ }
+ let mut res_dim = common_dim.insert_axis(axis);
+
+ if arrays.iter().any(|a| a.raw_dim() != common_dim) {
+ return Err(from_kind(ErrorKind::IncompatibleShape));
+ }
+
+ res_dim.set_axis(axis, arrays.len());
+
+ // we can safely use uninitialized values here because they are Copy
+ // and we will only ever write to them
+ let size = res_dim.size();
+ let mut v = Vec::with_capacity(size);
+ unsafe {
+ v.set_len(size);
+ }
+ let mut res = Array::from_shape_vec(res_dim, v)?;
+
+ res.axis_iter_mut(axis)
+ .zip(arrays.into_iter())
+ .for_each(|(mut assign_view, array)| {
+ assign_view.assign(&array);
+ });
+
+ Ok(res)
+}
+
+/// Concatenate arrays along the given axis.
///
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
@@ -101,9 +198,81 @@ where
/// );
/// # }
/// ```
+#[deprecated(
+ since = "0.13.2",
+ note = "Please use the `concatenate!` macro instead"
+)]
#[macro_export]
macro_rules! stack {
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}
+
+/// Concatenate arrays along the given axis.
+///
+/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each
+/// argument `a`.
+///
+/// [1]: fn.concatenate.html
+///
+/// ***Panics*** if the `concatenate` function would return an error.
+///
+/// ```
+/// extern crate ndarray;
+///
+/// use ndarray::{arr2, concatenate, Axis};
+///
+/// # fn main() {
+///
+/// let a = arr2(&[[2., 2.],
+/// [3., 3.]]);
+/// assert!(
+/// concatenate![Axis(0), a, a]
+/// == arr2(&[[2., 2.],
+/// [3., 3.],
+/// [2., 2.],
+/// [3., 3.]])
+/// );
+/// # }
+/// ```
+#[macro_export]
+macro_rules! concatenate {
+ ($axis:expr, $( $array:expr ),+ ) => {
+ $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
+ }
+}
+
+/// Stack arrays along the new axis.
+///
+/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each
+/// argument `a`.
+///
+/// [1]: fn.stack_new_axis.html
+///
+/// ***Panics*** if the `stack` function would return an error.
+///
+/// ```
+/// extern crate ndarray;
+///
+/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
+///
+/// # fn main() {
+///
+/// let a = arr2(&[[2., 2.],
+/// [3., 3.]]);
+/// assert!(
+/// stack_new_axis![Axis(0), a, a]
+/// == arr3(&[[[2., 2.],
+/// [3., 3.]],
+/// [[2., 2.],
+/// [3., 3.]]])
+/// );
+/// # }
+/// ```
+#[macro_export]
+macro_rules! stack_new_axis {
+ ($axis:expr, $( $array:expr ),+ ) => {
+ $crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
+ }
+}
diff --git a/tests/stacking.rs b/tests/stacking.rs
index a9a031711..94077def2 100644
--- a/tests/stacking.rs
+++ b/tests/stacking.rs
@@ -1,7 +1,9 @@
-use ndarray::{arr2, aview1, stack, Array2, Axis, ErrorKind};
+#![allow(deprecated)]
+
+use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1};
#[test]
-fn stacking() {
+fn concatenating() {
let a = arr2(&[[2., 2.], [3., 3.]]);
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
@@ -23,4 +25,43 @@ fn stacking() {
let res: Result, _> = ndarray::stack(Axis(0), &[]);
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
+
+ let a = arr2(&[[2., 2.], [3., 3.]]);
+ let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
+ assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
+
+ let c = concatenate![Axis(0), a, b];
+ assert_eq!(
+ c,
+ arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
+ );
+
+ let d = concatenate![Axis(0), a.row(0), &[9., 9.]];
+ assert_eq!(d, aview1(&[2., 2., 9., 9.]));
+
+ let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]);
+ assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);
+
+ let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]);
+ assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);
+
+ let res: Result, _> = ndarray::concatenate(Axis(0), &[]);
+ assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
+}
+
+#[test]
+fn stacking() {
+ let a = arr2(&[[2., 2.], [3., 3.]]);
+ let b = ndarray::stack_new_axis(Axis(0), &[a.view(), a.view()]).unwrap();
+ assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));
+
+ let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]);
+ let res = ndarray::stack_new_axis(Axis(1), &[a.view(), c.view()]);
+ assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);
+
+ let res = ndarray::stack_new_axis(Axis(3), &[a.view(), a.view()]);
+ assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);
+
+ let res: Result, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), &[]);
+ assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
}