Skip to content

Commit

Permalink
Fix CUDA error when grad_output is contiguous but address is not 16 b…
Browse files Browse the repository at this point in the history
…ytes aligned (#1212)

Summary:
Pull Request resolved: #1212

As title

Reviewed By: divchenko

Differential Revision: D37951520

fbshipit-source-id: a2c2ab57bb13ed750e986b0326566a8f0a8ea3ae
  • Loading branch information
xing-liu authored and facebook-github-bot committed Jul 19, 2022
1 parent 7d59e80 commit 5a15342
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_dense_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ class SplitNoBagLookupFunction_Dense_Op
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0) {
grad_output = at::empty_like(grad_output).copy_(grad_output);
}

auto grad_dev_weights =
split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda(
Expand Down
6 changes: 4 additions & 2 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,12 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op :

auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 ||
grad_output.stride(0) % 4 != 0) {
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0) {
grad_output = at::empty_like(grad_output).copy_(grad_output);
}

{% if not nobag %}
if (!indice_weights.defined()) {
Expand Down

0 comments on commit 5a15342

Please sign in to comment.