Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement chunk for different backends #1032

Merged
merged 5 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,21 @@ impl<B: Backend> BoolTensorOps<Self> for Autodiff<B> {
) -> <Autodiff<B> as Backend>::BoolTensorPrimitive<D> {
B::bool_swap_dims(tensor, dim1, dim2)
}

fn bool_narrow<const D: usize>(
tensor: BoolTensor<B, D>,
dim: usize,
start: usize,
length: usize,
) -> BoolTensor<B, D> {
B::bool_narrow(tensor, dim, start, length)
}

fn bool_chunk<const D: usize>(
tensor: BoolTensor<B, D>,
chunks: usize,
dim: usize,
) -> Vec<BoolTensor<B, D>> {
B::bool_chunk(tensor, chunks, dim)
}
}
17 changes: 17 additions & 0 deletions burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,21 @@ impl<B: Backend> IntTensorOps<Autodiff<B>> for Autodiff<B> {
) -> <Autodiff<B> as Backend>::IntTensorPrimitive<D> {
B::int_swap_dims(tensor, dim1, dim2)
}

fn int_narrow<const D: usize>(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive<D>,
dim: usize,
start: usize,
length: usize,
) -> <Autodiff<B> as Backend>::IntTensorPrimitive<D> {
B::int_narrow(tensor, dim, start, length)
}

fn int_chunk<const D: usize>(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive<D>,
chunks: usize,
dim: usize,
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive<D>> {
B::int_chunk(tensor, chunks, dim)
}
}
28 changes: 28 additions & 0 deletions burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,31 @@ pub fn slice_assign<E: CandleElement, const D1: usize, const D2: usize>(
) -> CandleTensor<E, D1> {
panic!("slice_assign not supported by Candle")
}

pub fn narrow<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
dim: usize,
start: usize,
length: usize,
) -> CandleTensor<E, D> {
let tensor = tensor.tensor.narrow(dim, start, length);
match tensor {
Ok(tensor) => CandleTensor::new(tensor),
Err(e) => panic!("error narrow from Candle"),
}
}

pub fn chunk<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
chunks: usize,
dim: usize,
) -> Vec<CandleTensor<E, D>> {
let tensors = tensor.tensor.chunk(chunks, dim);
match tensors {
Ok(tensors) => tensors
.into_iter()
.map(|tensor| CandleTensor::new(tensor))
.collect(),
Err(e) => panic!("error chunk from Candle"),
}
}
17 changes: 17 additions & 0 deletions burn-candle/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
) -> <Candle<F, I> as burn_tensor::backend::Backend>::BoolTensorPrimitive<D> {
super::base::swap_dims(tensor, dim1, dim2)
}

fn bool_narrow<const D: usize>(
tensor: BoolTensor<Self, D>,
dim: usize,
start: usize,
length: usize,
) -> BoolTensor<Self, D> {
super::base::narrow(tensor, dim, start, length)
}

fn bool_chunk<const D: usize>(
tensor: BoolTensor<Self, D>,
chunks: usize,
dim: usize,
) -> Vec<BoolTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}
}
17 changes: 17 additions & 0 deletions burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
) -> <Candle<F, I> as burn_tensor::backend::Backend>::IntTensorPrimitive<D> {
super::base::swap_dims(tensor, dim1, dim2)
}

fn int_narrow<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
start: usize,
length: usize,
) -> IntTensor<Self, D> {
super::base::narrow(tensor, dim, start, length)
}

fn int_chunk<const D: usize>(
tensor: IntTensor<Self, D>,
chunks: usize,
dim: usize,
) -> Vec<IntTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}
}
17 changes: 17 additions & 0 deletions burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.recip().unwrap())
}

fn narrow<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
start: usize,
length: usize,
) -> FloatTensor<Self, D> {
super::base::narrow(tensor, dim, start, length)
}

fn chunk<const D: usize>(
tensor: FloatTensor<Self, D>,
chunks: usize,
dim: usize,
) -> Vec<FloatTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}
}
26 changes: 26 additions & 0 deletions burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,4 +413,30 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);
TchTensor::new(tensor)
}

pub fn narrow<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<E, D> {
TchTensor::new(
tensor
.tensor
.narrow(dim as i64, start as i64, length as i64),
)
}

pub fn chunk<const D: usize>(
tensor: TchTensor<E, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<E, D>> {
tensor
.tensor
.chunk(chunks as i64, dim as i64)
.into_iter()
.map(|tensor| TchTensor::new(tensor))
.collect()
}
}
17 changes: 17 additions & 0 deletions burn-tch/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,21 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
) -> <LibTorch<E> as Backend>::BoolTensorPrimitive<D> {
TchOps::swap_dims(tensor, dim1, dim2)
}

