diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index 6971944e..e3030d7f 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -6,7 +6,7 @@ use std::ops::{Index, IndexMut, Range}; use crate::copy::{ copy_into, copy_into_slice, copy_into_uninit, copy_range_into_slice, map_into_slice, }; -use crate::errors::{DimensionError, ExpandError, FromDataError, SliceError}; +use crate::errors::{DimensionError, ExpandError, FromDataError, ReshapeError, SliceError}; use crate::iterators::{ for_each_mut, AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterDyn, InnerIterDynMut, InnerIterMut, Iter, IterMut, Lanes, LanesMut, MutViewRef, ViewRef, @@ -249,18 +249,43 @@ pub trait AsView: Layout { self.view().permuted(order) } - /// Return a view with a given shape, without copying any data. This - /// requires that the tensor is contiguous. + /// Return either a view or a copy of `self` with the given shape. /// /// The new shape must have the same number of elments as the current /// shape. The result will have a static rank if `shape` is an array or /// a dynamic rank if it is a slice. /// - /// Panics if the tensor is not contiguous. - fn reshaped(&self, shape: S) -> TensorBase, S::Layout> { + /// If `self` is contiguous this will return a view, as changing the shape + /// can be done without moving data. Otherwise it will copy elements into + /// a new tensor. + /// + /// # Panics + /// + /// Panics if the number of elements in the new shape does not match the + /// current shape. + fn reshaped( + &self, + shape: S, + ) -> TensorBase, S::Layout> + where + Self::Elem: Clone, + { self.view().reshaped(shape) } + /// A variant of [`reshaped`](AsView::reshaped) that allows specifying the + /// allocator to use if a copy is needed. + fn reshaped_in( + &self, + alloc: A, + shape: S, + ) -> TensorBase, S::Layout> + where + Self::Elem: Clone, + { + self.view().reshaped_in(alloc, shape) + } + /// Reverse the order of dimensions in this tensor. fn transpose(&mut self); @@ -750,18 +775,18 @@ impl TensorBase { /// Change the layout of the tensor without moving any data. /// - /// See [`AsView::reshaped`]. + /// This will return an error if the view is not contiguous. + /// + /// See also [`AsView::reshaped`]. pub fn reshaped_mut( &mut self, shape: SH, - ) -> TensorBase, SH::Layout> { - TensorBase { - layout: self - .layout - .reshaped_for_view(shape) - .expect("reshape failed"), + ) -> Result, SH::Layout>, ReshapeError> { + let layout = self.layout.reshaped_for_view(shape)?; + Ok(TensorBase { + layout, data: self.data.view_mut(), - } + }) } /// Slice this tensor along a given axis. @@ -1451,16 +1476,39 @@ impl<'a, T, L: Clone + MutLayout> TensorBase, L> { } } - /// Change the shape of this tensor without copying data. + /// Return a view or owned tensor that has the given shape. /// /// See [`AsView::reshaped`]. - pub fn reshaped(&self, shape: S) -> TensorBase, S::Layout> { - TensorBase { - data: self.data, - layout: self + pub fn reshaped(&self, shape: S) -> TensorBase, S::Layout> + where + T: Clone, + { + self.reshaped_in(GlobalAlloc::new(), shape) + } + + /// Variant of [`reshaped`](Self::reshaped) that takes an allocator. + pub fn reshaped_in( + &self, + alloc: A, + shape: S, + ) -> TensorBase, S::Layout> + where + T: Clone, + { + if let Ok(layout) = self.layout.reshaped_for_view(shape) { + TensorBase { + data: CowData::Borrowed(self.data), + layout, + } + } else { + let layout = self .layout - .reshaped_for_view(shape) - .expect("reshape failed"), + .reshaped_for_copy(shape) + .expect("invalid target shape for `reshape`"); + TensorBase { + data: CowData::Owned(self.to_vec_in(alloc)), + layout, + } } } @@ -1879,19 +1927,6 @@ impl TensorBase, DynLayout> { } } -impl TensorBase, DynLayout> { - /// Reshape this view. - /// - /// Panics if the view is not contiguous. - pub fn reshape(&mut self, shape: &[usize]) - where - T: Clone, - { - assert!(self.is_contiguous(), "can only reshape contiguous views"); - self.layout = DynLayout::from_shape(shape); - } -} - impl<'a, T, L: MutLayout> TensorBase, L> { /// Divide this tensor into two mutable views along a given axis. /// @@ -1930,19 +1965,6 @@ impl<'a, T, L: MutLayout> TensorBase, L> { } } -impl TensorBase, DynLayout> { - /// Reshape this view. - /// - /// Panics if the view is not contiguous. - pub fn reshape(&mut self, shape: &[usize]) - where - T: Clone, - { - assert!(self.is_contiguous(), "can only reshape contiguous views"); - self.layout = DynLayout::from_shape(shape); - } -} - impl FromIterator for TensorBase, L> where [usize; 1]: AsIndex, @@ -3314,22 +3336,11 @@ mod tests { #[test] fn test_reshape() { - // Owned tensor let mut tensor = Tensor::::from_data(&[2, 2], vec![1., 2., 3., 4.]); tensor.transpose(); tensor.reshape(&[4]); assert_eq!(tensor.shape(), &[4]); assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]); - - // View - let mut view = tensor.view(); - view.reshape(&[2, 2]); - assert_eq!(view.shape(), &[2, 2]); - - // Mut view - let mut view_mut = tensor.view_mut(); - view_mut.reshape(&[2, 2]); - assert_eq!(view_mut.shape(), &[2, 2]); } #[test] @@ -3342,19 +3353,36 @@ mod tests { #[test] fn test_reshaped() { let data = &[1., 2., 3., 4., 5., 6.]; - let tensor = NdTensorView::from_data([1, 1, 2, 1, 3], data); + let tensor = NdTensorView::from_data([2, 3], data); - // Reshape to static dim count + // Non-copying reshape to static dim count let reshaped = tensor.reshaped([6]); assert_eq!(reshaped.shape(), [6]); + assert_eq!( + reshaped.view().storage().as_ptr(), + tensor.view().storage().as_ptr() + ); + + // Copying reshape to static dim count + let reshaped = tensor.transposed().reshaped([6]); + assert_eq!(reshaped.shape(), [6]); + assert_ne!( + reshaped.view().storage().as_ptr(), + tensor.view().storage().as_ptr() + ); + assert_eq!(reshaped.to_vec(), &[1., 4., 2., 5., 3., 6.]); - // Reshape to dynamic dim count + // Non-copying reshape to dynamic dim count let reshaped = tensor.reshaped([6].as_slice()); assert_eq!(reshaped.shape(), &[6]); + assert_eq!( + reshaped.view().storage().as_ptr(), + tensor.view().storage().as_ptr() + ); } #[test] - #[should_panic(expected = "reshape failed")] + #[should_panic(expected = "invalid target shape for `reshape`: LengthMismatch")] fn test_reshaped_invalid() { let tensor = NdTensor::arange(0, 16, None); tensor.reshaped([2, 2]); @@ -3365,7 +3393,7 @@ mod tests { let data = vec![1., 2., 3., 4., 5., 6.]; let mut tensor = NdTensor::from_data([1, 1, 2, 1, 3], data); - let mut reshaped = tensor.reshaped_mut([6]); + let mut reshaped = tensor.reshaped_mut([6]).unwrap(); reshaped[[0]] = 0.; reshaped[[5]] = 0.; diff --git a/src/ops/conv.rs b/src/ops/conv.rs index b890d697..dfd4c5fd 100644 --- a/src/ops/conv.rs +++ b/src/ops/conv.rs @@ -34,11 +34,7 @@ where let [out_c, in_c, _, _]: [usize; 4] = kernel.shape(); let mut output = NdTensor::uninit_in(pool, [batch, out_c, in_h * in_w]); - // Get input and kernel as contiguous tensors so we can create reshaped - // views. - let input = input.to_contiguous_in(pool).auto_return(pool); - let kernel = kernel.to_contiguous_in(pool).auto_return(pool); - let kernel_mat = kernel.reshaped([out_c, in_c]); + let kernel_mat = kernel.reshaped_in(pool, [out_c, in_c]).auto_return(pool); // Bias must be contiguous for use with `gemm_bias`. let bias = bias.as_ref().map(|b| b.to_contiguous()); @@ -51,13 +47,16 @@ where let mut out_item = output.slice_mut([n]); let out_row_stride = out_item.stride(0); - let in_mat = input.slice([n]).reshaped([in_c, in_h * in_w]); + let in_mat = input + .slice([n]) + .reshaped_in(pool, [in_c, in_h * in_w]) + .auto_return(pool); gemm.gemm_uninit_bias( out_item.data_mut().unwrap(), out_row_stride, - GemmInputA::Unpacked(kernel_mat), - GemmInputB::Unpacked(in_mat), + GemmInputA::Unpacked(kernel_mat.view()), + GemmInputB::Unpacked(in_mat.view()), 1., // alpha bias_vec, ); @@ -242,14 +241,14 @@ where let in_group = input.slice((.., in_chan_start..in_chan_end)); let mut out_group = output.slice_mut((.., out_chans.clone())); - let kernel = kernel.to_contiguous_in(pool); - let kernel_mat = kernel - .slice([out_chans.clone()]) - .reshaped([out_channels_per_group, in_channels_per_group * k_h * k_w]); + let kernel_mat = kernel.slice([out_chans.clone()]).reshaped_in( + pool, + [out_channels_per_group, in_channels_per_group * k_h * k_w], + ); // Prepack kernel if we'll be able to reuse packed weights. let prepacked_kernel = if in_group.size(0) > 1 { - Some(gemm.prepack_a_in(pool, kernel_mat).auto_return(pool)) + Some(gemm.prepack_a_in(pool, kernel_mat.view()).auto_return(pool)) } else { None }; @@ -260,7 +259,9 @@ where .zip(in_group.axis_iter(0)) .par_bridge() .for_each(|(mut out_item, in_item)| { - let mut out_mat = out_item.reshaped_mut([out_channels_per_group, out_h * out_w]); + let mut out_mat = out_item + .reshaped_mut([out_channels_per_group, out_h * out_w]) + .unwrap(); let out_row_stride = out_mat.stride(0); let im2col = VirtualIm2Col::new( @@ -281,7 +282,7 @@ where out_row_stride, prepacked_kernel .map(GemmInputA::Packed) - .unwrap_or(GemmInputA::Unpacked(kernel_mat)), + .unwrap_or(GemmInputA::Unpacked(kernel_mat.view())), GemmInputB::Virtual(&im2col), 1., // alpha bias_vec, @@ -476,11 +477,12 @@ pub fn conv_transpose( if let &[n, c, w] = input.shape() { let [out_c, k_in_c, k_w] = static_dims!(kernel, 3, "OCW")?.shape(); - let mut input_2d = input.clone(); - input_2d.reshape(&[n, c, 1, w]); - - let mut kernel_2d = kernel.clone(); - kernel_2d.reshape(&[out_c, k_in_c, 1, k_w]); + let input_2d = input + .reshaped_in(pool, [n, c, 1, w].as_slice()) + .auto_return(pool); + let kernel_2d = kernel + .reshaped_in(pool, [out_c, k_in_c, 1, k_w].as_slice()) + .auto_return(pool); let padding_2d = padding.expand_1d_to_2d()?; @@ -491,7 +493,14 @@ pub fn conv_transpose( } }; - let result_2d = conv_transpose(pool, input_2d, kernel_2d, bias, padding_2d, &strides_2d); + let result_2d = conv_transpose( + pool, + input_2d.view(), + kernel_2d.view(), + bias, + padding_2d, + &strides_2d, + ); return result_2d.map(|mut t| { let [n, c, _h, w]: [usize; 4] = t.shape().try_into().expect("expected 4D output"); @@ -529,26 +538,28 @@ pub fn conv_transpose( let mut output = NdTensor::uninit_in(pool, [batch, out_c, out_h, out_w]); - // Ensure input and kernel are contiguous to support reshaping. - let input = input.to_contiguous_in(pool).auto_return(pool); - let kernel = kernel.to_contiguous_in(pool).auto_return(pool); - let mut col2im_mat = NdTensor::uninit_in(pool, [out_c * k_h * k_w, in_h * in_w]).auto_return(pool); - let kernel_mat = kernel.reshaped([k_in_c, out_c * k_h * k_w]).transposed(); + let kernel_mat = kernel + .reshaped_in(pool, [k_in_c, out_c * k_h * k_w]) + .auto_return(pool); + let kernel_mat = kernel_mat.transposed(); let gemm = GemmExecutor::new(); // The implementation here is the inverse of the im2col-based convolution. let mut n_init = 0; for n in 0..batch { - let input_mat = input.slice([n]).reshaped([in_c, in_h * in_w]); + let input_mat = input + .slice([n]) + .reshaped_in(pool, [in_c, in_h * in_w]) + .auto_return(pool); let col2im_row_stride = col2im_mat.stride(0); gemm.gemm_uninit( col2im_mat.data_mut().unwrap(), col2im_row_stride, GemmInputA::Unpacked(kernel_mat), - GemmInputB::Unpacked(input_mat), + GemmInputB::Unpacked(input_mat.view()), 1., /* alpha */ ); @@ -558,7 +569,7 @@ pub fn conv_transpose( col2im( &mut out_img, - &col2im_mat.reshaped([out_c, k_h, k_w, in_h, in_w]), + &col2im_mat.reshaped([out_c, k_h, k_w, in_h, in_w]).view(), [pad_top, pad_left, pad_right, pad_bottom], [stride_h, stride_w], bias, diff --git a/src/ops/einsum.rs b/src/ops/einsum.rs index 19c7cf26..369f8c37 100644 --- a/src/ops/einsum.rs +++ b/src/ops/einsum.rs @@ -379,8 +379,8 @@ fn einsum_step( .collect(); einsum_matmul( pool, - &x, - &y, + &x.view(), + &y.view(), &term_simplified, &term_simplified, &step.output, diff --git a/src/ops/layout.rs b/src/ops/layout.rs index d326dc70..d8946dcf 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -45,12 +45,12 @@ pub fn depth_to_space( // See https://onnx.ai/onnx/operators/onnx__DepthToSpace.html#summary let tmp = input.to_contiguous_in(pool); let tmp = match mode { - DepthToSpaceMode::DepthColumnRow => tmp - .reshaped([n, block_size, block_size, new_c, h, w]) - .permuted([0, 3, 4, 1, 5, 2]), - DepthToSpaceMode::ColumnRowDepth => tmp - .reshaped([n, new_c, block_size, block_size, h, w]) - .permuted([0, 1, 4, 2, 5, 3]), + DepthToSpaceMode::DepthColumnRow => tmp.reshaped([n, block_size, block_size, new_c, h, w]), + DepthToSpaceMode::ColumnRowDepth => tmp.reshaped([n, new_c, block_size, block_size, h, w]), + }; + let tmp = match mode { + DepthToSpaceMode::DepthColumnRow => tmp.permuted([0, 3, 4, 1, 5, 2]), + DepthToSpaceMode::ColumnRowDepth => tmp.permuted([0, 1, 4, 2, 5, 3]), }; let mut tmp = tmp.to_tensor_in(pool).into_dyn(); tmp.reshape(&new_shape); @@ -1372,7 +1372,10 @@ mod tests { let reference_transpose_stats = run_bench(100, None, || { let transposed = tensor.permuted(perm); - reference_transpose_into(transposed.view(), dest.reshaped_mut(transposed.shape())); + reference_transpose_into( + transposed.view(), + dest.reshaped_mut(transposed.shape()).unwrap(), + ); }); let transpose_stats = run_bench(100, None, || { diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index 729390a4..7005009f 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -188,7 +188,7 @@ where // nb. We assume `a` is likely already contiguous, so this will be cheap. let a_contig = a.to_contiguous_in(pool).auto_return(pool); let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice()); - let mut output = matmul_impl(pool, a_matrix, b.clone(), strategy, bias)?; + let mut output = matmul_impl(pool, a_matrix.view(), b.clone(), strategy, bias)?; output.reshape(out_shape); return Ok(output); }