From 80199a3ed53f1297d2d797cf628de2f561e9ee26 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 24 Oct 2023 16:41:40 -0700 Subject: [PATCH] impl_abstract expand_into_jagged_permute (#2090) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2090 Differential Revision: D50586828 --- fbgemm_gpu/fbgemm_gpu/sparse_operators.py | 17 +++++++++++++++++ fbgemm_gpu/test/failures_dict.json | 15 +-------------- fbgemm_gpu/test/jagged_tensor_ops_test.py | 1 + 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py b/fbgemm_gpu/fbgemm_gpu/sparse_operators.py index 07fd2d0963..59a5a19fdd 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_operators.py @@ -71,3 +71,20 @@ def permute_1D_sparse_data_meta( # pyre-fixme permuted_weights = weights.new_empty(permuted_indices_size) return permuted_lengths, permuted_indices, permuted_weights + + +@torch.library.impl_abstract("fbgemm::expand_into_jagged_permute") +def expand_into_jagged_permute_meta( + permute, input_offsets, output_offsets, output_size +): + torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0") + torch._check( + permute.numel() == input_offsets.numel() - 1, + lambda: f"expected {permute.numel()} == {input_offsets.numel()} - 1", + ) + torch._check( + permute.numel() == output_offsets.numel() - 1, + lambda: f"expected {permute.numel()} == {output_offsets.numel()} - 1", + ) + output_permute = input_offsets.new_empty(output_size) + return output_permute diff --git a/fbgemm_gpu/test/failures_dict.json b/fbgemm_gpu/test/failures_dict.json index daca769e1e..b8cbae093b 100644 --- a/fbgemm_gpu/test/failures_dict.json +++ b/fbgemm_gpu/test/failures_dict.json @@ -154,20 +154,7 @@ "status": "xfail" } }, - "fbgemm::expand_into_jagged_permute": { - "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_expand_into_jagged_permute": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_aot_dispatch_static__test_expand_into_jagged_permute": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_faketensor__test_expand_into_jagged_permute": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::expand_into_jagged_permute": {}, "fbgemm::generic_histogram_binning_calibration_by_feature": { "SparseOpsTest.test_aot_dispatch_dynamic__test_generic_histogram_binning_calibration_by_feature": { "comment": "", diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index ddd4464a7e..f83bd941bc 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -35,6 +35,7 @@ except Exception: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + import fbgemm_gpu.sparse_operators # noqa: F401, E402 from fbgemm_gpu.test.test_utils import ( gpu_available, gpu_unavailable,