diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 470f363e38..f2bedcb18e 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -1834,11 +1834,7 @@ __global__ __launch_bounds__(kMaxThreads) void jagged_softmax_kernel( const int row_start = offsets[b]; const int row_end = offsets[b + 1]; const int length = min(row_end - row_start, max_L); - if (length == 0) { - for (int d = threadIdx.x; d < D; d += blockDim.x) { - output[b][d] = 0; - } - } else { + if (length != 0) { // TODO: use shared memory and better reduction for (int d = threadIdx.x; d < D; d += blockDim.x) { scalar_t max_value = values[row_start][d]; @@ -1872,7 +1868,7 @@ Tensor jagged_softmax_forward( device_guard.set_index(values.get_device()); const int B = offsets.numel() - 1; - const int D = values.size(-1); + const int D = values.size(1); auto output = at::empty_like(values); if (B > 0 && D > 0) { @@ -1905,6 +1901,86 @@ Tensor jagged_softmax_forward( return output; } +template +__global__ __launch_bounds__(kMaxThreads) void jagged_softmax_backward_kernel( + const at::PackedTensorAccessor32 grad_output, + const at::PackedTensorAccessor32 output, + const at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor32 grad_input, + const int max_L) { + const int B = offsets.size(0) - 1; + const int D = grad_output.size(1); + + const int b_begin = blockIdx.x * blockDim.y + threadIdx.y; + const int b_step = gridDim.x * blockDim.y; + for (int b = b_begin; b < B; b += b_step) { + const int row_start = offsets[b]; + const int row_end = offsets[b + 1]; + const int length = min(row_end - row_start, max_L); + if (length != 0) { + // TODO: use shared memory and better reduction + for (int d = threadIdx.x; d < D; d += blockDim.x) { + scalar_t sum_value = grad_output[row_start][d] * output[row_start][d]; + for (int l = 1; l < length; ++l) { + sum_value += grad_output[row_start + l][d] * output[row_start + l][d]; + } + + for (int l = 0; l < length; ++l) { + grad_input[row_start + l][d] = + (grad_output[row_start + l][d] - sum_value) * + output[row_start + l][d]; + } + } + } + } +} + +Tensor jagged_softmax_backward( + const Tensor& grad_output, + const Tensor& output, + const Tensor& offsets, + const int64_t max_L) { + TENSOR_ON_CUDA_GPU(grad_output); + TENSOR_ON_CUDA_GPU(output); + TENSOR_ON_CUDA_GPU(offsets); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(grad_output.get_device()); + + const int B = offsets.numel() - 1; + const int D = grad_output.size(1); + auto grad_input = at::empty_like(grad_output); + + if (B > 0 && D > 0) { + const int block_dim_x = + std::min(div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads); + const int block_dim_y = kMaxThreads / block_dim_x; + + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), "jagged_softmax_backward_kernel_1", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + grad_output.scalar_type(), + "jagged_softmax_backward_kernel_2", + [&] { + jagged_softmax_backward_kernel + <<>>( + grad_output.packed_accessor32(), + output.packed_accessor32(), + offsets.packed_accessor32(), + grad_input.packed_accessor32(), + (int)max_L); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + } + return grad_input; +} + template __global__ __launch_bounds__(kMaxThreads) void jagged_jagged_bmm_kernel( const at::PackedTensorAccessor32 x_values, @@ -3099,6 +3175,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA("jagged_softmax", fbgemm_gpu::jagged_softmax); DISPATCH_TO_CUDA( "jagged_softmax_forward", fbgemm_gpu::jagged_softmax_forward); + DISPATCH_TO_CUDA( + "jagged_softmax_backward", fbgemm_gpu::jagged_softmax_backward); DISPATCH_TO_CUDA("jagged_jagged_bmm", fbgemm_gpu::jagged_jagged_bmm); DISPATCH_TO_CUDA( "jagged_jagged_bmm_forward", fbgemm_gpu::jagged_jagged_bmm_forward); diff --git a/fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp b/fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp index 3a915e8fc5..b99f7394d8 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp @@ -296,6 +296,56 @@ class DenseToJaggedOp : public torch::autograd::Function { } }; +class JaggedSoftmaxOp : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& values, + const Tensor& offsets, + const int64_t max_L) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::jagged_softmax_forward", "") + .typed(); + + auto output = op.call(values, offsets, max_L); + + ctx->save_for_backward({output, offsets}); + ctx->saved_data["max_L"] = max_L; + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + const auto saved = ctx->get_saved_variables(); + auto savedItr = std::begin(saved); + Tensor output = *savedItr++; + Tensor offsets = *savedItr++; + int64_t max_L = ctx->saved_data["max_L"].toInt(); + TORCH_CHECK(grad_outputs.size() == 1); + + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::jagged_softmax_backward", "") + .typed(); + + auto grad_input = op.call(grad_outputs[0], output, offsets, max_L); + + return { + grad_input, + torch::autograd::Variable(), // offsets + torch::autograd::Variable() // max_L + }; + } +}; + } // namespace ///@ingroup jagged-tensor-ops-cpu @@ -416,15 +466,7 @@ std::tuple jagged_softmax( const Tensor& values, const Tensor& offsets, const int64_t max_L) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::jagged_softmax_forward", "") - .typed(); - - auto output = op.call(values, offsets, max_L); - - return {output, offsets}; + return {JaggedSoftmaxOp::apply(values, offsets, max_L)[0], offsets}; } Tensor jagged_jagged_bmm( diff --git a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp index ccbf6d1cf8..8921c6b3af 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp @@ -1068,6 +1068,9 @@ void jagged_softmax_kernel( const int row_end = offsets[b + 1]; const int length = std::min(row_end - row_start, (int)max_L); + if (length == 0) + continue; + for (int d = 0; d < D; ++d) { // use is_cuda=true because acc_type = double is too // conservative @@ -1093,8 +1096,9 @@ Tensor jagged_softmax_forward( const Tensor& offsets, const int64_t max_L) { TENSOR_ON_CPU(values); + TENSOR_ON_CPU(offsets); const int B = offsets.numel() - 1; - const int D = values.size(-1); + const int D = values.size(1); auto output = at::empty_like(values); if (B > 0 && D > 0) { @@ -1117,6 +1121,69 @@ Tensor jagged_softmax_forward( return output; } +template +void jagged_softmax_backward_kernel( + const at::TensorAccessor& grad_output, + const at::TensorAccessor& output, + const at::TensorAccessor& offsets, + at::TensorAccessor grad_input, + const int64_t max_L) { + const int B = offsets.size(0) - 1; + const int D = grad_output.size(1); + for (int b = 0; b < B; ++b) { + const int row_start = offsets[b]; + const int row_end = offsets[b + 1]; + const int length = std::min(row_end - row_start, (int)max_L); + if (length == 0) + continue; + for (int d = 0; d < D; ++d) { + at::acc_type sum_value = + grad_output[row_start][d] * output[row_start][d]; + for (int l = 1; l < length; ++l) { + sum_value += grad_output[row_start + l][d] * output[row_start + l][d]; + } + for (int l = 0; l < length; ++l) { + grad_input[row_start + l][d] = + (grad_output[row_start + l][d] - sum_value) * + output[row_start + l][d]; + } + } + } +} + +Tensor jagged_softmax_backward( + const Tensor& grad_output, + const Tensor& output, + const Tensor& offsets, + const int64_t max_L) { + TENSOR_ON_CPU(grad_output); + TENSOR_ON_CPU(output); + TENSOR_ON_CPU(offsets); + const int B = offsets.numel() - 1; + const int D = grad_output.size(1); + auto grad_input = at::empty_like(grad_output); + + if (B > 0 && D > 0) { + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), "jagged_backward_kernel_1", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + grad_output.scalar_type(), + "jagged_softmax_backward_kernel_2", + [&] { + jagged_softmax_backward_kernel( + grad_output.accessor(), + output.accessor(), + offsets.accessor(), + grad_input.accessor(), + max_L); + }); + }); + } + return grad_input; +} + template void jagged_jagged_bmm_kernel( const at::TensorAccessor& x_values, @@ -1300,6 +1367,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "jagged_softmax(Tensor values, Tensor x_offsets, int max_L) -> (Tensor, Tensor)"); m.def( "jagged_softmax_forward(Tensor values, Tensor x_offsets, int max_L) -> Tensor"); + m.def( + "jagged_softmax_backward(Tensor grad_output, Tensor output, Tensor x_offsets, int max_L) -> Tensor"); m.def( "jagged_jagged_bmm(Tensor x_values, Tensor y_values, Tensor x_offsets, int max_L) -> Tensor"); m.def( @@ -1362,6 +1431,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { "masked_select_jagged_1d", fbgemm_gpu::masked_select_jagged_1d); DISPATCH_TO_CPU("jagged_softmax", fbgemm_gpu::jagged_softmax); DISPATCH_TO_CPU("jagged_softmax_forward", fbgemm_gpu::jagged_softmax_forward); + DISPATCH_TO_CPU( + "jagged_softmax_backward", fbgemm_gpu::jagged_softmax_backward); DISPATCH_TO_CPU("jagged_jagged_bmm", fbgemm_gpu::jagged_jagged_bmm); DISPATCH_TO_CPU( "jagged_jagged_bmm_forward", fbgemm_gpu::jagged_jagged_bmm_forward); diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index fcbbc99300..a2e53d29cb 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -1761,9 +1761,9 @@ def test_keyed_jagged_index_select_dim1( # pyre-ignore [56] @given( - B=st.integers(0, 32), - max_L=st.integers(1, 32), - D=st.integers(0, 32), + B=st.integers(1, 512), + max_L=st.integers(1, 1000), + D=st.integers(1, 32), dtype=st.sampled_from([torch.float, torch.double]), device_type=st.sampled_from(["cpu", "cuda"]) if gpu_available @@ -1778,32 +1778,45 @@ def test_jagged_softmax( dtype: torch.dtype, device_type: str, ) -> None: - assume(B != 0) device = torch.device(device_type) torch.backends.cuda.matmul.allow_tf32 = False lengths = torch.randint(max_L + 1, size=(B,), device=device) + total_length = int(lengths.sum().item()) offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - values = torch.rand((offsets[-1], D), dtype=dtype, device=device) + values = torch.rand( + (total_length, D), requires_grad=True, dtype=dtype, device=device + ) output, _ = torch.ops.fbgemm.jagged_softmax( values, offsets, max_L, ) - dense = torch.ops.fbgemm.jagged_to_padded_dense( - values, - [offsets], - max_lengths=[max_L], - padding_value=-5e7, - ) - dense_softmax = torch.nn.functional.softmax( - dense.transpose(1, 2), dim=-1 - ).permute(0, 2, 1) + values_ref = values.detach().clone().requires_grad_(True) output_ref, _ = torch.ops.fbgemm.dense_to_jagged( - dense_softmax, [offsets], offsets[-1] + torch.nn.functional.softmax( + torch.ops.fbgemm.jagged_to_padded_dense( + values_ref, + [offsets], + max_lengths=[max_L], + padding_value=-5e7, + ).transpose(1, 2), + dim=-1, + ).permute(0, 2, 1), + [offsets], + total_length, ) + # verify forward torch.testing.assert_close(output, output_ref) + # verify backward + grad_output = output.detach().clone().requires_grad_(True) + + output.backward(grad_output) + output_ref.backward(grad_output) + + torch.testing.assert_close(values.grad, values_ref.grad) + # pyre-ignore [56] @given( B=st.integers(10, 512),