Skip to content

Commit

Permalink
impl_abstract expand_into_jagged_permute (#2090)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2090

Reviewed By: zou3519

Differential Revision: D50586828

fbshipit-source-id: 2f92a717877dfaf7b56fbea56df67acc272fd8f5
  • Loading branch information
ezyang authored and facebook-github-bot committed Oct 26, 2023
1 parent 34f62ad commit 3283a68
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
20 changes: 20 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,23 @@ def permute_1D_sparse_data_meta(
# pyre-fixme
permuted_weights = weights.new_empty(permuted_indices_size)
return permuted_lengths, permuted_indices, permuted_weights


@torch.library.impl_abstract("fbgemm::expand_into_jagged_permute")
def expand_into_jagged_permute_meta(
permute: Tensor,
input_offsets: Tensor,
output_offsets: Tensor,
output_size: Tuple[int, ...],
) -> Tensor:
torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0")
torch._check(
permute.numel() == input_offsets.numel() - 1,
lambda: f"expected {permute.numel()} == {input_offsets.numel()} - 1",
)
torch._check(
permute.numel() == output_offsets.numel() - 1,
lambda: f"expected {permute.numel()} == {output_offsets.numel()} - 1",
)
output_permute = input_offsets.new_empty(output_size)
return output_permute
15 changes: 1 addition & 14 deletions fbgemm_gpu/test/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,7 @@
"status": "xfail"
}
},
"fbgemm::expand_into_jagged_permute": {
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_expand_into_jagged_permute": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_aot_dispatch_static__test_expand_into_jagged_permute": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_faketensor__test_expand_into_jagged_permute": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::expand_into_jagged_permute": {},
"fbgemm::generic_histogram_binning_calibration_by_feature": {
"SparseOpsTest.test_aot_dispatch_dynamic__test_generic_histogram_binning_calibration_by_feature": {
"comment": "",
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
except Exception:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
import fbgemm_gpu.sparse_operators # noqa: F401, E402
from fbgemm_gpu.test.test_utils import (
gpu_available,
gpu_unavailable,
Expand Down

0 comments on commit 3283a68

Please sign in to comment.