Skip to content

Commit

Permalink
Revert "Implement chunk for different backends (#1032)"
Browse files Browse the repository at this point in the history
This reverts commit 7c6f017.
  • Loading branch information
syl20bnr committed Jan 4, 2024
1 parent 71230cf commit 83278ad
Show file tree
Hide file tree
Showing 17 changed files with 45 additions and 450 deletions.
17 changes: 0 additions & 17 deletions burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,4 @@ impl<B: Backend> BoolTensorOps<Self> for Autodiff<B> {
) -> <Autodiff<B> as Backend>::BoolTensorPrimitive<D> {
B::bool_swap_dims(tensor, dim1, dim2)
}

fn bool_narrow<const D: usize>(
tensor: BoolTensor<B, D>,
dim: usize,
start: usize,
length: usize,
) -> BoolTensor<B, D> {
B::bool_narrow(tensor, dim, start, length)
}

fn bool_chunk<const D: usize>(
tensor: BoolTensor<B, D>,
chunks: usize,
dim: usize,
) -> Vec<BoolTensor<B, D>> {
B::bool_chunk(tensor, chunks, dim)
}
}
17 changes: 0 additions & 17 deletions burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,21 +316,4 @@ impl<B: Backend> IntTensorOps<Autodiff<B>> for Autodiff<B> {
) -> <Autodiff<B> as Backend>::IntTensorPrimitive<D> {
B::int_swap_dims(tensor, dim1, dim2)
}

fn int_narrow<const D: usize>(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive<D>,
dim: usize,
start: usize,
length: usize,
) -> <Autodiff<B> as Backend>::IntTensorPrimitive<D> {
B::int_narrow(tensor, dim, start, length)
}

fn int_chunk<const D: usize>(
tensor: <Autodiff<B> as Backend>::IntTensorPrimitive<D>,
chunks: usize,
dim: usize,
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive<D>> {
B::int_chunk(tensor, chunks, dim)
}
}
28 changes: 0 additions & 28 deletions burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,31 +88,3 @@ pub fn slice_assign<E: CandleElement, const D1: usize, const D2: usize>(
) -> CandleTensor<E, D1> {
panic!("slice_assign not supported by Candle")
}

pub fn narrow<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
dim: usize,
start: usize,
length: usize,
) -> CandleTensor<E, D> {
let tensor = tensor.tensor.narrow(dim, start, length);
match tensor {
Ok(tensor) => CandleTensor::new(tensor),
Err(e) => panic!("error narrow from Candle"),
}
}

pub fn chunk<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
chunks: usize,
dim: usize,
) -> Vec<CandleTensor<E, D>> {
let tensors = tensor.tensor.chunk(chunks, dim);
match tensors {
Ok(tensors) => tensors
.into_iter()
.map(|tensor| CandleTensor::new(tensor))
.collect(),
Err(e) => panic!("error chunk from Candle"),
}
}
17 changes: 0 additions & 17 deletions burn-candle/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,4 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
) -> <Candle<F, I> as burn_tensor::backend::Backend>::BoolTensorPrimitive<D> {
super::base::swap_dims(tensor, dim1, dim2)
}

fn bool_narrow<const D: usize>(
tensor: BoolTensor<Self, D>,
dim: usize,
start: usize,
length: usize,
) -> BoolTensor<Self, D> {
super::base::narrow(tensor, dim, start, length)
}

fn bool_chunk<const D: usize>(
tensor: BoolTensor<Self, D>,
chunks: usize,
dim: usize,
) -> Vec<BoolTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}
}
17 changes: 0 additions & 17 deletions burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,21 +359,4 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
) -> <Candle<F, I> as burn_tensor::backend::Backend>::IntTensorPrimitive<D> {
super::base::swap_dims(tensor, dim1, dim2)
}

fn int_narrow<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
start: usize,
length: usize,
) -> IntTensor<Self, D> {
super::base::narrow(tensor, dim, start, length)
}

fn int_chunk<const D: usize>(
tensor: IntTensor<Self, D>,
chunks: usize,
dim: usize,
) -> Vec<IntTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}
}
17 changes: 0 additions & 17 deletions burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,21 +457,4 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.recip().unwrap())
}

fn narrow<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
start: usize,
length: usize,
) -> FloatTensor<Self, D> {
super::base::narrow(tensor, dim, start, length)
}

fn chunk<const D: usize>(
tensor: FloatTensor<Self, D>,
chunks: usize,
dim: usize,
) -> Vec<FloatTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}
}
26 changes: 0 additions & 26 deletions burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,30 +413,4 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);
TchTensor::new(tensor)
}

