-
Notifications
You must be signed in to change notification settings - Fork 474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate/jit/cat #1457
Migrate/jit/cat #1457
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1457 +/- ##
==========================================
- Coverage 85.79% 85.78% -0.02%
==========================================
Files 647 649 +2
Lines 72106 72414 +308
==========================================
+ Hits 61866 62119 +253
- Misses 10240 10295 +55 ☔ View full report in Codecov by Sentry. |
let first_tensor = tensors.first().expect("Tensors should not be empty"); | ||
let mut shape = B::float_shape(first_tensor); | ||
let device = &B::float_device(first_tensor); | ||
|
||
let output_dim_length: usize = tensors | ||
.iter() | ||
.map(|tensor: &FloatTensor<B, D>| B::float_shape(tensor).dims[dim]) | ||
.sum(); | ||
shape.dims[dim] = output_dim_length; | ||
|
||
let mut tensor_output = B::float_empty(shape.clone(), device); | ||
|
||
let mut i = 0; | ||
let indices_select_all = [0; D].map(|_| { | ||
let start = 0; | ||
let end = shape.dims[i]; | ||
i += 1; | ||
start..end | ||
}); | ||
|
||
let mut output_index = 0; | ||
for tensor in tensors { | ||
let mut indices = indices_select_all.clone(); | ||
let tensor_dim_length = B::float_shape(&tensor).dims[dim]; | ||
indices[dim] = output_index..tensor_dim_length; | ||
output_index += tensor_dim_length; | ||
|
||
tensor_output = B::float_slice_assign(tensor_output, indices, tensor) | ||
} | ||
|
||
tensor_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be better to implement this function in a specific module that work with the Tensor
struct instead of the backend API. Here we could simply:
fn float_cat<const D: usize>(tensors: Vec<FloatTensor<B, D>>, dim: usize) -> FloatTensor<B, D> {
cat_with_slice_assign(tensors.into_iter().map(Tensor::from_primitive).collect(), dim).into_primitive()
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried, see #1473
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See me comment
Cat kernel for #1422
Note: it was chosen to rely on a default implementation based on slice assign instead of making a kernel which would have the same performance anyway because of the dynamic number of inputs.