Skip to content

Commit

Permalink
impl_abstract for permute_1D_sparse_data (#2087)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2087

Reviewed By: zou3519

Differential Revision: D50584541
  • Loading branch information
ezyang authored and facebook-github-bot committed Oct 25, 2023
1 parent 9cd8ce8 commit 0492277
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
26 changes: 26 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,29 @@ def permute_2D_sparse_data_meta(
# pyre-fixme
permuted_weights = weights.new_empty(permuted_indices_size)
return permuted_lengths, permuted_indices, permuted_weights


@torch.library.impl_abstract("fbgemm::permute_1D_sparse_data")
def permute_1D_sparse_data_meta(
permute: Tensor,
lengths: Tensor,
values: Tensor,
weights: Optional[Tensor] = None,
permuted_lengths_sum: Optional[int] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
indices = values
permuted_lengths_size = permute.numel()
permuted_lengths = lengths.new_empty([permuted_lengths_size])
permuted_indices_size = 0
if permuted_lengths_sum is not None:
permuted_indices_size = permuted_lengths_sum
else:
ctx = torch._custom_op.impl.get_ctx()
permuted_indices_size = ctx.new_dynamic_size()
# pyre-fixme
permuted_indices = indices.new_empty(permuted_indices_size)
permuted_weights = None
if weights is not None:
# pyre-fixme
permuted_weights = weights.new_empty(permuted_indices_size)
return permuted_lengths, permuted_indices, permuted_weights
8 changes: 0 additions & 8 deletions fbgemm_gpu/test/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -516,18 +516,10 @@
}
},
"fbgemm::permute_1D_sparse_data": {
"SparseOpsTest.test_aot_dispatch_dynamic__test_permute_indices": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_aot_dispatch_static__test_permute_indices": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_faketensor__test_permute_indices": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_schema__test_permute_indices": {
"comment": "flaky",
"status": "skip"
Expand Down

0 comments on commit 0492277

Please sign in to comment.