pub fn narrow<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<E, D> {
TchTensor::new(
tensor
.tensor
.narrow(dim as i64, start as i64, length as i64),
)
}

pub fn chunk<const D: usize>(
tensor: TchTensor<E, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<E, D>> {
tensor
.tensor
.chunk(chunks as i64, dim as i64)
.into_iter()
.map(|tensor| TchTensor::new(tensor))
.collect()
}
}
17 changes: 0 additions & 17 deletions burn-tch/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,4 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
) -> <LibTorch<E> as Backend>::BoolTensorPrimitive<D> {
TchOps::swap_dims(tensor, dim1, dim2)
}

fn bool_narrow<const D: usize>(
tensor: TchTensor<bool, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<bool, D> {
TchOps::narrow(tensor, dim, start, length)
}

fn bool_chunk<const D: usize>(
tensor: TchTensor<bool, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<bool, D>> {
TchOps::chunk(tensor, chunks, dim)
}
}
17 changes: 0 additions & 17 deletions burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,21 +401,4 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
TchOps::swap_dims(tensor, dim1, dim2)
}

fn int_narrow<const D: usize>(
tensor: TchTensor<i64, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<i64, D> {
TchOps::narrow(tensor, dim, start, length)
}

fn int_chunk<const D: usize>(
tensor: TchTensor<i64, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<i64, D>> {
TchOps::chunk(tensor, chunks, dim)
}
}
17 changes: 0 additions & 17 deletions burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,21 +440,4 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
TchTensor::new(tensor)
}

fn narrow<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
start: usize,
length: usize,
) -> TchTensor<E, D> {
TchOps::narrow(tensor, dim, start, length)
}

fn chunk<const D: usize>(
tensor: TchTensor<E, D>,
chunks: usize,
dim: usize,
) -> Vec<TchTensor<E, D>> {
TchOps::chunk(tensor, chunks, dim)
}
}
51 changes: 42 additions & 9 deletions burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ use alloc::vec;
use burn_common::{reader::Reader, stub::Mutex};
use core::{fmt::Debug, ops::Range};

use crate::check::TensorCheck;
use crate::tensor::api::chunk::chunk;
use crate::tensor::api::narrow::narrow;
use crate::{backend::Backend, check, Bool, Data, Float, Int, Shape, TensorKind};
use crate::{
backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind,
};

/// A tensor with a given backend, shape and data type.
#[derive(new, Clone, Debug)]
Expand Down Expand Up @@ -508,7 +507,20 @@ where
pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
check!(TensorCheck::dim_ops::<D>("narrow", dim));
check!(TensorCheck::narrow(&self, dim, start, length));
Self::new(narrow::<B, D, K>(self.primitive, dim, start, length))

let ranges: Vec<_> = (0..D)
.map(|i| {
if i == dim {
start..(start + length)
} else {
0..self.shape().dims[i]
}
})
.collect();

let ranges_array: [_; D] = ranges.try_into().unwrap();

self.slice(ranges_array)
}

/// Attempts to split the tensor along the given dimension into chunks.
Expand All @@ -525,10 +537,31 @@ where
/// A vector of tensors.
pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
check!(TensorCheck::dim_ops::<D>("chunk", dim));
chunk::<B, D, K>(self.primitive, chunks, dim)
.into_iter()
.map(|v| Self::new(v))
.collect()

let size = self.shape().dims[dim];
if size < chunks {
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
}

let mut tensors = Vec::with_capacity(chunks);
let mut sum_chunk_size = 0;
if size % chunks == 0 {
let chunk_size = size / chunks;
for _ in 0..chunks {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
} else {
let chunk_size = (size / chunks) + 1; // assumes not divisible
for _ in 0..chunks - 1 {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
let remainder = size % chunk_size;
tensors.push(self.clone().narrow(dim, sum_chunk_size, remainder));
}

tensors
}
}

Expand Down
69 changes: 0 additions & 69 deletions burn-tensor/src/tensor/api/chunk.rs

This file was deleted.

4 changes: 0 additions & 4 deletions burn-tensor/src/tensor/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@ pub(crate) mod check;
mod autodiff;
mod base;
mod bool;
mod chunk;
mod float;
mod int;
mod kind;
mod narrow;
mod numeric;

pub use autodiff::*;
pub use base::*;
pub use chunk::chunk;
pub use kind::*;
pub use narrow::narrow;
pub use numeric::*;
Loading

0 comments on commit 83278ad

Please sign in to comment.