From ce6a758eb1f1027859d583aa8b9e00d38081ec01 Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Tue, 14 Dec 2021 14:04:04 -0800 Subject: [PATCH] unify function signature of jagged_xD_to_dense (#813) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/813 As title Reviewed By: jiaqizhai, jianyuh Differential Revision: D33066551 fbshipit-source-id: 07f1033412f23d38b8b4cb7e1e86ee69f9f6d265 --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 4 +- fbgemm_gpu/src/sparse_ops.cu | 70 +++++++++++----------- fbgemm_gpu/src/sparse_ops_cpu.cpp | 36 +++++------ fbgemm_gpu/src/sparse_ops_gpu.cpp | 18 +++--- fbgemm_gpu/test/sparse_ops_test.py | 66 ++++++++++---------- 5 files changed, 97 insertions(+), 97 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index f3f672e1dc..88339a1861 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -154,12 +154,12 @@ at::Tensor batched_unary_embeddings_backward_cuda( const at::Tensor& indices); at::Tensor jagged_2d_to_dense_forward_cuda( - at::Tensor embeddings, + at::Tensor values, at::Tensor offsets, int32_t max_L); at::Tensor jagged_2d_to_dense_backward_cuda( - at::Tensor grad_padded_embeddings, + at::Tensor grad_padded_values, at::Tensor offsets, int32_t total_L); diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index 4a059be2c8..3cccfddf16 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -1284,8 +1284,8 @@ __global__ void jagged_2d_to_dense_forward_kernel( int32_t max_L, int32_t D, at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor64 embeddings, - at::PackedTensorAccessor64 padded_embeddings) { + at::PackedTensorAccessor64 values, + at::PackedTensorAccessor64 padded_values) { int32_t b_l = blockIdx.x * blockDim.y + threadIdx.y; int32_t l = b_l / B; int32_t b = b_l % B; @@ -1298,39 +1298,39 @@ __global__ void jagged_2d_to_dense_forward_kernel( if (l < length) { for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) { if (d + threadIdx.x < D) { - padded_embeddings[b][l][d + threadIdx.x] = - embeddings[row_start + l][d + threadIdx.x]; + padded_values[b][l][d + threadIdx.x] = + values[row_start + l][d + threadIdx.x]; } } } else { for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) { if (d + threadIdx.x < D) { - padded_embeddings[b][l][d + threadIdx.x] = 0.0; + padded_values[b][l][d + threadIdx.x] = 0.0; } } } } Tensor jagged_2d_to_dense_forward_cuda( - Tensor embeddings, + Tensor values, Tensor offsets, int32_t max_L) { - TORCH_CHECK(embeddings.dim() == 2); + TORCH_CHECK(values.dim() == 2); TORCH_CHECK(offsets.dim() == 1); TORCH_CHECK(max_L > 0); at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(embeddings.get_device()); + device_guard.set_index(values.get_device()); - int32_t D = embeddings.size(1); + int32_t D = values.size(1); int32_t B = offsets.numel() - 1; - auto padded_embeddings = at::empty({B, max_L, D}, embeddings.options()); - const auto embeddings_contig = embeddings.contiguous(); + auto padded_values = at::empty({B, max_L, D}, values.options()); + const auto values_contig = values.contiguous(); const auto offsets_contig = offsets.contiguous(); AT_DISPATCH_INDEX_TYPES( offsets.scalar_type(), "jagged_2d_to_dense_forward_kernel_1", ([&]() { AT_DISPATCH_FLOATING_TYPES_AND_HALF( - embeddings.scalar_type(), + values.scalar_type(), "jagged_2d_to_dense_forward_kernel_2", ([&]() { jagged_2d_to_dense_forward_kernel @@ -1346,12 +1346,12 @@ Tensor jagged_2d_to_dense_forward_cuda( max_L, D, offsets_contig.packed_accessor32(), - embeddings_contig.packed_accessor64(), - padded_embeddings.packed_accessor64()); + values_contig.packed_accessor64(), + padded_values.packed_accessor64()); })); })); - return padded_embeddings; + return padded_values; } template @@ -1360,8 +1360,8 @@ __global__ void jagged_2d_to_dense_backward_kernel( int32_t max_L, int32_t D, at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor64 grad_padded_embeddings, - at::PackedTensorAccessor64 grad_embeddings) { + at::PackedTensorAccessor64 grad_padded_values, + at::PackedTensorAccessor64 grad_values) { int32_t b_l = blockIdx.x * blockDim.y + threadIdx.y; int32_t l = b_l / B; int32_t b = b_l % B; @@ -1374,37 +1374,37 @@ __global__ void jagged_2d_to_dense_backward_kernel( if (l < length) { for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) { if (d + threadIdx.x < D) { - grad_embeddings[row_start + l][d + threadIdx.x] = - grad_padded_embeddings[b][l][d + threadIdx.x]; + grad_values[row_start + l][d + threadIdx.x] = + grad_padded_values[b][l][d + threadIdx.x]; } } } } Tensor jagged_2d_to_dense_backward_cuda( - Tensor grad_padded_embeddings, + Tensor grad_padded_values, Tensor offsets, int32_t total_L) { - TORCH_CHECK(grad_padded_embeddings.dim() == 3); + TORCH_CHECK(grad_padded_values.dim() == 3); TORCH_CHECK(offsets.dim() == 1); TORCH_CHECK(total_L >= 0); - TORCH_CHECK(offsets.numel() == grad_padded_embeddings.size(0) + 1); + TORCH_CHECK(offsets.numel() == grad_padded_values.size(0) + 1); at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_padded_embeddings.get_device()); - - int32_t B = grad_padded_embeddings.size(0); - int32_t max_L = grad_padded_embeddings.size(1); - int32_t D = grad_padded_embeddings.size(2); - auto grad_embeddings = - at::zeros({total_L, D}, grad_padded_embeddings.options()); - const auto grad_padded_embeddings_config = - grad_padded_embeddings.contiguous(); + device_guard.set_index(grad_padded_values.get_device()); + + int32_t B = grad_padded_values.size(0); + int32_t max_L = grad_padded_values.size(1); + int32_t D = grad_padded_values.size(2); + auto grad_values = + at::zeros({total_L, D}, grad_padded_values.options()); + const auto grad_padded_values_config = + grad_padded_values.contiguous(); const auto offsets_contig = offsets.contiguous(); AT_DISPATCH_INDEX_TYPES( offsets.scalar_type(), "jagged_2d_to_dense_backward_kernel_1", ([&]() { AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_padded_embeddings.scalar_type(), + grad_padded_values.scalar_type(), "jagged_2d_to_dense_backward_kernel_2", ([&]() { jagged_2d_to_dense_backward_kernel @@ -1420,13 +1420,13 @@ Tensor jagged_2d_to_dense_backward_cuda( max_L, D, offsets_contig.packed_accessor32(), - grad_padded_embeddings_config + grad_padded_values_config .packed_accessor64(), - grad_embeddings.packed_accessor64()); + grad_values.packed_accessor64()); })); })); - return grad_embeddings; + return grad_values; } template diff --git a/fbgemm_gpu/src/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops_cpu.cpp index 79ad535fcc..f3fa7e187b 100644 --- a/fbgemm_gpu/src/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops_cpu.cpp @@ -839,8 +839,8 @@ void jagged_2d_to_dense_forward_kernel( int32_t max_L, int32_t D, const index_t* offsets, - const scalar_t* embeddings_data, - scalar_t* padded_embeddings_data) { + const scalar_t* values_data, + scalar_t* padded_values_data) { const auto block_size = max_L * D; const auto embedding_byte_size = D * sizeof(scalar_t); for (auto b = 0; b < B; ++b) { @@ -852,53 +852,53 @@ void jagged_2d_to_dense_forward_kernel( } auto padding_length = max_L - length; memcpy( - &padded_embeddings_data[b * block_size], - &embeddings_data[start_idx * D], + &padded_values_data[b * block_size], + &values_data[start_idx * D], length * embedding_byte_size); memset( - &padded_embeddings_data[b * block_size + length * D], + &padded_values_data[b * block_size + length * D], 0, padding_length * embedding_byte_size); } } Tensor jagged_2d_to_dense_forward_cpu( - Tensor embeddings, + Tensor values, Tensor offsets, int64_t max_L) { - TORCH_CHECK(embeddings.dim() == 2); + TORCH_CHECK(values.dim() == 2); TORCH_CHECK(offsets.dim() == 1); TORCH_CHECK(max_L > 0); const auto B = offsets.numel() - 1; - const auto D = embeddings.size(1); - const auto embeddings_contig = embeddings.expect_contiguous(); + const auto D = values.size(1); + const auto values_contig = values.expect_contiguous(); const auto offsets_contig = offsets.expect_contiguous(); - if (embeddings.size(0) == 0) { - return at::zeros({B, max_L, D}, embeddings.options()); + if (values.size(0) == 0) { + return at::zeros({B, max_L, D}, values.options()); } - auto padded_embeddings = at::empty({B, max_L, D}, embeddings.options()); + auto padded_values = at::empty({B, max_L, D}, values.options()); AT_DISPATCH_INDEX_TYPES( offsets_contig->scalar_type(), "jagged_2d_to_dense_forward_by_offsets", ([&]() { AT_DISPATCH_FLOATING_TYPES_AND_HALF( - embeddings_contig->scalar_type(), - "jagged_2d_to_dense_forward_by_embeddings", + values_contig->scalar_type(), + "jagged_2d_to_dense_forward_by_values", ([&]() { jagged_2d_to_dense_forward_kernel( B, max_L, D, offsets_contig->data_ptr(), - embeddings_contig->data_ptr(), - padded_embeddings.data_ptr()); + values_contig->data_ptr(), + padded_values.data_ptr()); })); })); - return padded_embeddings; + return padded_values; } template @@ -1193,7 +1193,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor"); m.def( - "jagged_2d_to_dense(Tensor embeddings, Tensor offsets, int max_sequence_length) -> Tensor"); + "jagged_2d_to_dense(Tensor values, Tensor offsets, int max_sequence_length) -> Tensor"); m.def( "jagged_1d_to_dense(Tensor values, Tensor offsets, int max_sequence_length, int padding_value) -> Tensor"); m.def( diff --git a/fbgemm_gpu/src/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops_gpu.cpp index 1bafc682ad..0a9b3cafed 100644 --- a/fbgemm_gpu/src/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops_gpu.cpp @@ -62,15 +62,15 @@ class Jagged2DToDenseGPUOp public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - Tensor embeddings, + Tensor values, Tensor offsets, int32_t max_sequence_length) { - int32_t total_L = embeddings.size(0); + int32_t total_L = values.size(0); ctx->save_for_backward({offsets}); ctx->saved_data["total_L"] = total_L; return {jagged_2d_to_dense_forward_cuda( - embeddings, offsets, max_sequence_length)}; + values, offsets, max_sequence_length)}; } static torch::autograd::variable_list backward( @@ -82,11 +82,11 @@ class Jagged2DToDenseGPUOp int32_t total_L = ctx->saved_data["total_L"].toInt(); using torch::autograd::Variable; - auto grad_padded_embeddings = grad_outputs[0]; - auto grad_embeddings = jagged_2d_to_dense_backward_cuda( - grad_padded_embeddings, offsets, total_L); + auto grad_padded_values = grad_outputs[0]; + auto grad_values = jagged_2d_to_dense_backward_cuda( + grad_padded_values, offsets, total_L); return { - grad_embeddings, + grad_values, Variable(), // offsets Variable() // max_sequence_length }; @@ -94,11 +94,11 @@ class Jagged2DToDenseGPUOp }; Tensor jagged_2d_to_dense_gpu( - Tensor embeddings, + Tensor values, Tensor offsets, int64_t max_sequence_length) { return Jagged2DToDenseGPUOp::apply( - embeddings, offsets, static_cast(max_sequence_length))[0]; + values, offsets, static_cast(max_sequence_length))[0]; } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 49000de0b3..3f7bf8c311 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -817,45 +817,45 @@ def test_jagged_2d_to_dense( lengths = torch.from_numpy(lengths_) offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - ref_embeddings = torch.rand(total_lengths, D) - ref_output_embeddings = var_list_to_coo( + ref_values = torch.rand(total_lengths, D) + ref_output_values = var_list_to_coo( lengths, - ref_embeddings, + ref_values, max_sequence_length, D, ).to_dense() # test cpu forward if is_half: - embeddings = ref_embeddings.clone().half().detach().requires_grad_(True) + values = ref_values.clone().half().detach().requires_grad_(True) else: - embeddings = ref_embeddings.clone().detach().requires_grad_(True) - output_embeddings = torch.ops.fbgemm.jagged_2d_to_dense( - embeddings=embeddings, + values = ref_values.clone().detach().requires_grad_(True) + output_values = torch.ops.fbgemm.jagged_2d_to_dense( + values=values, offsets=offsets, max_sequence_length=max_sequence_length, ) - torch.testing.assert_allclose(ref_output_embeddings, output_embeddings) + torch.testing.assert_allclose(ref_output_values, output_values) if torch.cuda.is_available(): # test gpu forward - ref_embeddings = ref_embeddings.cuda() + ref_values = ref_values.cuda() if is_half: - embeddings = ref_embeddings.clone().half().detach().requires_grad_(True) + values = ref_values.clone().half().detach().requires_grad_(True) else: - embeddings = ref_embeddings.clone().detach().requires_grad_(True) + values = ref_values.clone().detach().requires_grad_(True) offsets = offsets.cuda() - ref_output_embeddings = ref_output_embeddings.cuda() - output_embeddings = torch.ops.fbgemm.jagged_2d_to_dense( - embeddings=embeddings, + ref_output_values = ref_output_values.cuda() + output_values = torch.ops.fbgemm.jagged_2d_to_dense( + values=values, offsets=offsets, max_sequence_length=max_sequence_length, ) - torch.testing.assert_allclose(ref_output_embeddings, output_embeddings) + torch.testing.assert_allclose(ref_output_values, output_values) # test gpu backward - output_embeddings.backward(ref_output_embeddings) - torch.testing.assert_allclose(ref_embeddings, embeddings.grad) + output_values.backward(ref_output_values) + torch.testing.assert_allclose(ref_values, values.grad) def test_jagged_2d_to_dense_truncation(self) -> None: # Test the case where max_sequence_length < max(lengths[i]) @@ -866,42 +866,42 @@ def test_jagged_2d_to_dense_truncation(self) -> None: embedding_dim = 16 max_sequence_length = 2 - ref_embeddings = torch.rand(total_lengths, embedding_dim) - ref_output_embeddings = var_list_to_coo( + ref_values = torch.rand(total_lengths, embedding_dim) + ref_output_values = var_list_to_coo( lengths, - ref_embeddings, + ref_values, 3, embedding_dim, ).to_dense()[:, :max_sequence_length, :] # test cpu forward - embeddings = ref_embeddings.clone().detach().requires_grad_(True) - output_embeddings = torch.ops.fbgemm.jagged_2d_to_dense( - embeddings=embeddings, + values = ref_values.clone().detach().requires_grad_(True) + output_values = torch.ops.fbgemm.jagged_2d_to_dense( + values=values, offsets=offsets, max_sequence_length=max_sequence_length, ) - torch.testing.assert_allclose(ref_output_embeddings, output_embeddings) + torch.testing.assert_allclose(ref_output_values, output_values) if torch.cuda.is_available(): # test gpu forward - ref_embeddings = ref_embeddings.cuda() - embeddings = ref_embeddings.clone().detach().requires_grad_(True) + ref_values = ref_values.cuda() + values = ref_values.clone().detach().requires_grad_(True) offsets = offsets.cuda() - ref_output_embeddings = ref_output_embeddings.cuda() - output_embeddings = torch.ops.fbgemm.jagged_2d_to_dense( - embeddings=embeddings, + ref_output_values = ref_output_values.cuda() + output_values = torch.ops.fbgemm.jagged_2d_to_dense( + values=values, offsets=offsets, max_sequence_length=max_sequence_length, ) - torch.testing.assert_allclose(ref_output_embeddings, output_embeddings) + torch.testing.assert_allclose(ref_output_values, output_values) # test gpu backward - expected_grad = ref_embeddings + expected_grad = ref_values expected_grad[4, :] = 0 # due to truncation expected_grad = expected_grad.cuda() - output_embeddings.backward(ref_output_embeddings) - torch.testing.assert_allclose(expected_grad, embeddings.grad) + output_values.backward(ref_output_values) + torch.testing.assert_allclose(expected_grad, values.grad) @settings( verbosity=Verbosity.verbose,