From b903979543526edbcde7b9fa512dbe3244f494e9 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Wed, 10 Jul 2024 08:20:05 -0700 Subject: [PATCH] address reviewer comments (#2815) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2815 # context * use `at::parallel_for` to run a parallel threading in CPU kernel * somehow the results are wrong. ``` (Pdb) outputs[0] tensor([[-6.4628e-01, -1.9817e+00, -8.4945e-01, 1.3860e+00, 7.4463e-01, 1.3079e-01, 8.5881e-01, -7.4804e-01], [-2.3989e-01, 1.2933e+00, 1.3789e+00, -1.9305e+00, -5.7734e-01, -4.5220e-01, -1.3703e+00, -1.9221e+00], [ 1.2582e+00, 1.2426e+00, 2.6749e-01, 6.8250e-01, 7.0065e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00], [-4.5382e-02, -9.4207e-01, 7.1254e-01, 7.8096e-01, -1.3482e+00, -1.2763e+00, 4.2996e-01, -8.9042e-01], [-5.8892e-02, 1.1909e+00, -1.4653e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [ 1.0306e-01, 2.2235e-01, 1.6044e+00, -3.0457e-01, 1.6609e+00, 4.1478e-43, 0.0000e+00, 0.0000e+00]], grad_fn=>) (Pdb) refs[0] tensor([[-0.6463, -1.9817, -0.8494, 0.8588, -0.7480, -1.0166, 1.3860, 0.7446], [ 1.3572, 0.0403, -0.2078, -0.8592, 0.4000, 1.0562, -0.2399, 1.2933], [-0.8118, -0.8703, -0.1429, -0.0802, 0.2706, -0.6728, 1.2582, 1.2426], [ 0.0026, -1.3482, -1.2763, 0.9094, 1.2502, 0.5035, -0.0454, -0.9421], [-1.4653, 0.8384, -0.3290, -1.2008, -0.4272, -1.0376, 1.0920, 0.2197], [ 1.4819, 0.1565, -0.1601, -0.8323, -0.0130, 0.4165, 0.1031, 0.2223]], grad_fn=) ``` Reviewed By: sryap Differential Revision: D38300272 fbshipit-source-id: 74546cf05b619ce8175915d21ba330fbfe7bd513 --- .../permute_multi_embedding_function.h | 10 +++ .../permute_multi_embedding_ops.cu | 75 ++++++++++--------- .../permute_multi_embedding_ops_cpu.cpp | 34 +++++---- 3 files changed, 68 insertions(+), 51 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h index fbda97d47a..0653554c60 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h @@ -66,4 +66,14 @@ std::vector permute_multi_embedding_gpu( const Tensor& out_shapes, const std::vector& out_lengths, const bool& reverse_permute); + +enum PermuteParam { + in_tensor = 0, + out_tensor = 1, + in_offset = 2, + out_offset = 3, + length = 4, + next = 5, +}; + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu index c5ea34f5e2..569d38469f 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu @@ -49,20 +49,21 @@ __global__ void permute_multi_embs_kernel( } // parse permutes - int32_t in_tensor, out_tensor, in_start, out_start, length, next; + int32_t in_tensor, out_tensor, in_offset, out_offset, length, next; + auto pp = permutes[permute_id]; if (reverse_permute) { - out_tensor = permutes[permute_id][0]; - in_tensor = permutes[permute_id][1]; - out_start = permutes[permute_id][2]; - in_start = permutes[permute_id][3]; + out_tensor = pp[PermuteParam::in_tensor]; + in_tensor = pp[PermuteParam::out_tensor]; + out_offset = pp[PermuteParam::in_offset]; + in_offset = pp[PermuteParam::out_offset]; } else { - in_tensor = permutes[permute_id][0]; - out_tensor = permutes[permute_id][1]; - in_start = permutes[permute_id][2]; - out_start = permutes[permute_id][3]; + in_tensor = pp[PermuteParam::in_tensor]; + out_tensor = pp[PermuteParam::out_tensor]; + in_offset = pp[PermuteParam::in_offset]; + out_offset = pp[PermuteParam::out_offset]; } - length = permutes[permute_id][4]; - next = permutes[permute_id][5]; + length = pp[PermuteParam::length]; + next = pp[PermuteParam::next]; if (worker_id >= length) { return; @@ -73,47 +74,48 @@ __global__ void permute_multi_embs_kernel( // locate the batch_id int32_t in_length = in_lengths[in_tensor]; - scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; - input_ptr += batch_id * in_length; + scalar_t* input_ptr = const_cast( + reinterpret_cast(inputs[in_tensor]) + + batch_id * in_length + in_offset); int32_t out_length = out_lengths[out_tensor]; - scalar_t* output_ptr = (scalar_t*)outputs[out_tensor]; - output_ptr += batch_id * out_length; + scalar_t* output_ptr = const_cast( + reinterpret_cast(outputs[out_tensor]) + + batch_id * out_length + out_offset); - if (fbgemm_gpu::is_aligned>( - &output_ptr[out_start]) && - fbgemm_gpu::is_aligned>( - &input_ptr[in_start])) { + if (fbgemm_gpu::is_aligned>(output_ptr) && + fbgemm_gpu::is_aligned>(input_ptr)) { constexpr int32_t vec_size = 4; const int32_t loop_end = round_down(length, vec_size); for (int32_t i = worker_id * vec_size; i < loop_end; i += blockDim.x * vec_size) { - fbgemm_gpu::Vec4T::copy( - &input_ptr[in_start + i], &output_ptr[out_start + i]); + fbgemm_gpu::Vec4T::copy(&input_ptr[i], &output_ptr[i]); } // Use elementwise access for the last incomplete vector. for (int32_t i = loop_end + worker_id; i < length; i += blockDim.x) { - output_ptr[out_start + i] = input_ptr[in_start + i]; + output_ptr[i] = input_ptr[i]; } } else { // Fallback if not aligned. for (int32_t i = worker_id; i < length; i += blockDim.x) { - output_ptr[out_start + i] = input_ptr[in_start + i]; + output_ptr[i] = input_ptr[i]; } } // for reverse_permute (backward) with next while (reverse_permute && next > 0 && next < permute_size) { - in_tensor = permutes[next][1]; - in_start = permutes[next][3]; - length = permutes[next][4]; - next = -permutes[next][5]; + auto pp = permutes[next]; + in_tensor = pp[PermuteParam::out_tensor]; + in_offset = pp[PermuteParam::out_offset]; + length = pp[PermuteParam::length]; + next = -pp[PermuteParam::next]; int32_t in_length = in_lengths[in_tensor]; - scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; - input_ptr += batch_id * in_length; + scalar_t* input_ptr = const_cast( + reinterpret_cast(inputs[in_tensor]) + + batch_id * in_length + in_offset); for (int32_t i = worker_id; i < length; i += blockDim.x) { - output_ptr[out_start + i] += input_ptr[in_start + i]; + output_ptr[i] += input_ptr[i]; } } } @@ -155,6 +157,13 @@ std::vector permute_multi_embedding_gpu( const Tensor& out_shapes, const std::vector& out_lengths, const bool& reverse_permute) { + CUDA_DEVICE_GUARD(pooled_embs[0]); + TENSORS_ON_SAME_DEVICE(permutes, pooled_embs[0]); + TENSORS_ON_SAME_DEVICE(permutes, in_shapes); + TENSORS_ON_SAME_DEVICE(permutes, out_shapes); + TORCH_CHECK(in_shapes.is_contiguous()); + TORCH_CHECK(out_shapes.is_contiguous()); + int32_t num_of_input_tensors = in_shapes.size(0); int32_t num_of_output_tensors = out_lengths.size(); int32_t batch_size = pooled_embs[0].size(0); @@ -166,12 +175,8 @@ std::vector permute_multi_embedding_gpu( for (int32_t i = 0; i < num_of_input_tensors; i++) { Tensor cont_tensor = pooled_embs[i].contiguous(); inputs.push_back(cont_tensor); - TENSORS_ON_SAME_DEVICE(cont_tensor, pooled_embs[i]); - TENSORS_ON_SAME_DEVICE(pooled_embs[i], pooled_embs[0]); - CUDA_DEVICE_GUARD(cont_tensor); + TORCH_CHECK(cont_tensor.is_contiguous()); } - TORCH_CHECK(in_shapes.is_contiguous()); - TORCH_CHECK(out_shapes.is_contiguous()); // initiate output tensors std::vector outputs; diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp index a0225e743a..80a70a874c 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp @@ -37,32 +37,34 @@ std::vector permute_multi_embedding_cpu( outputs.push_back(at::empty({B, out_lengths[i]}, pooled_embs[0].options())); TORCH_CHECK(outputs[i].is_contiguous()); } - int32_t in_tensor, out_tensor, in_start, out_start, length, jump; + int32_t in_tensor, out_tensor, in_offset, out_offset, length, next; for (const auto i : c10::irange(permutes.size(0))) { + auto pp = permutes[i]; if (reverse_permute) { - out_tensor = permutes[i][0].item(); - in_tensor = permutes[i][1].item(); - out_start = permutes[i][2].item(); - in_start = permutes[i][3].item(); - jump = permutes[i][5].item(); + out_tensor = pp[PermuteParam::in_tensor].item(); + in_tensor = pp[PermuteParam::out_tensor].item(); + out_offset = pp[PermuteParam::in_offset].item(); + in_offset = pp[PermuteParam::out_offset].item(); + next = pp[PermuteParam::next].item(); } else { - in_tensor = permutes[i][0].item(); - out_tensor = permutes[i][1].item(); - in_start = permutes[i][2].item(); - out_start = permutes[i][3].item(); + in_tensor = pp[PermuteParam::in_tensor].item(); + out_tensor = pp[PermuteParam::out_tensor].item(); + in_offset = pp[PermuteParam::in_offset].item(); + out_offset = pp[PermuteParam::out_offset].item(); } length = permutes[i][4].item(); - if (reverse_permute && jump < 0) { + if (reverse_permute && next < 0) { for (auto b : c10::irange(B)) { + auto outp = outputs[out_tensor][b].data_ptr() + out_offset; + auto inp = inputs[in_tensor][b].data_ptr() + in_offset; for (const auto j : c10::irange(length)) { - outputs[out_tensor][b][j + out_start] += - inputs[in_tensor][b][j + in_start]; + outp[j] += inp[j]; } } } else { for (auto b : c10::irange(B)) { - auto outp = outputs[out_tensor][b].data_ptr() + out_start; - auto inp = inputs[in_tensor][b].data_ptr() + in_start; + auto outp = outputs[out_tensor][b].data_ptr() + out_offset; + auto inp = inputs[in_tensor][b].data_ptr() + in_offset; std::memcpy(outp, inp, length * pooled_embs[0].itemsize()); } } @@ -128,7 +130,7 @@ std::vector permute_multi_embedding_meta( /// column is the output tensor index. the third column is the feature's offset /// of input tensor, and the fourth column is the feature's offset of output /// tensor. the fifth column is the length of the feature in a permute, and the -/// last column is a jump flag. +/// last column is a next permute row to operate on (used in backward only). /// @param in_shapes a 1D tensor with each element representing the length of an /// input KT. /// @param out_shapes a 1D tensor with each element representing the length of