Skip to content

Commit

Permalink
- Support for CPU/GPU compilation (pytorch#2040)
Browse files Browse the repository at this point in the history
Summary:

Backwards was not working for CPU.

Moving the Autograd registration to CPU fixes the issue. On CPU only compilation since the GPU code was not built there was no autograd registration.

Also re-activated the test on github.

Reviewed By: spcyppt

Differential Revision: D49615498
  • Loading branch information
Flavio Sales Truzzi authored and facebook-github-bot committed Sep 25, 2023
1 parent 85de33b commit 439292a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 81 deletions.
72 changes: 72 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <torch/library.h>
#include "ATen/Parallel.h"

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/autograd/custom_function.h>
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
Expand Down Expand Up @@ -55,6 +56,73 @@ using Tensor = at::Tensor;

namespace fbgemm_gpu {

// Custom PackSegments operator that is based on the Caffe2 PackSegments and
// UnpackSegments.
// Needed this to support backward pass.
class PackSegments : public torch::autograd::Function<PackSegments> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& t_in,
const Tensor& lengths,
const at::SymInt& max_length) {
const at::SymInt total_length = t_in.sym_size(0);

at::AutoDispatchBelowADInplaceOrView guard;

static auto custom_pack_segments_op =
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::pack_segments", "")
.typed<at::Tensor(
const at::Tensor&, const at::Tensor&, const at::SymInt)>();

Tensor res = custom_pack_segments_op.call(t_in, lengths, max_length);

ctx->saved_data["max_length"] = max_length;
ctx->saved_data["total_length"] = total_length;
ctx->save_for_backward({lengths});

return {res};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(grad_output.size() == 2 or grad_output.size() == 1);
const Tensor& grad = grad_output[0];
const auto& max_length = ctx->saved_data["max_length"].toSymInt();
const auto& total_length = ctx->saved_data["total_length"].toSymInt();

// Retrieve saved variables for backward.
const auto& saved_variables = ctx->get_saved_variables();
const auto& lengths = saved_variables[0];

torch::autograd::variable_list grad_inputs(5);

static auto custom_pack_segments_backward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::pack_segments_backward", "")
.typed<at::Tensor(
const at::Tensor&,
const at::Tensor&,
const at::SymInt,
const at::SymInt)>();

grad_inputs[0] = custom_pack_segments_backward_op.call(
grad, lengths, total_length, max_length);
return grad_inputs;
}
};

Tensor pack_segments_autograd(
const Tensor& t_in,
const Tensor& lengths,
const at::SymInt max_length

) {
return PackSegments::apply(t_in, lengths, max_length)[0];
}

Tensor native_empty_like(const Tensor& self) {
return at::native::empty_like(
self,
Expand Down Expand Up @@ -2767,3 +2835,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"group_index_select_dim0", fbgemm_gpu::group_index_select_dim0);
DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row);
}

TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl("pack_segments", &fbgemm_gpu::pack_segments_autograd);
}
74 changes: 1 addition & 73 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,73 +58,6 @@ void offset_args(
}
} // namespace

// Custom PackSegments operator that is based on the Caffe2 PackSegments and
// UnpackSegments.
// Needed this to support backward pass.
class PackSegments : public torch::autograd::Function<PackSegments> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& t_in,
const Tensor& lengths,
const at::SymInt& max_length) {
const at::SymInt total_length = t_in.sym_size(0);

at::AutoDispatchBelowADInplaceOrView guard;

static auto custom_pack_segments_op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::pack_segments", "")
.typed<at::Tensor(
const at::Tensor&, const at::Tensor&, const at::SymInt)>();

Tensor res = custom_pack_segments_op.call(t_in, lengths, max_length);

ctx->saved_data["max_length"] = max_length;
ctx->saved_data["total_length"] = total_length;
ctx->save_for_backward({lengths});

return {res};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(grad_output.size() == 2 or grad_output.size() == 1);
const Tensor& grad = grad_output[0];
const auto& max_length = ctx->saved_data["max_length"].toSymInt();
const auto& total_length = ctx->saved_data["total_length"].toSymInt();

// Retrieve saved variables for backward.
const auto& saved_variables = ctx->get_saved_variables();
const auto& lengths = saved_variables[0];

torch::autograd::variable_list grad_inputs(5);

static auto custom_pack_segments_backward_op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::pack_segments_backward", "")
.typed<at::Tensor(
const at::Tensor&,
const at::Tensor&,
const at::SymInt,
const at::SymInt)>();

grad_inputs[0] = custom_pack_segments_backward_op.call(
grad, lengths, total_length, max_length);
return grad_inputs;
}
};

torch::Tensor pack_segments_autograd(
const Tensor& t_in,
const Tensor& lengths,
const at::SymInt max_length

) {
return PackSegments::apply(t_in, lengths, max_length)[0];
}

class LookupFunctionBatchedUnaryEmbeddingOp
: public torch::autograd::Function<LookupFunctionBatchedUnaryEmbeddingOp> {
public:
Expand Down Expand Up @@ -610,8 +543,7 @@ Tensor pack_segments_cuda(
const Tensor& t_in,
const Tensor& lengths,
const int64_t max_length) {
const auto& res = PackSegments::apply(t_in, lengths, max_length);
return res[0];
return fbgemm_gpu::pack_segments_forward_cuda(t_in, lengths, max_length)[0];
}

Tensor index_select_dim0_gpu(
Expand Down Expand Up @@ -683,7 +615,3 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA(
"group_index_select_dim0", fbgemm_gpu::group_index_select_dim0_gpu);
}

TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
m.impl("pack_segments", &fbgemm_gpu::pack_segments_autograd);
}
10 changes: 2 additions & 8 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,12 @@
from fbgemm_gpu import open_source # noqa: F401

# pyre-ignore[21]
from test_utils import gpu_available, gpu_unavailable, running_on_github, skipIfRocm
from test_utils import gpu_available, gpu_unavailable, skipIfRocm
except Exception:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")
from fbgemm_gpu.test.test_utils import (
gpu_available,
gpu_unavailable,
running_on_github,
skipIfRocm,
)
from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable, skipIfRocm


def unbucketize_indices_value(
Expand Down Expand Up @@ -1825,7 +1820,6 @@ def _pack_segments_ref(
),
torch_compile=st.booleans(),
)
@unittest.skipIf(*running_on_github)
@settings(deadline=None)
def test_pack_segments(
self,
Expand Down

0 comments on commit 439292a

Please sign in to comment.