Skip to content

Commit

Permalink
fix/docs/chunk (#1006)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Nov 29, 2023
1 parent 3301aed commit aa3180d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
26 changes: 14 additions & 12 deletions burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,20 +532,22 @@ where
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
}

let chunk_size = size / chunks;
let cnt_additional = size % chunks;
let mut tensors = Vec::with_capacity(chunks);

let mut sum_chunk_size = 0;
for i in 0..chunks {
let chunk_size = if i < cnt_additional {
chunk_size + 1
} else {
chunk_size
};

tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
if size % chunks == 0 {
let chunk_size = size / chunks;
for _ in 0..chunks {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
} else {
let chunk_size = (size / chunks) + 1; // assumes not divisible
for _ in 0..chunks - 1 {
tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}
let remainder = size % chunk_size;
tensors.push(self.clone().narrow(dim, sum_chunk_size, remainder));
}

tensors
Expand Down
12 changes: 12 additions & 0 deletions burn-tensor/src/tests/ops/chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ mod tests {
}
}

#[test]
fn test_chunk_not_evenly_divisible_remains_several() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..100).chunk(8, 0);
assert_eq!(tensors.len(), 8);

let expected = [13, 13, 13, 13, 13, 13, 13, 9];

for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.shape().dims[0], expected[index]);
}
}

#[test]
fn test_chunk_not_divisible() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..6).chunk(7, 0);
Expand Down

0 comments on commit aa3180d

Please sign in to comment.