-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement chunk for different backends #1032
Implement chunk for different backends #1032
Conversation
Currently, this PR has an error of not implementing the traits:
I think once we come up with the solution for |
The current solution is to have a default implementation on all the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove the default function implementation of chunk
and narrow
in the Kind
API. I think you should create files in burn_tensor
for narrow
and chunk
where implementations will be located, which can reduce code duplication.
burn-candle/src/ops/base.rs
Outdated
let tensor = tensor.tensor.narrow(dim, start, length); | ||
match tensor { | ||
Ok(tensor) => CandleTensor::new(tensor), | ||
Err(e) => panic!("error chunk from Candle"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
error chunk
=> error narrow
burn-tensor/src/tensor/api/base.rs
Outdated
if i == dim { | ||
start..(start + length) | ||
} else { | ||
0..Self::shape(&tensor).dims[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would extract the shape before the loop, just a bit cleaner imo.
burn-tensor/src/tensor/api/base.rs
Outdated
fn chunk<const D: usize>( | ||
tensor: Self::Primitive<D>, | ||
chunks: usize, | ||
dim: usize, | ||
) -> Vec<Self::Primitive<D>> { | ||
let size = Self::shape(&tensor).dims[dim]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default implementation should be in the Backend
trait, not in the API Kind trait system.
So the change would be =>
Extract those functions into their own modules burn-tensor/src/ops/chunk.rs
and burn-tensor/src/ops/narrow.rs
. In the TensorOps
, IntTensorOps
and BoolTensorOps
trait, we should call those functions by default, so backends are not forced to implement the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm the interfaces, I presume they should be the following right?:
pub(crate) fn chunk<B: Backend, const D: usize, K: TensorKind>(
tensor: Tensor<B, D, K>,
chunks: usize,
dim: usize,
) -> Vec<Tensor<B, D, K>> {
}
pub(crate) fn narrow<B: Backend, const D: usize, K: TensorKind>(
tensor: Tensor<B, D, K>,
dim: usize,
start: usize,
length: usize,
) -> Tensor<B, D, K> {
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Kelvinyu1117 Yes sorry for the delay, I wasn't available last week 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is refactored. Please have a look, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it's good that we can leverage tch and candle implementations!
It appears we have worked a bit in parallel, see my other comment.
/// | ||
/// A vectors of tensors | ||
/// | ||
fn chunk<const D: usize>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you have started your branch before #1006 was merged.
In #1006 we modified the default chunk algorithm --don't mind the name of the PR which is about fixing the docs, at first I thought the doc was wrong and the code was right, but in the end the doc was right and the code was wrong.
Could you please adjust the algorithm so it passes the test in #1006 (or in main as #1006 was merged) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, let me work on that.
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #1032 +/- ##
==========================================
- Coverage 85.63% 85.33% -0.30%
==========================================
Files 507 509 +2
Lines 54074 54324 +250
==========================================
+ Hits 46305 46357 +52
- Misses 7769 7967 +198 ☔ View full report in Codecov by Sentry. |
Thanks for the change, looks good to me now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some smalls changes, but it's getting ready to be merged.
burn-tensor/src/tensor/api/base.rs
Outdated
// fn chunk<const D: usize>( | ||
// tensor: Self::Primitive<D>, | ||
// chunks: usize, | ||
// dim: usize, | ||
// ) -> Vec<Self::Primitive<D>> { | ||
// B::bool_chunk(tensor, chunks, dim) | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deed code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
burn-tensor/src/tensor/api/chunk.rs
Outdated
/// Split the tensor along the given dimension into chunks. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `tensor` - The tensor. | ||
/// * `chunks` - The number of chunks to be produced | ||
/// * `times` - The dimension along which the tensor will be split. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A vectors of tensors | ||
/// | ||
/// # Remarks | ||
/// | ||
/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. | ||
/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved | ||
/// by static dispatch. It is not designed for direct usage by users, and not recommended to import | ||
/// or use this function directly. | ||
/// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small details, but I think spacing between the comments and the imports would be nice. Also, the last empty line of the comment isn't necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
This reverts commit 7c6f017.
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#970
Changes
This PR is in progress, mainly utilizing the backend implementation for
chunk()
instead of the default, currently onlybool_chunk
was implemented and encountered issues forburn-fusion
.Testing
As it is still in progress, it has not yet been tested.