Skip to content

Commit

Permalink
Add abstract impl fbgemm::dense_to_jagged (#2193)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2193

1. Remove the meta cpp function fbgemm::dense_to_jagged() and fbgemm::dense_to_jagged_forward()
2. Replace it with the Python abstract impl dense_to_jagged()

Reviewed By: zou3519, yanboliang

Differential Revision: D51216256

fbshipit-source-id: 532f25e5f9574e75f310ce98325523eb684cc7c8
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Dec 7, 2023
1 parent 0a3016f commit f8def44
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 64 deletions.
28 changes: 28 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,31 @@ def segment_sum_csr_abstract(
output_size = csr_seg.numel() - 1
output = values.new_empty(output_size)
return output


@impl_abstract("fbgemm::dense_to_jagged_forward")
def dense_to_jagged_forward(
dense: torch.Tensor,
offsets: List[torch.Tensor],
total_L: Optional[torch.SymInt] = None,
) -> torch.Tensor:
if not total_L:
total_L = torch.library.get_ctx().new_dynamic_size()
return dense.new_zeros(
total_L,
dense.size()[-1],
dtype=dense.dtype,
device=dense.device,
layout=dense.layout,
)


@impl_abstract("fbgemm::dense_to_jagged")
def dense_to_jagged(
dense: torch.Tensor,
offsets: List[torch.Tensor],
total_L: Optional[torch.SymInt] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
if not total_L:
total_L = torch.library.get_ctx().new_dynamic_size()
return (dense_to_jagged_forward(dense, offsets, total_L), offsets)
3 changes: 3 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// SymInt is a new PyTorch 2.0 feature to support dynamic shape. See more
// details at https://pytorch.org/get-started/pytorch-2.0/#dynamic-shapes. If
// you find it doesn't compile, please pull the new PyTorch 2.0 code
m.impl_abstract_pystub(
"fbgemm_gpu.sparse_ops",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
m.def(
"dense_to_jagged(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> (Tensor, Tensor[])",
{PT2_COMPLIANT_TAG});
Expand Down
26 changes: 0 additions & 26 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,6 @@ Tensor jagged_dense_elementwise_add_meta(
return at::empty_like(y);
}

Tensor dense_to_jagged_forward_meta(
const Tensor& dense,
const std::vector<Tensor>& offsets,
c10::optional<at::SymInt> total_L) {
auto dense_values = dense;
at::SymInt D = dense_values.sym_size(-1);
TORCH_CHECK_NOT_IMPLEMENTED(
total_L.has_value(), "total_L is required for meta backend");
auto& total_L_computed = total_L.value();
auto values = at::zeros_symint({total_L_computed, D}, dense_values.options());

TORCH_CHECK(values.is_meta());
return values;
}

std::tuple<Tensor, std::vector<Tensor>> dense_to_jagged_meta(
const Tensor& dense,
const std::vector<Tensor>& offsets,
c10::optional<at::SymInt> total_L) {
return {dense_to_jagged_forward_meta(dense, offsets, total_L), offsets};
}

std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul_meta(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
Expand Down Expand Up @@ -241,10 +219,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl(
"jagged_to_padded_dense_backward",
TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_backward_meta));
m.impl(
"dense_to_jagged_forward",
TORCH_FN(fbgemm_gpu::dense_to_jagged_forward_meta));
m.impl("dense_to_jagged", TORCH_FN(fbgemm_gpu::dense_to_jagged_meta));
m.impl(
"jagged_dense_dense_elementwise_add_jagged_output_forward",
TORCH_FN(
Expand Down
35 changes: 1 addition & 34 deletions fbgemm_gpu/test/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -115,40 +115,7 @@
"status": "xfail"
}
},
"fbgemm::dense_to_jagged": {
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged_meta_backend": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged_opt": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged_opt_large_batch": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged_meta_backend": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged_opt": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged_opt_large_batch": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::dense_to_jagged": {},
"fbgemm::expand_into_jagged_permute": {},
"fbgemm::generic_histogram_binning_calibration_by_feature": {
"SparseOpsTest.test_aot_dispatch_dynamic__test_generic_histogram_binning_calibration_by_feature": {
Expand Down
4 changes: 0 additions & 4 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,6 @@ def hash_size_cumsum_to_offsets(hash_size_cum_sum_list: List[int]) -> List[int]:
# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
additional_decorators: Dict[str, List[Callable]] = {
"test_pt2_compliant_tag_fbgemm_dense_to_jagged": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
"test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
Expand Down

0 comments on commit f8def44

Please sign in to comment.