Skip to content

Commit

Permalink
Merge pull request #463 from robertknight/reduce-pack
Browse files Browse the repository at this point in the history
Use packing buffer when reducing non-contiguous tensors
  • Loading branch information
robertknight authored Dec 16, 2024
2 parents 9e17677 + 3168392 commit dbbe48f
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 140 deletions.
3 changes: 1 addition & 2 deletions rten-tensor/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,13 @@ fn copy_blocked<T: Clone>(src: Matrix<T>, mut dest: MatrixMut<MaybeUninit<T>>) {
///
/// Returns `dest` as an initialized slice.
pub fn copy_into_slice<'a, T: Clone>(
src: TensorView<T>,
mut src: TensorView<T>,
dest: &'a mut [MaybeUninit<T>],
) -> &'a [T] {
assert!(dest.len() == src.len());

// Merge axes to increase the chance that we can use the fast path and
// also maximize the iteration count of the innermost loops.
let mut src = src.clone();
src.merge_axes();

if src.ndim() > 4 {
Expand Down
39 changes: 39 additions & 0 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ pub trait AsView: Layout {
self.view().broadcast(shape)
}

/// Copy elements from this tensor into `dest` in logical order.
///
/// Returns the initialized slice. Panics if the length of `dest` does
/// not match the number of elements in `self`.
fn copy_into_slice<'a>(&self, dest: &'a mut [MaybeUninit<Self::Elem>]) -> &'a [Self::Elem]
where
Self::Elem: Copy;

/// Return the layout of this tensor as a slice, if it is contiguous.
fn data(&self) -> Option<&[Self::Elem]>;

Expand Down Expand Up @@ -1676,6 +1684,22 @@ impl<T, S: Storage<Elem = T>, L: MutLayout + Clone> AsView for TensorBase<S, L>
self.view().iter()
}

fn copy_into_slice<'a>(&self, dest: &'a mut [MaybeUninit<T>]) -> &'a [T]
where
T: Copy,
{
if let Some(data) = self.data() {
// Safety: `[T]` and `[MaybeUninit<T>]` have same layout.
let src_uninit = unsafe { std::mem::transmute::<&[T], &[MaybeUninit<T>]>(data) };
dest.copy_from_slice(src_uninit);
// Safety: `copy_from_slice` initializes the whole slice or panics
// if there is a length mismatch.
unsafe { std::mem::transmute::<&[MaybeUninit<T>], &[T]>(dest) }
} else {
copy_into_slice(self.as_dyn(), dest)
}
}

fn data(&self) -> Option<&[Self::Elem]> {
self.view().data()
}
Expand Down Expand Up @@ -2540,6 +2564,21 @@ mod tests {
assert_eq!(dest.to_vec(), &[1., 2., 3., 4.]);
}

#[test]
fn test_copy_into_slice() {
let src = NdTensor::from([[1, 2], [3, 4], [5, 6]]);
let mut buf = Vec::with_capacity(src.len());
let buf_uninit = &mut buf.spare_capacity_mut()[..src.len()];

// Contiguous case.
let elts = src.copy_into_slice(buf_uninit);
assert_eq!(elts, &[1, 2, 3, 4, 5, 6]);

// Non-contiguous case.
let transposed_elts = src.transposed().copy_into_slice(buf_uninit);
assert_eq!(transposed_elts, &[1, 3, 5, 2, 4, 6]);
}

#[test]
fn test_data() {
let data = &[1., 2., 3., 4., 5., 6.];
Expand Down
Loading

0 comments on commit dbbe48f

Please sign in to comment.