Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jagged_softmax backward #1594

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 84 additions & 6 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1834,11 +1834,7 @@ __global__ __launch_bounds__(kMaxThreads) void jagged_softmax_kernel(
const int row_start = offsets[b];
const int row_end = offsets[b + 1];
const int length = min(row_end - row_start, max_L);
if (length == 0) {
for (int d = threadIdx.x; d < D; d += blockDim.x) {
output[b][d] = 0;
}
} else {
if (length != 0) {
// TODO: use shared memory and better reduction
for (int d = threadIdx.x; d < D; d += blockDim.x) {
scalar_t max_value = values[row_start][d];
Expand Down Expand Up @@ -1872,7 +1868,7 @@ Tensor jagged_softmax_forward(
device_guard.set_index(values.get_device());

const int B = offsets.numel() - 1;
const int D = values.size(-1);
const int D = values.size(1);
auto output = at::empty_like(values);

if (B > 0 && D > 0) {
Expand Down Expand Up @@ -1905,6 +1901,86 @@ Tensor jagged_softmax_forward(
return output;
}

template <typename index_t, typename scalar_t>
__global__ __launch_bounds__(kMaxThreads) void jagged_softmax_backward_kernel(
const at::PackedTensorAccessor32<scalar_t, 2> grad_output,
const at::PackedTensorAccessor32<scalar_t, 2> output,
const at::PackedTensorAccessor32<index_t, 1> offsets,
at::PackedTensorAccessor32<scalar_t, 2> grad_input,
const int max_L) {
const int B = offsets.size(0) - 1;
const int D = grad_output.size(1);

const int b_begin = blockIdx.x * blockDim.y + threadIdx.y;
const int b_step = gridDim.x * blockDim.y;
for (int b = b_begin; b < B; b += b_step) {
const int row_start = offsets[b];
const int row_end = offsets[b + 1];
const int length = min(row_end - row_start, max_L);
if (length != 0) {
// TODO: use shared memory and better reduction
for (int d = threadIdx.x; d < D; d += blockDim.x) {
scalar_t sum_value = grad_output[row_start][d] * output[row_start][d];
for (int l = 1; l < length; ++l) {
sum_value += grad_output[row_start + l][d] * output[row_start + l][d];
}

for (int l = 0; l < length; ++l) {
grad_input[row_start + l][d] =
(grad_output[row_start + l][d] - sum_value) *
output[row_start + l][d];
}
}
}
}
}

Tensor jagged_softmax_backward(
const Tensor& grad_output,
const Tensor& output,
const Tensor& offsets,
const int64_t max_L) {
TENSOR_ON_CUDA_GPU(grad_output);
TENSOR_ON_CUDA_GPU(output);
TENSOR_ON_CUDA_GPU(offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());

const int B = offsets.numel() - 1;
const int D = grad_output.size(1);
auto grad_input = at::empty_like(grad_output);

if (B > 0 && D > 0) {
const int block_dim_x =
std::min(div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
const int block_dim_y = kMaxThreads / block_dim_x;

AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(), "jagged_softmax_backward_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_output.scalar_type(),
"jagged_softmax_backward_kernel_2",
[&] {
jagged_softmax_backward_kernel<index_t, scalar_t>
<<<div_round_up(B, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
grad_output.packed_accessor32<scalar_t, 2>(),
output.packed_accessor32<scalar_t, 2>(),
offsets.packed_accessor32<index_t, 1>(),
grad_input.packed_accessor32<scalar_t, 2>(),
(int)max_L);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
return grad_input;
}

template <typename index_t, typename scalar_t>
__global__ __launch_bounds__(kMaxThreads) void jagged_jagged_bmm_kernel(
const at::PackedTensorAccessor32<scalar_t, 2> x_values,
Expand Down Expand Up @@ -3099,6 +3175,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA("jagged_softmax", fbgemm_gpu::jagged_softmax);
DISPATCH_TO_CUDA(
"jagged_softmax_forward", fbgemm_gpu::jagged_softmax_forward);
DISPATCH_TO_CUDA(
"jagged_softmax_backward", fbgemm_gpu::jagged_softmax_backward);
DISPATCH_TO_CUDA("jagged_jagged_bmm", fbgemm_gpu::jagged_jagged_bmm);
DISPATCH_TO_CUDA(
"jagged_jagged_bmm_forward", fbgemm_gpu::jagged_jagged_bmm_forward);
Expand Down
60 changes: 51 additions & 9 deletions fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,56 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
}
};

class JaggedSoftmaxOp : public torch::autograd::Function<JaggedSoftmaxOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& values,
const Tensor& offsets,
const int64_t max_L) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_softmax_forward", "")
.typed<Tensor(
const Tensor& values, const Tensor& offsets, int64_t max_L)>();

auto output = op.call(values, offsets, max_L);

ctx->save_for_backward({output, offsets});
ctx->saved_data["max_L"] = max_L;

return {output};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
Tensor output = *savedItr++;
Tensor offsets = *savedItr++;
int64_t max_L = ctx->saved_data["max_L"].toInt();
TORCH_CHECK(grad_outputs.size() == 1);

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_softmax_backward", "")
.typed<Tensor(
const Tensor& grad_output,
const Tensor& output,
const Tensor& offsets,
int64_t max_L)>();

auto grad_input = op.call(grad_outputs[0], output, offsets, max_L);

return {
grad_input,
torch::autograd::Variable(), // offsets
torch::autograd::Variable() // max_L
};
}
};

} // namespace

///@ingroup jagged-tensor-ops-cpu
Expand Down Expand Up @@ -416,15 +466,7 @@ std::tuple<Tensor, Tensor> jagged_softmax(
const Tensor& values,
const Tensor& offsets,
const int64_t max_L) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_softmax_forward", "")
.typed<Tensor(
const Tensor& values, const Tensor& offsets, int64_t max_L)>();

auto output = op.call(values, offsets, max_L);

return {output, offsets};
return {JaggedSoftmaxOp::apply(values, offsets, max_L)[0], offsets};
}

Tensor jagged_jagged_bmm(
Expand Down
73 changes: 72 additions & 1 deletion fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,9 @@ void jagged_softmax_kernel(
const int row_end = offsets[b + 1];
const int length = std::min(row_end - row_start, (int)max_L);

if (length == 0)
continue;

for (int d = 0; d < D; ++d) {
// use is_cuda=true because acc_type<float, false> = double is too
// conservative
Expand All @@ -1093,8 +1096,9 @@ Tensor jagged_softmax_forward(
const Tensor& offsets,
const int64_t max_L) {
TENSOR_ON_CPU(values);
TENSOR_ON_CPU(offsets);
const int B = offsets.numel() - 1;
const int D = values.size(-1);
const int D = values.size(1);
auto output = at::empty_like(values);

if (B > 0 && D > 0) {
Expand All @@ -1117,6 +1121,69 @@ Tensor jagged_softmax_forward(
return output;
}

template <typename index_t, typename scalar_t>
void jagged_softmax_backward_kernel(
const at::TensorAccessor<scalar_t, 2>& grad_output,
const at::TensorAccessor<scalar_t, 2>& output,
const at::TensorAccessor<index_t, 1>& offsets,
at::TensorAccessor<scalar_t, 2> grad_input,
const int64_t max_L) {
const int B = offsets.size(0) - 1;
const int D = grad_output.size(1);
for (int b = 0; b < B; ++b) {
const int row_start = offsets[b];
const int row_end = offsets[b + 1];
const int length = std::min(row_end - row_start, (int)max_L);
if (length == 0)
continue;
for (int d = 0; d < D; ++d) {
at::acc_type<scalar_t, true> sum_value =
grad_output[row_start][d] * output[row_start][d];
for (int l = 1; l < length; ++l) {
sum_value += grad_output[row_start + l][d] * output[row_start + l][d];
}
for (int l = 0; l < length; ++l) {
grad_input[row_start + l][d] =
(grad_output[row_start + l][d] - sum_value) *
output[row_start + l][d];
}
}
}
}

Tensor jagged_softmax_backward(
const Tensor& grad_output,
const Tensor& output,
const Tensor& offsets,
const int64_t max_L) {
TENSOR_ON_CPU(grad_output);
TENSOR_ON_CPU(output);
TENSOR_ON_CPU(offsets);
const int B = offsets.numel() - 1;
const int D = grad_output.size(1);
auto grad_input = at::empty_like(grad_output);

if (B > 0 && D > 0) {
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(), "jagged_backward_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
grad_output.scalar_type(),
"jagged_softmax_backward_kernel_2",
[&] {
jagged_softmax_backward_kernel<index_t, scalar_t>(
grad_output.accessor<scalar_t, 2>(),
output.accessor<scalar_t, 2>(),
offsets.accessor<index_t, 1>(),
grad_input.accessor<scalar_t, 2>(),
max_L);
});
});
}
return grad_input;
}

template <typename index_t, typename scalar_t>
void jagged_jagged_bmm_kernel(
const at::TensorAccessor<scalar_t, 2>& x_values,
Expand Down Expand Up @@ -1300,6 +1367,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"jagged_softmax(Tensor values, Tensor x_offsets, int max_L) -> (Tensor, Tensor)");
m.def(
"jagged_softmax_forward(Tensor values, Tensor x_offsets, int max_L) -> Tensor");
m.def(
"jagged_softmax_backward(Tensor grad_output, Tensor output, Tensor x_offsets, int max_L) -> Tensor");
m.def(
"jagged_jagged_bmm(Tensor x_values, Tensor y_values, Tensor x_offsets, int max_L) -> Tensor");
m.def(
Expand Down Expand Up @@ -1362,6 +1431,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"masked_select_jagged_1d", fbgemm_gpu::masked_select_jagged_1d);
DISPATCH_TO_CPU("jagged_softmax", fbgemm_gpu::jagged_softmax);
DISPATCH_TO_CPU("jagged_softmax_forward", fbgemm_gpu::jagged_softmax_forward);
DISPATCH_TO_CPU(
"jagged_softmax_backward", fbgemm_gpu::jagged_softmax_backward);
DISPATCH_TO_CPU("jagged_jagged_bmm", fbgemm_gpu::jagged_jagged_bmm);
DISPATCH_TO_CPU(
"jagged_jagged_bmm_forward", fbgemm_gpu::jagged_jagged_bmm_forward);
Expand Down
43 changes: 28 additions & 15 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,9 +1761,9 @@ def test_keyed_jagged_index_select_dim1(

# pyre-ignore [56]
@given(
B=st.integers(0, 32),
max_L=st.integers(1, 32),
D=st.integers(0, 32),
B=st.integers(1, 512),
max_L=st.integers(1, 1000),
D=st.integers(1, 32),
dtype=st.sampled_from([torch.float, torch.double]),
device_type=st.sampled_from(["cpu", "cuda"])
if gpu_available
Expand All @@ -1778,32 +1778,45 @@ def test_jagged_softmax(
dtype: torch.dtype,
device_type: str,
) -> None:
assume(B != 0)
device = torch.device(device_type)
torch.backends.cuda.matmul.allow_tf32 = False
lengths = torch.randint(max_L + 1, size=(B,), device=device)
total_length = int(lengths.sum().item())
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
values = torch.rand((offsets[-1], D), dtype=dtype, device=device)
values = torch.rand(
(total_length, D), requires_grad=True, dtype=dtype, device=device
)
output, _ = torch.ops.fbgemm.jagged_softmax(
values,
offsets,
max_L,
)
dense = torch.ops.fbgemm.jagged_to_padded_dense(
values,
[offsets],
max_lengths=[max_L],
padding_value=-5e7,
)
dense_softmax = torch.nn.functional.softmax(
dense.transpose(1, 2), dim=-1
).permute(0, 2, 1)
values_ref = values.detach().clone().requires_grad_(True)
output_ref, _ = torch.ops.fbgemm.dense_to_jagged(
dense_softmax, [offsets], offsets[-1]
torch.nn.functional.softmax(
torch.ops.fbgemm.jagged_to_padded_dense(
values_ref,
[offsets],
max_lengths=[max_L],
padding_value=-5e7,
).transpose(1, 2),
dim=-1,
).permute(0, 2, 1),
[offsets],
total_length,
)

# verify forward
torch.testing.assert_close(output, output_ref)

# verify backward
grad_output = output.detach().clone().requires_grad_(True)

output.backward(grad_output)
output_ref.backward(grad_output)

torch.testing.assert_close(values.grad, values_ref.grad)

# pyre-ignore [56]
@given(
B=st.integers(10, 512),
Expand Down