Skip to content

Commit

Permalink
impl_abstract for permute_1D_sparse_data
Browse files Browse the repository at this point in the history
Differential Revision: D50584541
  • Loading branch information
ezyang authored and facebook-github-bot committed Oct 24, 2023
1 parent ae45de6 commit ee60088
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
26 changes: 26 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,29 @@ def permute_2D_sparse_data_meta(
# pyre-fixme
permuted_weights = weights.new_empty(permuted_indices_size)
return permuted_lengths, permuted_indices, permuted_weights


@torch.library.impl_abstract("fbgemm::permute_1D_sparse_data")
def permute_1D_sparse_data_meta(
permute: Tensor,
lengths: Tensor,
values: Tensor,
weights: Optional[Tensor] = None,
permuted_lengths_sum: Optional[int] = None,
):
indices = values
permuted_lengths_size = permute.numel()
permuted_lengths = lengths.new_empty([permuted_lengths_size])
permuted_indices_size = 0
if permuted_lengths_sum is not None:
permuted_indices_size = permuted_lengths_sum
else:
ctx = torch._custom_op.impl.get_ctx()
permuted_indices_size = ctx.new_dynamic_size()
# pyre-fixme
permuted_indices = indices.new_empty(permuted_indices_size)
permuted_weights = None
if weights is not None:
# pyre-fixme
permuted_weights = weights.new_empty(permuted_indices_size)
return permuted_lengths, permuted_indices, permuted_weights
8 changes: 6 additions & 2 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2688,10 +2688,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.impl_abstract_pystub(
"permute_2D_sparse_data",
"fbgemm_gpu.operators",
"//deeplearning/fbgemm/fbgemm_gpu:operators");
"fbgemm_gpu.sparse_operators",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_operators");
m.def(
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.impl_abstract_pystub(
"permute_1D_sparse_data",
"fbgemm_gpu.sparse_operators",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_operators");
m.def("invert_permute(Tensor permute) -> Tensor");
m.def(
"expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, int output_size) -> Tensor");
Expand Down
8 changes: 0 additions & 8 deletions fbgemm_gpu/test/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -507,18 +507,10 @@
}
},
"fbgemm::permute_1D_sparse_data": {
"SparseOpsTest.test_aot_dispatch_dynamic__test_permute_indices": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_aot_dispatch_static__test_permute_indices": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_faketensor__test_permute_indices": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_schema__test_permute_indices": {
"comment": "flaky",
"status": "skip"
Expand Down

0 comments on commit ee60088

Please sign in to comment.