Skip to content

Commit

Permalink
Add scaffolding for Python impl_abstract in fbgemm, implement fbgemm.…
Browse files Browse the repository at this point in the history
…permute_1D_sparse_data (pytorch#2084)

Summary:
This also fixes a minor bug in GPU permute_1D_sparse_data where we need to clone the zero-size tensors to correctly setup (lack of) aliasing.


Reviewed By: sryap

Differential Revision: D50563192
  • Loading branch information
ezyang authored and facebook-github-bot committed Oct 23, 2023
1 parent b1049cf commit f798d76
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 28 deletions.
3 changes: 2 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
open_source: bool = True

# Re-export docs
from . import _fbgemm_gpu_docs # noqa: F401, E402
# Trigger meta registrations
from . import _fbgemm_gpu_docs, sparse_operators # noqa: F401, E402 # noqa: F401, E402

# Re-export the version string from the auto-generated version file
from ._fbgemm_gpu_version import __version__ # noqa: F401, E402
47 changes: 47 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple

import torch
from torch import Tensor

try:
# pyre-ignore
from fbgemm_gpu import open_source # noqa: F401
except Exception:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")


@torch.library.impl_abstract("fbgemm::permute_2D_sparse_data")
def permute_2D_sparse_data_meta(
permute: Tensor,
lengths: Tensor,
values: Tensor,
weights: Optional[Tensor] = None,
permuted_lengths_sum: Optional[int] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
torch._check(
lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}"
)
T = permute.numel()
B = lengths.size(1)
indices = values
permuted_lengths = lengths.new_empty([T, B])
permuted_indices_size = 0
if permuted_lengths_sum is not None:
permuted_indices_size = permuted_lengths_sum
else:
ctx = torch._custom_op.impl.get_ctx()
permuted_indices_size = ctx.new_dynamic_size()
# pyre-fixme
permuted_indices = indices.new_empty(permuted_indices_size)
permuted_weights = None
if weights is not None:
# pyre-fixme
permuted_weights = weights.new_empty(permuted_indices_size)
return permuted_lengths, permuted_indices, permuted_weights
4 changes: 4 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2686,6 +2686,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"permute_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def(
"permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.impl_abstract_pystub(
"permute_2D_sparse_data",
"fbgemm_gpu.operators",
"//deeplearning/fbgemm/fbgemm_gpu:operators");
m.def(
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def("invert_permute(Tensor permute) -> Tensor");
Expand Down
7 changes: 4 additions & 3 deletions fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ permute_2D_sparse_data_cuda(
// When T = 0 or B = 0, permutation will not be performed. Return the
// input tensors.
return {
lengths,
indices,
weights,
lengths.clone(),
indices.clone(),
weights.has_value() ? c10::make_optional(weights->clone())
: c10::nullopt,
};
}

Expand Down
24 changes: 0 additions & 24 deletions fbgemm_gpu/test/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -525,18 +525,6 @@
}
},
"fbgemm::permute_2D_sparse_data": {
"SparseOpsTest.test_aot_dispatch_dynamic__test_permute_embeddings": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_aot_dispatch_dynamic__test_permute_indices": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_aot_dispatch_dynamic__test_permute_indices_with_repeats": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_aot_dispatch_static__test_permute_embeddings": {
"comment": "",
"status": "xfail"
Expand All @@ -549,18 +537,6 @@
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_faketensor__test_permute_embeddings": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_faketensor__test_permute_indices": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_faketensor__test_permute_indices_with_repeats": {
"comment": "",
"status": "xfail"
},
"SparseOpsTest.test_schema__test_permute_indices": {
"comment": "flaky",
"status": "skip"
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")
import fbgemm_gpu.sparse_operators # noqa: F401, E402
from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable, skipIfRocm

suppressed_list: List[HealthCheck] = (
Expand Down

0 comments on commit f798d76

Please sign in to comment.