diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py b/fbgemm_gpu/fbgemm_gpu/sparse_operators.py index 091792ac19..8107ba09f2 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_operators.py @@ -71,3 +71,23 @@ def permute_1D_sparse_data_meta( # pyre-fixme permuted_weights = weights.new_empty(permuted_indices_size) return permuted_lengths, permuted_indices, permuted_weights + + +@torch.library.impl_abstract("fbgemm::expand_into_jagged_permute") +def expand_into_jagged_permute_meta( + permute: Tensor, + input_offsets: Tensor, + output_offsets: Tensor, + output_size: Tuple[int, ...], +) -> Tensor: + torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0") + torch._check( + permute.numel() == input_offsets.numel() - 1, + lambda: f"expected {permute.numel()} == {input_offsets.numel()} - 1", + ) + torch._check( + permute.numel() == output_offsets.numel() - 1, + lambda: f"expected {permute.numel()} == {output_offsets.numel()} - 1", + ) + output_permute = input_offsets.new_empty(output_size) + return output_permute diff --git a/fbgemm_gpu/test/failures_dict.json b/fbgemm_gpu/test/failures_dict.json index daca769e1e..b8cbae093b 100644 --- a/fbgemm_gpu/test/failures_dict.json +++ b/fbgemm_gpu/test/failures_dict.json @@ -154,20 +154,7 @@ "status": "xfail" } }, - "fbgemm::expand_into_jagged_permute": { - "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_expand_into_jagged_permute": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_aot_dispatch_static__test_expand_into_jagged_permute": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_faketensor__test_expand_into_jagged_permute": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::expand_into_jagged_permute": {}, "fbgemm::generic_histogram_binning_calibration_by_feature": { "SparseOpsTest.test_aot_dispatch_dynamic__test_generic_histogram_binning_calibration_by_feature": { "comment": "", diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index ddd4464a7e..f83bd941bc 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -35,6 +35,7 @@ except Exception: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + import fbgemm_gpu.sparse_operators # noqa: F401, E402 from fbgemm_gpu.test.test_utils import ( gpu_available, gpu_unavailable,