From ff2eeab8ed123cabf3cec11e370de17a0769b490 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 25 Oct 2023 06:38:40 -0700 Subject: [PATCH] impl_abstract expand_into_jagged_permute (#2090) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2090 Differential Revision: D50586828 --- fbgemm_gpu/fbgemm_gpu/sparse_operators.py | 20 ++++++++++++++++++++ fbgemm_gpu/test/failures_dict.json | 15 +-------------- fbgemm_gpu/test/jagged_tensor_ops_test.py | 1 + 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py b/fbgemm_gpu/fbgemm_gpu/sparse_operators.py index 091792ac19..8107ba09f2 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_operators.py @@ -71,3 +71,23 @@ def permute_1D_sparse_data_meta( # pyre-fixme permuted_weights = weights.new_empty(permuted_indices_size) return permuted_lengths, permuted_indices, permuted_weights + + +@torch.library.impl_abstract("fbgemm::expand_into_jagged_permute") +def expand_into_jagged_permute_meta( + permute: Tensor, + input_offsets: Tensor, + output_offsets: Tensor, + output_size: Tuple[int, ...], +) -> Tensor: + torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0") + torch._check( + permute.numel() == input_offsets.numel() - 1, + lambda: f"expected {permute.numel()} == {input_offsets.numel()} - 1", + ) + torch._check( + permute.numel() == output_offsets.numel() - 1, + lambda: f"expected {permute.numel()} == {output_offsets.numel()} - 1", + ) + output_permute = input_offsets.new_empty(output_size) + return output_permute diff --git a/fbgemm_gpu/test/failures_dict.json b/fbgemm_gpu/test/failures_dict.json index daca769e1e..b8cbae093b 100644 --- a/fbgemm_gpu/test/failures_dict.json +++ b/fbgemm_gpu/test/failures_dict.json @@ -154,20 +154,7 @@ "status": "xfail" } }, - "fbgemm::expand_into_jagged_permute": { - "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_expand_into_jagged_permute": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_aot_dispatch_static__test_expand_into_jagged_permute": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_faketensor__test_expand_into_jagged_permute": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::expand_into_jagged_permute": {}, "fbgemm::generic_histogram_binning_calibration_by_feature": { "SparseOpsTest.test_aot_dispatch_dynamic__test_generic_histogram_binning_calibration_by_feature": { "comment": "", diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index ddd4464a7e..f83bd941bc 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -35,6 +35,7 @@ except Exception: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + import fbgemm_gpu.sparse_operators # noqa: F401, E402 from fbgemm_gpu.test.test_utils import ( gpu_available, gpu_unavailable,