Skip to content

Commit

Permalink
Get fbgemm::group_index_select_dim0 to pass tests (#2076)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2076

Follow https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?usp=sharing to get `fbgemm::group_index_select_dim0` tests passing in `failures_dict.json`

Reviewed By: zou3519

Differential Revision: D49851284

fbshipit-source-id: 7194212a26786196b0c3df295022baaab11d4539
  • Loading branch information
williamwen42 authored and facebook-github-bot committed Nov 2, 2023
1 parent 1650d1f commit 80eaddf
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 87 deletions.
16 changes: 13 additions & 3 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2606,9 +2606,9 @@ Tensor index_select_dim0(
return at::index_select(input, 0, indices);
}

std::vector<Tensor> group_index_select_dim0(
const std::vector<Tensor>& input_group,
const std::vector<Tensor>& indices_group) {
torch::autograd::variable_list group_index_select_dim0(
at::TensorList input_group,
at::TensorList indices_group) {
int num_groups = input_group.size();
TORCH_CHECK(num_groups == (int)indices_group.size())
std::vector<Tensor> output_group;
Expand Down Expand Up @@ -2843,3 +2843,13 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl("pack_segments", &fbgemm_gpu::pack_segments_autograd);
}

TORCH_LIBRARY_IMPL(fbgemm, AutogradCPU, m) {
m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
// CPU group_index_select_dim0 is decomposable
m.impl(
"group_index_select_dim0", TORCH_FN(fbgemm_gpu::group_index_select_dim0));
}
Loading

0 comments on commit 80eaddf

Please sign in to comment.