diff --git a/burn-tensor/src/tensor/api/base.rs b/burn-tensor/src/tensor/api/base.rs index 9bd7b78e09..dd8f0d540c 100644 --- a/burn-tensor/src/tensor/api/base.rs +++ b/burn-tensor/src/tensor/api/base.rs @@ -532,20 +532,22 @@ where return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect(); } - let chunk_size = size / chunks; - let cnt_additional = size % chunks; let mut tensors = Vec::with_capacity(chunks); - let mut sum_chunk_size = 0; - for i in 0..chunks { - let chunk_size = if i < cnt_additional { - chunk_size + 1 - } else { - chunk_size - }; - - tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size)); - sum_chunk_size += chunk_size; + if size % chunks == 0 { + let chunk_size = size / chunks; + for _ in 0..chunks { + tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size)); + sum_chunk_size += chunk_size; + } + } else { + let chunk_size = (size / chunks) + 1; // assumes not divisible + for _ in 0..chunks - 1 { + tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size)); + sum_chunk_size += chunk_size; + } + let remainder = size % chunk_size; + tensors.push(self.clone().narrow(dim, sum_chunk_size, remainder)); } tensors diff --git a/burn-tensor/src/tests/ops/chunk.rs b/burn-tensor/src/tests/ops/chunk.rs index b7d472fa75..53b439549d 100644 --- a/burn-tensor/src/tests/ops/chunk.rs +++ b/burn-tensor/src/tests/ops/chunk.rs @@ -41,6 +41,18 @@ mod tests { } } + #[test] + fn test_chunk_not_evenly_divisible_remains_several() { + let tensors: Vec> = Tensor::arange(0..100).chunk(8, 0); + assert_eq!(tensors.len(), 8); + + let expected = [13, 13, 13, 13, 13, 13, 13, 9]; + + for (index, tensor) in tensors.iter().enumerate() { + assert_eq!(tensor.shape().dims[0], expected[index]); + } + } + #[test] fn test_chunk_not_divisible() { let tensors: Vec> = Tensor::arange(0..6).chunk(7, 0);