Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add backward of jagged_to_padded_dense #1008

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 186 additions & 96 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -226,58 +226,6 @@ void jagged_dense_elementwise_dense_output_(
#undef INVOKE_KERNEL_WITH_DIM
}

// Almost identical copy of jagged_to_padded_dense in jagged_tensor_ops_cpu.cpp
Tensor jagged_to_padded_dense(
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const int64_t padding_value) {
const size_t num_jagged_dim = offsets.size();
TORCH_CHECK(
max_lengths.size() == num_jagged_dim,
"max_lengths.size(), ",
max_lengths.size(),
" != num_jagged_dim, ",
num_jagged_dim);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());

const Tensor values_canonicalized = values.view(
{values.size(0),
std::accumulate(
values.sizes().begin() + 1,
values.sizes().end(),
1,
std::multiplies<size_t>())});
at::DimVector padded_values_shape({offsets[0].size(0) - 1});
padded_values_shape.insert(
padded_values_shape.end(), max_lengths.begin(), max_lengths.end());
if (values.dim() > 1) {
padded_values_shape.push_back(values.size(-1));
}
Tensor padded_values = at::empty(padded_values_shape, values.options());
Tensor padded_values_view =
values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values;

AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
values.scalar_type(),
"jagged_to_padded_dense",
[&] {
jagged_dense_elementwise_dense_output_<scalar_t>(
values_canonicalized,
offsets,
padded_values_view, // dummy not used in the lambda function
padded_values_view,
[] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t {
return x;
},
static_cast<scalar_t>(padding_value));
});

return padded_values;
}

template <typename scalar_t, typename F>
Tensor jagged_dense_elementwise_dense_output_(
const Tensor& x_values,
Expand Down Expand Up @@ -396,6 +344,117 @@ Tensor jagged_dense_elementwise_jagged_output_(
return output;
}

class JaggedToPaddedDenseGPUOp
: public torch::autograd::Function<JaggedToPaddedDenseGPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const int64_t padding_value) {
ctx->save_for_backward(offsets);
ctx->saved_data["total_L"] = values.size(0);

const size_t num_jagged_dim = offsets.size();
TORCH_CHECK(
max_lengths.size() == num_jagged_dim,
"max_lengths.size(), ",
max_lengths.size(),
" != num_jagged_dim, ",
num_jagged_dim);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());

const Tensor values_canonicalized = values.view(
{values.size(0),
std::accumulate(
values.sizes().begin() + 1,
values.sizes().end(),
1,
std::multiplies<size_t>())});
at::DimVector padded_values_shape({offsets[0].size(0) - 1});
padded_values_shape.insert(
padded_values_shape.end(), max_lengths.begin(), max_lengths.end());
if (values.dim() > 1) {
padded_values_shape.push_back(values.size(-1));
}
Tensor padded_values = at::empty(padded_values_shape, values.options());
Tensor padded_values_view =
values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values;

AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
values.scalar_type(),
"jagged_to_padded_dense",
[&] {
jagged_dense_elementwise_dense_output_<scalar_t>(
values_canonicalized,
offsets,
padded_values_view, // dummy not used in the lambda function
padded_values_view,
[] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t {
return x;
},
static_cast<scalar_t>(padding_value));
});

return {padded_values};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto offsets = ctx->get_saved_variables();
int32_t total_L = ctx->saved_data["total_L"].toInt();
TORCH_CHECK(grad_outputs.size() == 1);

TORCH_CHECK(total_L >= 0);
auto grad_padded_values = grad_outputs[0];
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_padded_values.get_device());

int32_t D = grad_padded_values.size(-1);
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_values, // dummy not used in the lambda function
{offsets},
grad_padded_values,
grad_values,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
});
});

return {
grad_values,
torch::autograd::Variable(), // offsets
torch::autograd::Variable(), // max_lengths
torch::autograd::Variable(), // padding_value
};
}
};

Tensor jagged_to_padded_dense(
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const int64_t padding_value) {
return JaggedToPaddedDenseGPUOp::apply(
values, offsets, max_lengths, padding_value)[0];
}

Tensor
jagged_2d_to_dense(Tensor values, Tensor offsets, int64_t max_sequence_length) {
return jagged_to_padded_dense(
values, {offsets}, {max_sequence_length}, /*padding_value=*/0L);
}