fn bool_narrow<const D: usize>(
tensor: TchTensor<bool, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<bool, D> {
TchOps::narrow(tensor, dim, start, length)
}

fn bool_chunk<const D: usize>(
tensor: TchTensor<bool, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<bool, D>> {
TchOps::chunk(tensor, chunks, dim)
}
}
17 changes: 17 additions & 0 deletions burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,21 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
TchOps::swap_dims(tensor, dim1, dim2)
}

fn int_narrow<const D: usize>(
tensor: TchTensor<i64, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<i64, D> {
TchOps::narrow(tensor, dim, start, length)
}

fn int_chunk<const D: usize>(
tensor: TchTensor<i64, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<i64, D>> {
TchOps::chunk(tensor, chunks, dim)
}
}
17 changes: 17 additions & 0 deletions burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,4 +440,21 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
TchTensor::new(tensor)
}

fn narrow<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<E, D> {
TchOps::narrow(tensor, dim, start, length)
}

fn chunk<const D: usize>(
tensor: TchTensor<E, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<E, D>> {
TchOps::chunk(tensor, chunks, dim)
}
}
51 changes: 9 additions & 42 deletions burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use alloc::vec;
use burn_common::{reader::Reader, stub::Mutex};
use core::{fmt::Debug, ops::Range};

use crate::{
backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind,
};
use crate::check::TensorCheck;
use crate::tensor::api::chunk::chunk;
use crate::tensor::api::narrow::narrow;
use crate::{backend::Backend, check, Bool, Data, Float, Int, Shape, TensorKind};

/// A tensor with a given backend, shape and data type.
#[derive(new, Clone, Debug)]
Expand Down Expand Up @@ -496,20 +497,7 @@ where
pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
check!(TensorCheck::dim_ops::<D>("narrow", dim));
check!(TensorCheck::narrow(&self, dim, start, length));

let ranges: Vec<_> = (0..D)
.map(|i| {
if i == dim {
start..(start + length)
} else {
0..self.shape().dims[i]
}
})
.collect();

let ranges_array: [_; D] = ranges.try_into().unwrap();

self.slice(ranges_array)
Self::new(narrow::<B, D, K>(self.primitive, dim, start, length))
}

/// Attempts to split the tensor along the given dimension into chunks.
Expand All @@ -526,31 +514,10 @@ where
/// A vector of tensors.
pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
check!(TensorCheck::dim_ops::<D>("chunk", dim));

let size = self.shape().dims[dim];
if size < chunks {
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
}

let mut tensors = Vec::with_capacity(chunks);
let mut sum_chunk_size = 0;
if size % chunks == 0 {
let chunk_size = size / chunks;
for _ in 0..chunks {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
} else {
let chunk_size = (size / chunks) + 1; // assumes not divisible
for _ in 0..chunks - 1 {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
let remainder = size % chunk_size;
tensors.push(self.clone().narrow(dim, sum_chunk_size, remainder));
}

tensors
chunk::<B, D, K>(self.primitive, chunks, dim)
.into_iter()
.map(|v| Self::new(v))
.collect()
}
}

Expand Down
69 changes: 69 additions & 0 deletions burn-tensor/src/tensor/api/chunk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use super::narrow::narrow;
use crate::{backend::Backend, BasicOps, TensorKind};
use alloc::vec::Vec;

/// Split the tensor along the given dimension into chunks.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `chunks` - The number of chunks to be produced
/// * `times` - The dimension along which the tensor will be split.
///
/// # Returns
///
/// A vectors of tensors
///
/// # Remarks
///
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
pub fn chunk<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive<D>,
chunks: usize,
dim: usize,
) -> Vec<K::Primitive<D>> {
let size = K::shape(&tensor).dims[dim];
if size < chunks {
return (0..size)
.map(|i| narrow::<B, D, K>(tensor.clone(), dim, i, 1))
.collect();
}

let mut tensors = Vec::with_capacity(chunks);
let mut sum_chunk_size = 0;
if size % chunks == 0 {
let chunk_size = size / chunks;
for _ in 0..chunks {
tensors.push(narrow::<B, D, K>(
tensor.clone(),
dim,
sum_chunk_size,
chunk_size,
));
sum_chunk_size += chunk_size;
}
} else {
let chunk_size = (size / chunks) + 1; // assumes not divisible
for _ in 0..chunks - 1 {
tensors.push(narrow::<B, D, K>(
tensor.clone(),
dim,
sum_chunk_size,
chunk_size,
));
sum_chunk_size += chunk_size;
}
let remainder = size % chunk_size;
tensors.push(narrow::<B, D, K>(
tensor.clone(),
dim,
sum_chunk_size,
remainder,
));
}

tensors
}
4 changes: 4 additions & 0 deletions burn-tensor/src/tensor/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ pub(crate) mod check;
mod autodiff;
mod base;
mod bool;
mod chunk;
mod float;
mod int;
mod kind;
mod narrow;
mod numeric;

pub use autodiff::*;
pub use base::*;
pub use chunk::chunk;
pub use kind::*;
pub use narrow::narrow;
pub use numeric::*;
Loading
Loading