Skip to content

Commit

Permalink
Merge pull request #472 from robertknight/reshape-cow
Browse files Browse the repository at this point in the history
Make `TensorBase::reshaped` copy instead of panic if input is not contiguous
  • Loading branch information
robertknight authored Dec 20, 2024
2 parents 24313cf + d291e1e commit 8655867
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 101 deletions.
152 changes: 90 additions & 62 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::ops::{Index, IndexMut, Range};
use crate::copy::{
copy_into, copy_into_slice, copy_into_uninit, copy_range_into_slice, map_into_slice,
};
use crate::errors::{DimensionError, ExpandError, FromDataError, SliceError};
use crate::errors::{DimensionError, ExpandError, FromDataError, ReshapeError, SliceError};
use crate::iterators::{
for_each_mut, AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterDyn,
InnerIterDynMut, InnerIterMut, Iter, IterMut, Lanes, LanesMut, MutViewRef, ViewRef,
Expand Down Expand Up @@ -249,18 +249,43 @@ pub trait AsView: Layout {
self.view().permuted(order)
}

/// Return a view with a given shape, without copying any data. This
/// requires that the tensor is contiguous.
/// Return either a view or a copy of `self` with the given shape.
///
/// The new shape must have the same number of elments as the current
/// shape. The result will have a static rank if `shape` is an array or
/// a dynamic rank if it is a slice.
///
/// Panics if the tensor is not contiguous.
fn reshaped<S: IntoLayout>(&self, shape: S) -> TensorBase<ViewData<'_, Self::Elem>, S::Layout> {
/// If `self` is contiguous this will return a view, as changing the shape
/// can be done without moving data. Otherwise it will copy elements into
/// a new tensor.
///
/// # Panics
///
/// Panics if the number of elements in the new shape does not match the
/// current shape.
fn reshaped<S: Copy + IntoLayout>(
&self,
shape: S,
) -> TensorBase<CowData<'_, Self::Elem>, S::Layout>
where
Self::Elem: Clone,
{
self.view().reshaped(shape)
}

/// A variant of [`reshaped`](AsView::reshaped) that allows specifying the
/// allocator to use if a copy is needed.
fn reshaped_in<A: Alloc, S: Copy + IntoLayout>(
&self,
alloc: A,
shape: S,
) -> TensorBase<CowData<'_, Self::Elem>, S::Layout>
where
Self::Elem: Clone,
{
self.view().reshaped_in(alloc, shape)
}

/// Reverse the order of dimensions in this tensor.
fn transpose(&mut self);

Expand Down Expand Up @@ -750,18 +775,18 @@ impl<S: StorageMut, L: MutLayout> TensorBase<S, L> {

/// Change the layout of the tensor without moving any data.
///
/// See [`AsView::reshaped`].
/// This will return an error if the view is not contiguous.
///
/// See also [`AsView::reshaped`].
pub fn reshaped_mut<SH: IntoLayout>(
&mut self,
shape: SH,
) -> TensorBase<ViewMutData<S::Elem>, SH::Layout> {
TensorBase {
layout: self
.layout
.reshaped_for_view(shape)
.expect("reshape failed"),
) -> Result<TensorBase<ViewMutData<S::Elem>, SH::Layout>, ReshapeError> {
let layout = self.layout.reshaped_for_view(shape)?;
Ok(TensorBase {
layout,
data: self.data.view_mut(),
}
})
}

/// Slice this tensor along a given axis.
Expand Down Expand Up @@ -1451,16 +1476,39 @@ impl<'a, T, L: Clone + MutLayout> TensorBase<ViewData<'a, T>, L> {
}
}

/// Change the shape of this tensor without copying data.
/// Return a view or owned tensor that has the given shape.
///
/// See [`AsView::reshaped`].
pub fn reshaped<S: IntoLayout>(&self, shape: S) -> TensorBase<ViewData<'a, T>, S::Layout> {
TensorBase {
data: self.data,
layout: self
pub fn reshaped<S: Copy + IntoLayout>(&self, shape: S) -> TensorBase<CowData<'a, T>, S::Layout>
where
T: Clone,
{
self.reshaped_in(GlobalAlloc::new(), shape)
}

/// Variant of [`reshaped`](Self::reshaped) that takes an allocator.
pub fn reshaped_in<A: Alloc, S: Copy + IntoLayout>(
&self,
alloc: A,
shape: S,
) -> TensorBase<CowData<'a, T>, S::Layout>
where
T: Clone,
{
if let Ok(layout) = self.layout.reshaped_for_view(shape) {
TensorBase {
data: CowData::Borrowed(self.data),
layout,
}
} else {
let layout = self
.layout
.reshaped_for_view(shape)
.expect("reshape failed"),
.reshaped_for_copy(shape)
.expect("invalid target shape for `reshape`");
TensorBase {
data: CowData::Owned(self.to_vec_in(alloc)),
layout,
}
}
}

Expand Down Expand Up @@ -1879,19 +1927,6 @@ impl<T> TensorBase<Vec<T>, DynLayout> {
}
}

impl<T> TensorBase<ViewData<'_, T>, DynLayout> {
/// Reshape this view.
///
/// Panics if the view is not contiguous.
pub fn reshape(&mut self, shape: &[usize])
where
T: Clone,
{
assert!(self.is_contiguous(), "can only reshape contiguous views");
self.layout = DynLayout::from_shape(shape);
}
}