class JaggedDenseAddGPUOp
: public torch::autograd::Function<JaggedDenseAddGPUOp> {
public:
Expand All @@ -406,7 +465,6 @@ class JaggedDenseAddGPUOp
const Tensor& y) {
ctx->save_for_backward(x_offsets);
ctx->saved_data["x_values_shape"] = x_values.sizes();
ctx->saved_data["y_shape"] = y.sizes();

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());
Expand All @@ -431,7 +489,6 @@ class JaggedDenseAddGPUOp
torch::autograd::variable_list grad_outputs) {
auto offsets = ctx->get_saved_variables();
auto x_values_shape = ctx->saved_data["x_values_shape"].toIntVector();
auto y_shape = ctx->saved_data["y_shape"].toIntVector();
TORCH_CHECK(grad_outputs.size() == 1);

at::cuda::OptionalCUDAGuard device_guard;
Expand Down Expand Up @@ -466,6 +523,73 @@ Tensor jagged_dense_elementwise_add(
return JaggedDenseAddGPUOp::apply(x_values, x_offsets, y)[0];
}

// Unlike JaggedDenseAddGPUOp that treats "zeros" as zeros so adding with
// a dense tensor results in a dense tensor, this operator treats "zeros" as
// undefined so resulting a jagged tensor.
class JaggedDenseAddJaggedOutputGPUOp
: public torch::autograd::Function<JaggedDenseAddJaggedOutputGPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
ctx->save_for_backward(x_offsets);
ctx->saved_data["y_shape"] = y.sizes();

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());

Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_dense_add_forward", [&] {
output = jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
});
});

return {output};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto offsets = ctx->get_saved_variables();
auto y_shape = ctx->saved_data["y_shape"].toIntVector();
TORCH_CHECK(grad_outputs.size() == 1);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_outputs[0].get_device());

Tensor y_values_grad = jagged_to_padded_dense(
grad_outputs[0],
offsets,
std::vector<int64_t>(y_shape.begin() + 1, y_shape.end() - 1),
/*padding_value=*/0);
TORCH_CHECK(y_values_grad.sizes() == y_shape);

return {
grad_outputs[0],
torch::autograd::Variable(), // x_offsets
y_values_grad};
}
};

// output = x + y where x is jagged, y is dense, and output is jagged
std::tuple<Tensor, std::vector<Tensor>>
jagged_dense_elementwise_add_jagged_output(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
return {
JaggedDenseAddJaggedOutputGPUOp::apply(x_values, x_offsets, y)[0],
x_offsets};
}

/**
* output = f(x, y) where x and y are jagged (and share x_offsets), and output
* is dense.
Expand Down Expand Up @@ -924,45 +1048,6 @@ Tensor batched_dense_vec_jagged_2d_mul(

} // namespace

Tensor
jagged_2d_to_dense_forward_cuda(Tensor values, Tensor offsets, int32_t max_L) {
TORCH_CHECK(values.dim() == 2);
TORCH_CHECK(offsets.dim() == 1);
TORCH_CHECK(max_L > 0);

return jagged_to_padded_dense(values, {offsets}, {max_L}, 0);
}

Tensor jagged_2d_to_dense_backward_cuda(
Tensor grad_padded_values,
Tensor offsets,
int32_t total_L) {
TORCH_CHECK(grad_padded_values.dim() == 3);
TORCH_CHECK(offsets.dim() == 1);
TORCH_CHECK(total_L >= 0);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_padded_values.get_device());

int32_t D = grad_padded_values.size(2);
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_values, // dummy not used in the lambda function
{offsets},
grad_padded_values,
grad_values,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
});
});

return grad_values;
}

Tensor jagged_1d_to_dense_gpu(
Tensor values,
Tensor offsets,
Expand Down Expand Up @@ -1023,10 +1108,11 @@ stacked_jagged_2d_to_dense_forward_cuda(
});
offsets_tensor_per_key.push_back(offsets);

padded_values_per_key.push_back(jagged_2d_to_dense_forward_cuda(
padded_values_per_key.push_back(jagged_to_padded_dense(
values.slice(0, offset_per_key[t], offset_per_key[t + 1]),
offsets,
max_L));
{offsets},
{max_L},
/*padding_value=*/0L));
}

return std::make_tuple(padded_values_per_key, offsets_tensor_per_key);
Expand Down Expand Up @@ -1128,8 +1214,12 @@ std::vector<Tensor> stacked_jagged_1d_to_dense_gpu(
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA(
"jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense);
DISPATCH_TO_CUDA("jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense);
DISPATCH_TO_CUDA(
"jagged_dense_elementwise_add", fbgemm_gpu::jagged_dense_elementwise_add);
DISPATCH_TO_CUDA(
"jagged_dense_elementwise_add_jagged_output",
fbgemm_gpu::jagged_dense_elementwise_add_jagged_output);
DISPATCH_TO_CUDA(
"jagged_dense_elementwise_mul", fbgemm_gpu::jagged_dense_elementwise_mul);
DISPATCH_TO_CUDA(
Expand Down
Loading