Skip to content

Commit

Permalink
Make fbgemm::jagged_index_select pt2_compliant (#2170)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2170

- The problem with the original op was that in the Autograd implementation, it
  needed to call Tensor.item(). This doesn't work with FakeTensors (maybe it
  can some day in the future).
- We create two new ops, `jagged_index_select_2d_forward_v2` and
  `jagged_index_add_2d_forward_v2` (which is effectively the backward) that do
  the Tensor.item() calls, and change fbgemm::jagged_index_select's Autograd
  implementation to call those.
- We add abstract impls for those two new ops.
- Finally, we move the fbgemm::jagged_index_select implementation to
  CompositeImplicitAutograd (and delete the CPU/CUDA impls, because those are
  redundant).

Reviewed By: williamwen42, aakhundov

Differential Revision: D51670069

fbshipit-source-id: b7ae86dcb02a993ec3bad94a839b707faf4f9098
  • Loading branch information
zou3519 authored and facebook-github-bot committed Nov 30, 2023
1 parent 63d1198 commit 1c40928
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 41 deletions.
32 changes: 32 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,38 @@ def tbe_input_combine_abstract(
return combined_indices, combined_offsets, combined_weights


@impl_abstract("fbgemm::jagged_index_select_2d_forward_v2")
def jagged_index_select_2d_forward_v2_abstract(
values: Tensor,
indices: Tensor,
input_offsets: Tensor,
output_offsets: Tensor,
) -> Tensor:
torch._check(values.device == indices.device)
torch._check(values.device == input_offsets.device)
torch._check(values.device == output_offsets.device)
torch._check(values.dim() == 2)
num_dense_output_rows = torch.library.get_ctx().new_dynamic_size()
num_cols = values.size(1)
return values.new_empty([num_dense_output_rows, num_cols])


@impl_abstract("fbgemm::jagged_index_add_2d_forward_v2")
def jagged_index_add_2d_forward_v2_abstract(
values: Tensor,
indices: Tensor,
input_offsets: Tensor,
output_offsets: Tensor,
num_output_rows: int,
) -> Tensor:
torch._check(values.device == indices.device)
torch._check(values.device == input_offsets.device)
torch._check(values.device == output_offsets.device)
torch._check(values.dim() == 2)
num_cols = values.size(1)
return values.new_empty([num_output_rows, num_cols])


@impl_abstract("fbgemm::expand_into_jagged_permute")
def expand_into_jagged_permute_meta(
permute: Tensor,
Expand Down
5 changes: 0 additions & 5 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,3 @@ FBGEMM_OP_DISPATCH(CUDA, "jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense);
FBGEMM_OP_DISPATCH(CUDA, "jagged_softmax", fbgemm_gpu::jagged_softmax);
FBGEMM_OP_DISPATCH(CUDA, "jagged_jagged_bmm", fbgemm_gpu::jagged_jagged_bmm);
FBGEMM_OP_DISPATCH(CUDA, "jagged_dense_bmm", fbgemm_gpu::jagged_dense_bmm);

FBGEMM_OP_DISPATCH(
CUDA,
"jagged_index_select",
fbgemm_gpu::jagged_index_select_2d);
40 changes: 12 additions & 28 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,30 +524,20 @@ class JaggedIndexSelect2dOp
Tensor output_offsets = output_lengths.cumsum(0);
Tensor input_offsets = lengths.cumsum(0);

int64_t num_dense_output_rows =
output_offsets[output_offsets.numel() - 1].item<int64_t>();

ctx->save_for_backward({indices, output_offsets, input_offsets});
ctx->saved_data["num_dense_grad_rows"] = num_dense_output_rows;
ctx->saved_data["num_input_rows"] = values.size(0);
ctx->saved_data["num_input_rows"] = values.sym_size(0);

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_index_select_2d_forward", "")
.findSchemaOrThrow("fbgemm::jagged_index_select_2d_forward_v2", "")
.typed<at::Tensor(
const Tensor& values,
const Tensor& indices,
const Tensor& input_offsets,
const Tensor& output_offsets,
const int64_t num_dense_output_rows)>();
const Tensor& output_offsets)>();

return {
op.call(
values,
indices,
input_offsets,
output_offsets,
num_dense_output_rows),
op.call(values, indices, input_offsets, output_offsets),
output_lengths};
}

Expand All @@ -565,29 +555,20 @@ class JaggedIndexSelect2dOp

TENSORS_ON_SAME_DEVICE(grad, indices);

int64_t num_dense_grad_rows =
ctx->saved_data["num_dense_grad_rows"].toInt();
int64_t num_output_rows = ctx->saved_data["num_input_rows"].toInt();
auto num_output_rows = ctx->saved_data["num_input_rows"].toSymInt();

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "")
.findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward_v2", "")
.typed<at::Tensor(
const Tensor& values,
const Tensor& indices,
const Tensor& input_offsets,
const Tensor& output_offsets,
const int64_t num_dense_input_rows,
const int64_t num_output_rows)>();
c10::SymInt num_output_rows)>();