impl<'a, T, L: MutLayout> TensorBase<ViewMutData<'a, T>, L> {
/// Divide this tensor into two mutable views along a given axis.
///
Expand Down Expand Up @@ -1930,19 +1965,6 @@ impl<'a, T, L: MutLayout> TensorBase<ViewMutData<'a, T>, L> {
}
}

impl<T> TensorBase<ViewMutData<'_, T>, DynLayout> {
/// Reshape this view.
///
/// Panics if the view is not contiguous.
pub fn reshape(&mut self, shape: &[usize])
where
T: Clone,
{
assert!(self.is_contiguous(), "can only reshape contiguous views");
self.layout = DynLayout::from_shape(shape);
}
}

impl<T, L: Clone + MutLayout> FromIterator<T> for TensorBase<Vec<T>, L>
where
[usize; 1]: AsIndex<L>,
Expand Down Expand Up @@ -3314,22 +3336,11 @@ mod tests {

#[test]
fn test_reshape() {
// Owned tensor
let mut tensor = Tensor::<f32>::from_data(&[2, 2], vec![1., 2., 3., 4.]);
tensor.transpose();
tensor.reshape(&[4]);
assert_eq!(tensor.shape(), &[4]);
assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]);

// View
let mut view = tensor.view();
view.reshape(&[2, 2]);
assert_eq!(view.shape(), &[2, 2]);

// Mut view
let mut view_mut = tensor.view_mut();
view_mut.reshape(&[2, 2]);
assert_eq!(view_mut.shape(), &[2, 2]);
}

