Skip to content

Commit

Permalink
address reviewer comments (#2815)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2815

# context
* use `at::parallel_for` to run a parallel threading in CPU kernel
* somehow the results are wrong.
```
(Pdb) outputs[0]
tensor([[-6.4628e-01, -1.9817e+00, -8.4945e-01,  1.3860e+00,  7.4463e-01,
          1.3079e-01,  8.5881e-01, -7.4804e-01],
        [-2.3989e-01,  1.2933e+00,  1.3789e+00, -1.9305e+00, -5.7734e-01,
         -4.5220e-01, -1.3703e+00, -1.9221e+00],
        [ 1.2582e+00,  1.2426e+00,  2.6749e-01,  6.8250e-01,  7.0065e-45,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-4.5382e-02, -9.4207e-01,  7.1254e-01,  7.8096e-01, -1.3482e+00,
         -1.2763e+00,  4.2996e-01, -8.9042e-01],
        [-5.8892e-02,  1.1909e+00, -1.4653e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0306e-01,  2.2235e-01,  1.6044e+00, -3.0457e-01,  1.6609e+00,
          4.1478e-43,  0.0000e+00,  0.0000e+00]],
       grad_fn=<PermuteMultiEmbeddingOp>>)
(Pdb) refs[0]
tensor([[-0.6463, -1.9817, -0.8494,  0.8588, -0.7480, -1.0166,  1.3860,  0.7446],
        [ 1.3572,  0.0403, -0.2078, -0.8592,  0.4000,  1.0562, -0.2399,  1.2933],
        [-0.8118, -0.8703, -0.1429, -0.0802,  0.2706, -0.6728,  1.2582,  1.2426],
        [ 0.0026, -1.3482, -1.2763,  0.9094,  1.2502,  0.5035, -0.0454, -0.9421],
        [-1.4653,  0.8384, -0.3290, -1.2008, -0.4272, -1.0376,  1.0920,  0.2197],
        [ 1.4819,  0.1565, -0.1601, -0.8323, -0.0130,  0.4165,  0.1031,  0.2223]],
       grad_fn=<CatBackward0>)
```

Reviewed By: sryap

Differential Revision: D38300272

fbshipit-source-id: 74546cf05b619ce8175915d21ba330fbfe7bd513
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 10, 2024
1 parent 1b63049 commit b903979
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 51 deletions.
10 changes: 10 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,14 @@ std::vector<Tensor> permute_multi_embedding_gpu(
const Tensor& out_shapes,
const std::vector<int64_t>& out_lengths,
const bool& reverse_permute);

enum PermuteParam {
in_tensor = 0,
out_tensor = 1,
in_offset = 2,
out_offset = 3,
length = 4,
next = 5,
};

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,21 @@ __global__ void permute_multi_embs_kernel(
}

// parse permutes
int32_t in_tensor, out_tensor, in_start, out_start, length, next;
int32_t in_tensor, out_tensor, in_offset, out_offset, length, next;
auto pp = permutes[permute_id];
if (reverse_permute) {
out_tensor = permutes[permute_id][0];
in_tensor = permutes[permute_id][1];
out_start = permutes[permute_id][2];
in_start = permutes[permute_id][3];
out_tensor = pp[PermuteParam::in_tensor];
in_tensor = pp[PermuteParam::out_tensor];
out_offset = pp[PermuteParam::in_offset];
in_offset = pp[PermuteParam::out_offset];
} else {
in_tensor = permutes[permute_id][0];
out_tensor = permutes[permute_id][1];
in_start = permutes[permute_id][2];
out_start = permutes[permute_id][3];
in_tensor = pp[PermuteParam::in_tensor];
out_tensor = pp[PermuteParam::out_tensor];
in_offset = pp[PermuteParam::in_offset];
out_offset = pp[PermuteParam::out_offset];
}
length = permutes[permute_id][4];
next = permutes[permute_id][5];
length = pp[PermuteParam::length];
next = pp[PermuteParam::next];

if (worker_id >= length) {
return;
Expand All @@ -73,47 +74,48 @@ __global__ void permute_multi_embs_kernel(

// locate the batch_id
int32_t in_length = in_lengths[in_tensor];
scalar_t* input_ptr = (scalar_t*)inputs[in_tensor];
input_ptr += batch_id * in_length;
scalar_t* input_ptr = const_cast<scalar_t*>(
reinterpret_cast<const scalar_t*>(inputs[in_tensor]) +
batch_id * in_length + in_offset);

int32_t out_length = out_lengths[out_tensor];
scalar_t* output_ptr = (scalar_t*)outputs[out_tensor];
output_ptr += batch_id * out_length;
scalar_t* output_ptr = const_cast<scalar_t*>(
reinterpret_cast<const scalar_t*>(outputs[out_tensor]) +
batch_id * out_length + out_offset);

if (fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(
&output_ptr[out_start]) &&
fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(
&input_ptr[in_start])) {
if (fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(output_ptr) &&
fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(input_ptr)) {
constexpr int32_t vec_size = 4;
const int32_t loop_end = round_down(length, vec_size);
for (int32_t i = worker_id * vec_size; i < loop_end;
i += blockDim.x * vec_size) {
fbgemm_gpu::Vec4T<scalar_t>::copy(
&input_ptr[in_start + i], &output_ptr[out_start + i]);
fbgemm_gpu::Vec4T<scalar_t>::copy(&input_ptr[i], &output_ptr[i]);
}
// Use elementwise access for the last incomplete vector.
for (int32_t i = loop_end + worker_id; i < length; i += blockDim.x) {
output_ptr[out_start + i] = input_ptr[in_start + i];
output_ptr[i] = input_ptr[i];
}
} else { // Fallback if not aligned.
for (int32_t i = worker_id; i < length; i += blockDim.x) {
output_ptr[out_start + i] = input_ptr[in_start + i];
output_ptr[i] = input_ptr[i];
}
}

// for reverse_permute (backward) with next
while (reverse_permute && next > 0 && next < permute_size) {
in_tensor = permutes[next][1];
in_start = permutes[next][3];
length = permutes[next][4];
next = -permutes[next][5];
auto pp = permutes[next];
in_tensor = pp[PermuteParam::out_tensor];
in_offset = pp[PermuteParam::out_offset];
length = pp[PermuteParam::length];
next = -pp[PermuteParam::next];

int32_t in_length = in_lengths[in_tensor];
scalar_t* input_ptr = (scalar_t*)inputs[in_tensor];
input_ptr += batch_id * in_length;
scalar_t* input_ptr = const_cast<scalar_t*>(
reinterpret_cast<const scalar_t*>(inputs[in_tensor]) +
batch_id * in_length + in_offset);

for (int32_t i = worker_id; i < length; i += blockDim.x) {
output_ptr[out_start + i] += input_ptr[in_start + i];
output_ptr[i] += input_ptr[i];
}
}
}
Expand Down Expand Up @@ -155,6 +157,13 @@ std::vector<Tensor> permute_multi_embedding_gpu(
const Tensor& out_shapes,
const std::vector<int64_t>& out_lengths,
const bool& reverse_permute) {
CUDA_DEVICE_GUARD(pooled_embs[0]);
TENSORS_ON_SAME_DEVICE(permutes, pooled_embs[0]);
TENSORS_ON_SAME_DEVICE(permutes, in_shapes);
TENSORS_ON_SAME_DEVICE(permutes, out_shapes);
TORCH_CHECK(in_shapes.is_contiguous());
TORCH_CHECK(out_shapes.is_contiguous());

int32_t num_of_input_tensors = in_shapes.size(0);
int32_t num_of_output_tensors = out_lengths.size();
int32_t batch_size = pooled_embs[0].size(0);
Expand All @@ -166,12 +175,8 @@ std::vector<Tensor> permute_multi_embedding_gpu(
for (int32_t i = 0; i < num_of_input_tensors; i++) {
Tensor cont_tensor = pooled_embs[i].contiguous();
inputs.push_back(cont_tensor);
TENSORS_ON_SAME_DEVICE(cont_tensor, pooled_embs[i]);
TENSORS_ON_SAME_DEVICE(pooled_embs[i], pooled_embs[0]);
CUDA_DEVICE_GUARD(cont_tensor);
TORCH_CHECK(cont_tensor.is_contiguous());
}
TORCH_CHECK(in_shapes.is_contiguous());
TORCH_CHECK(out_shapes.is_contiguous());

// initiate output tensors
std::vector<Tensor> outputs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,34 @@ std::vector<Tensor> permute_multi_embedding_cpu(
outputs.push_back(at::empty({B, out_lengths[i]}, pooled_embs[0].options()));
TORCH_CHECK(outputs[i].is_contiguous());
}
int32_t in_tensor, out_tensor, in_start, out_start, length, jump;
int32_t in_tensor, out_tensor, in_offset, out_offset, length, next;
for (const auto i : c10::irange(permutes.size(0))) {
auto pp = permutes[i];
if (reverse_permute) {
out_tensor = permutes[i][0].item<int32_t>();
in_tensor = permutes[i][1].item<int32_t>();
out_start = permutes[i][2].item<int32_t>();
in_start = permutes[i][3].item<int32_t>();
jump = permutes[i][5].item<int32_t>();
out_tensor = pp[PermuteParam::in_tensor].item<int32_t>();
in_tensor = pp[PermuteParam::out_tensor].item<int32_t>();
out_offset = pp[PermuteParam::in_offset].item<int32_t>();
in_offset = pp[PermuteParam::out_offset].item<int32_t>();
next = pp[PermuteParam::next].item<int32_t>();
} else {
in_tensor = permutes[i][0].item<int32_t>();
out_tensor = permutes[i][1].item<int32_t>();
in_start = permutes[i][2].item<int32_t>();
out_start = permutes[i][3].item<int32_t>();
in_tensor = pp[PermuteParam::in_tensor].item<int32_t>();
out_tensor = pp[PermuteParam::out_tensor].item<int32_t>();
in_offset = pp[PermuteParam::in_offset].item<int32_t>();
out_offset = pp[PermuteParam::out_offset].item<int32_t>();
}
length = permutes[i][4].item<int32_t>();
if (reverse_permute && jump < 0) {
if (reverse_permute && next < 0) {
for (auto b : c10::irange(B)) {
auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset;
auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset;
for (const auto j : c10::irange(length)) {
outputs[out_tensor][b][j + out_start] +=
inputs[in_tensor][b][j + in_start];
outp[j] += inp[j];
}
}
} else {
for (auto b : c10::irange(B)) {
auto outp = outputs[out_tensor][b].data_ptr<float>() + out_start;
auto inp = inputs[in_tensor][b].data_ptr<float>() + in_start;
auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset;
auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset;
std::memcpy(outp, inp, length * pooled_embs[0].itemsize());
}
}
Expand Down Expand Up @@ -128,7 +130,7 @@ std::vector<Tensor> permute_multi_embedding_meta(
/// column is the output tensor index. the third column is the feature's offset
/// of input tensor, and the fourth column is the feature's offset of output
/// tensor. the fifth column is the length of the feature in a permute, and the
/// last column is a jump flag.
/// last column is a next permute row to operate on (used in backward only).
/// @param in_shapes a 1D tensor with each element representing the length of an
/// input KT.
/// @param out_shapes a 1D tensor with each element representing the length of
Expand Down

0 comments on commit b903979

Please sign in to comment.