return {
op.call(
grad,
indices,
grad_offsets,
output_offsets,
num_dense_grad_rows,
num_output_rows),
op.call(grad, indices, grad_offsets, output_offsets, num_output_rows),
torch::autograd::Variable(), // lengths
torch::autograd::Variable() // indices
};
Expand Down Expand Up @@ -883,6 +864,9 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl("jagged_softmax", TORCH_FN(fbgemm_gpu::jagged_softmax));
m.impl("jagged_jagged_bmm", TORCH_FN(fbgemm_gpu::jagged_jagged_bmm));
m.impl("jagged_dense_bmm", TORCH_FN(fbgemm_gpu::jagged_dense_bmm));
m.impl("jagged_index_select", TORCH_FN(fbgemm_gpu::jagged_index_select_2d));
m.impl("jagged_slice", TORCH_FN(fbgemm_gpu::jagged_slice));
}

TORCH_LIBRARY_IMPL(fbgemm, CompositeImplicitAutograd, m) {
m.impl("jagged_index_select", TORCH_FN(fbgemm_gpu::jagged_index_select_2d));
}
76 changes: 74 additions & 2 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,63 @@ Tensor jagged_index_select_2d_forward_cpu(
return output;
}

// v2 supports PT2 Dynamic Shapes.
// The problem with v1 is that it accepts a redundant num_dense_output_rows arg
// that we compute by peeking at output_offsets.data.
// PT2 has problems with data access, so we hide the data access inside
// the new operator.
Tensor jagged_index_select_2d_forward_v2_impl(
const Tensor& values,
const Tensor& indices,
const Tensor& input_offsets,
const Tensor& output_offsets) {
int64_t num_dense_output_rows =
output_offsets[output_offsets.numel() - 1].item<int64_t>();
static auto v1_op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_index_select_2d_forward", "")
.typed<at::Tensor(
const Tensor& values,
const Tensor& indices,
const Tensor& input_offsets,
const Tensor& output_offsets,
const int64_t num_dense_output_rows)>();
return v1_op.call(
values, indices, input_offsets, output_offsets, num_dense_output_rows);
}

// v2 supports PT2 Dynamic Shapes.
// The problem with v1 is that it accepts a redundant num_dense_output_rows arg
// that we compute by peeking at input_offsets.data.
// PT2 has problems with data access, so we hide the data access inside
// the new operator.
Tensor jagged_index_add_2d_forward_v2_impl(
const Tensor& values,
const Tensor& indices,
const Tensor& input_offsets,
const Tensor& output_offsets,
const int64_t num_output_rows) {
int64_t num_dense_output_rows =
input_offsets[input_offsets.numel() - 1].item<int64_t>();
static auto v1_op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "")
.typed<at::Tensor(
const Tensor& values,
const Tensor& indices,
const Tensor& input_offsets,
const Tensor& output_offsets,
const int64_t num_dense_input_rows,
const int64_t num_output_rows)>();
return v1_op.call(
values,
indices,
input_offsets,
output_offsets,
num_dense_output_rows,
num_output_rows);
}

template <typename index_t, typename offset_t, typename scalar_t>
void jagged_index_add_2d_kernel(
at::TensorAccessor<scalar_t, 2> output,
Expand Down Expand Up @@ -1642,11 +1699,18 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"batched_dense_vec_jagged_2d_mul_backward(Tensor grad_output, Tensor v, Tensor a_values, Tensor a_offsets) -> (Tensor, Tensor)",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_index_select(Tensor values, Tensor lengths, Tensor indices) -> Tensor[]");
"jagged_index_select(Tensor values, Tensor lengths, Tensor indices) -> Tensor[]",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_index_select_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_output_rows) -> Tensor");
m.def(
"jagged_index_select_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_index_add_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_input_rows, int num_output_rows) -> Tensor");
m.def(
"jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_1d_to_truncated_values(Tensor values, Tensor lengths, int max_truncated_length) -> Tensor");
m.def(
Expand Down Expand Up @@ -1728,7 +1792,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU(
"jagged_index_select_2d_forward",
fbgemm_gpu::jagged_index_select_2d_forward_cpu);
DISPATCH_TO_CPU("jagged_index_select", fbgemm_gpu::jagged_index_select_2d);
DISPATCH_TO_CPU(
"jagged_index_add_2d_forward",
fbgemm_gpu::jagged_index_add_2d_forward_cpu);
Expand All @@ -1749,3 +1812,12 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"jagged_dense_bmm_forward", fbgemm_gpu::jagged_dense_bmm_forward);
DISPATCH_TO_CPU("jagged_slice_forward", fbgemm_gpu::jagged_slice_forward_cpu);
}

TORCH_LIBRARY_IMPL(fbgemm, CompositeExplicitAutograd, m) {
m.impl(
"jagged_index_select_2d_forward_v2",
fbgemm_gpu::jagged_index_select_2d_forward_v2_impl);
m.impl(
"jagged_index_add_2d_forward_v2",
fbgemm_gpu::jagged_index_add_2d_forward_v2_impl);
}
12 changes: 6 additions & 6 deletions fbgemm_gpu/test/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -268,27 +268,27 @@
"fbgemm::jagged_index_select": {
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_jagged_index_select_2d": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_jagged_index_select_2d_in_inference": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_keyed_jagged_index_select_dim1": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"JaggedTensorOpsTest.test_faketensor__test_jagged_index_select_2d": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"JaggedTensorOpsTest.test_faketensor__test_jagged_index_select_2d_in_inference": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"JaggedTensorOpsTest.test_faketensor__test_keyed_jagged_index_select_dim1": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
}
},
"fbgemm::jagged_jagged_bmm": {},
Expand Down

0 comments on commit 1c40928

Please sign in to comment.