#[test]
Expand All @@ -3342,19 +3353,36 @@ mod tests {
#[test]
fn test_reshaped() {
let data = &[1., 2., 3., 4., 5., 6.];
let tensor = NdTensorView::from_data([1, 1, 2, 1, 3], data);
let tensor = NdTensorView::from_data([2, 3], data);

// Reshape to static dim count
// Non-copying reshape to static dim count
let reshaped = tensor.reshaped([6]);
assert_eq!(reshaped.shape(), [6]);
assert_eq!(
reshaped.view().storage().as_ptr(),
tensor.view().storage().as_ptr()
);

// Copying reshape to static dim count
let reshaped = tensor.transposed().reshaped([6]);
assert_eq!(reshaped.shape(), [6]);
assert_ne!(
reshaped.view().storage().as_ptr(),
tensor.view().storage().as_ptr()
);
assert_eq!(reshaped.to_vec(), &[1., 4., 2., 5., 3., 6.]);

// Reshape to dynamic dim count
// Non-copying reshape to dynamic dim count
let reshaped = tensor.reshaped([6].as_slice());
assert_eq!(reshaped.shape(), &[6]);
assert_eq!(
reshaped.view().storage().as_ptr(),
tensor.view().storage().as_ptr()
);
}

#[test]
#[should_panic(expected = "reshape failed")]
#[should_panic(expected = "invalid target shape for `reshape`: LengthMismatch")]
fn test_reshaped_invalid() {
let tensor = NdTensor::arange(0, 16, None);
tensor.reshaped([2, 2]);
Expand All @@ -3365,7 +3393,7 @@ mod tests {
let data = vec![1., 2., 3., 4., 5., 6.];
let mut tensor = NdTensor::from_data([1, 1, 2, 1, 3], data);

let mut reshaped = tensor.reshaped_mut([6]);
let mut reshaped = tensor.reshaped_mut([6]).unwrap();
reshaped[[0]] = 0.;
reshaped[[5]] = 0.;

Expand Down
69 changes: 40 additions & 29 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ where
let [out_c, in_c, _, _]: [usize; 4] = kernel.shape();
let mut output = NdTensor::uninit_in(pool, [batch, out_c, in_h * in_w]);

// Get input and kernel as contiguous tensors so we can create reshaped
// views.
let input = input.to_contiguous_in(pool).auto_return(pool);
let kernel = kernel.to_contiguous_in(pool).auto_return(pool);
let kernel_mat = kernel.reshaped([out_c, in_c]);
let kernel_mat = kernel.reshaped_in(pool, [out_c, in_c]).auto_return(pool);

// Bias must be contiguous for use with `gemm_bias`.
let bias = bias.as_ref().map(|b| b.to_contiguous());
Expand All @@ -51,13 +47,16 @@ where
let mut out_item = output.slice_mut([n]);
let out_row_stride = out_item.stride(0);

let in_mat = input.slice([n]).reshaped([in_c, in_h * in_w]);
let in_mat = input
.slice([n])
.reshaped_in(pool, [in_c, in_h * in_w])
.auto_return(pool);

gemm.gemm_uninit_bias(
out_item.data_mut().unwrap(),
out_row_stride,
GemmInputA::Unpacked(kernel_mat),
GemmInputB::Unpacked(in_mat),
GemmInputA::Unpacked(kernel_mat.view()),
GemmInputB::Unpacked(in_mat.view()),
1., // alpha
bias_vec,
);
Expand Down Expand Up @@ -242,14 +241,14 @@ where
let in_group = input.slice((.., in_chan_start..in_chan_end));
let mut out_group = output.slice_mut((.., out_chans.clone()));

let kernel = kernel.to_contiguous_in(pool);
let kernel_mat = kernel
.slice([out_chans.clone()])
.reshaped([out_channels_per_group, in_channels_per_group * k_h * k_w]);
let kernel_mat = kernel.slice([out_chans.clone()]).reshaped_in(
pool,
[out_channels_per_group, in_channels_per_group * k_h * k_w],
);

// Prepack kernel if we'll be able to reuse packed weights.
let prepacked_kernel = if in_group.size(0) > 1 {
Some(gemm.prepack_a_in(pool, kernel_mat).auto_return(pool))
Some(gemm.prepack_a_in(pool, kernel_mat.view()).auto_return(pool))
} else {
None
};
Expand All @@ -260,7 +259,9 @@ where
.zip(in_group.axis_iter(0))
.par_bridge()
.for_each(|(mut out_item, in_item)| {
let mut out_mat = out_item.reshaped_mut([out_channels_per_group, out_h * out_w]);
let mut out_mat = out_item
.reshaped_mut([out_channels_per_group, out_h * out_w])
.unwrap();
let out_row_stride = out_mat.stride(0);

let im2col = VirtualIm2Col::new(
Expand All @@ -281,7 +282,7 @@ where
out_row_stride,
prepacked_kernel
.map(GemmInputA::Packed)
.unwrap_or(GemmInputA::Unpacked(kernel_mat)),
.unwrap_or(GemmInputA::Unpacked(kernel_mat.view())),
GemmInputB::Virtual(&im2col),
1., // alpha
bias_vec,
Expand Down Expand Up @@ -476,11 +477,12 @@ pub fn conv_transpose(
if let &[n, c, w] = input.shape() {
let [out_c, k_in_c, k_w] = static_dims!(kernel, 3, "OCW")?.shape();

let mut input_2d = input.clone();
input_2d.reshape(&[n, c, 1, w]);

let mut kernel_2d = kernel.clone();
kernel_2d.reshape(&[out_c, k_in_c, 1, k_w]);
let input_2d = input
.reshaped_in(pool, [n, c, 1, w].as_slice())
.auto_return(pool);
let kernel_2d = kernel
.reshaped_in(pool, [out_c, k_in_c, 1, k_w].as_slice())
.auto_return(pool);

let padding_2d = padding.expand_1d_to_2d()?;

Expand All @@ -491,7 +493,14 @@ pub fn conv_transpose(
}
};

let result_2d = conv_transpose(pool, input_2d, kernel_2d, bias, padding_2d, &strides_2d);
let result_2d = conv_transpose(
pool,
input_2d.view(),
kernel_2d.view(),
bias,
padding_2d,
&strides_2d,
);

return result_2d.map(|mut t| {
let [n, c, _h, w]: [usize; 4] = t.shape().try_into().expect("expected 4D output");
Expand Down Expand Up @@ -529,26 +538,28 @@ pub fn conv_transpose(

let mut output = NdTensor::uninit_in(pool, [batch, out_c, out_h, out_w]);

// Ensure input and kernel are contiguous to support reshaping.
let input = input.to_contiguous_in(pool).auto_return(pool);
let kernel = kernel.to_contiguous_in(pool).auto_return(pool);

let mut col2im_mat =
NdTensor::uninit_in(pool, [out_c * k_h * k_w, in_h * in_w]).auto_return(pool);
let kernel_mat = kernel.reshaped([k_in_c, out_c * k_h * k_w]).transposed();
let kernel_mat = kernel
.reshaped_in(pool, [k_in_c, out_c * k_h * k_w])
.auto_return(pool);
let kernel_mat = kernel_mat.transposed();
let gemm = GemmExecutor::new();

// The implementation here is the inverse of the im2col-based convolution.
let mut n_init = 0;
for n in 0..batch {
let input_mat = input.slice([n]).reshaped([in_c, in_h * in_w]);
let input_mat = input
.slice([n])
.reshaped_in(pool, [in_c, in_h * in_w])
.auto_return(pool);

let col2im_row_stride = col2im_mat.stride(0);
gemm.gemm_uninit(
col2im_mat.data_mut().unwrap(),
col2im_row_stride,
GemmInputA::Unpacked(kernel_mat),
GemmInputB::Unpacked(input_mat),
GemmInputB::Unpacked(input_mat.view()),
1., /* alpha */
);

Expand All @@ -558,7 +569,7 @@ pub fn conv_transpose(

col2im(
&mut out_img,
&col2im_mat.reshaped([out_c, k_h, k_w, in_h, in_w]),
&col2im_mat.reshaped([out_c, k_h, k_w, in_h, in_w]).view(),
[pad_top, pad_left, pad_right, pad_bottom],
[stride_h, stride_w],
bias,
Expand Down
Loading

0 comments on commit 8655867

Please sign in to comment.