Skip to content

Commit

Permalink
refactor grad output non-contiguous handler
Browse files Browse the repository at this point in the history
Summary:
This is a follow-up on D37951520 (5a15342)

- Minor clean-up and refactoring for non-contiguous grad output.
- Add more comments.
- Add unit test coverage

TODO: add the 16 alignment unit test coverage.

Differential Revision: D37988742

fbshipit-source-id: e0f6a565ec22fcf9a1135053d708dca1b33688db
  • Loading branch information
jianyuh authored and facebook-github-bot committed Jul 20, 2022
1 parent 5a15342 commit 30ad4b7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
16 changes: 10 additions & 6 deletions fbgemm_gpu/codegen/embedding_backward_dense_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,12 @@ class SplitLookupFunction_Dense_Op
using torch::autograd::Variable;

auto grad_output = grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {

// FIXME: to support aligned memory access in Vec4T load/store function
// 16 for FP32 and 8 for FP16
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0) {
grad_output = at::empty_like(grad_output).copy_(grad_output);
} else if (!grad_output.is_contiguous()) {
grad_output = grad_output.contiguous();
}

Expand Down Expand Up @@ -324,12 +328,12 @@ class SplitNoBagLookupFunction_Dense_Op
using torch::autograd::Variable;

auto grad_output = grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
// FIXME: to support aligned memory access in Vec4T load/store function
// 16 for FP32 and 8 for FP16
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0) {
grad_output = at::empty_like(grad_output).copy_(grad_output);
} else if (!grad_output.is_contiguous()) {
grad_output = grad_output.contiguous();
}

auto grad_dev_weights =
Expand Down
8 changes: 4 additions & 4 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op :
using torch::autograd::Variable;

auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0];
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0 ||
grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) {
grad_output = grad_output.contiguous();
}
// FIXME: to support aligned memory access in Vec4T load/store function
// 16 for FP32 and 8 for FP16
if (reinterpret_cast<uint64_t>(grad_output.data_ptr()) % 16 != 0) {
grad_output = at::empty_like(grad_output).copy_(grad_output);
} else if (!grad_output.is_contiguous()) {
grad_output = grad_output.contiguous();
}

{% if not nobag %}
Expand Down
12 changes: 6 additions & 6 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,9 +1324,9 @@ def test_backward_dense(
rtol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5,
)
if do_pooling:
goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous()
goc = torch.cat([go.view(B, -1) for go in gos], dim=1)
else:
goc = torch.cat(gos, dim=0).contiguous()
goc = torch.cat(gos, dim=0)
fc2.backward(goc)
torch.testing.assert_close(
cc.weights.grad,
Expand Down Expand Up @@ -1584,9 +1584,9 @@ def test_backward_sgd( # noqa C901
else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu))
)
if do_pooling:
goc = torch.cat([go.view(B, -1) for go in gos], dim=1).contiguous()
goc = torch.cat([go.view(B, -1) for go in gos], dim=1)
else:
goc = torch.cat(gos, dim=0).contiguous()
goc = torch.cat(gos, dim=0)
fc2.backward(goc)
if use_cache:
cc.flush()
Expand Down Expand Up @@ -1817,7 +1817,7 @@ def execute_backward_adagrad_( # noqa C901
if do_pooling:
goc = torch.cat([go.view(B, -1) for go in gos], dim=1)
else:
goc = torch.cat(gos, dim=0).contiguous()
goc = torch.cat(gos, dim=0)
fc2.backward(goc)
cc.flush()
split_optimizer_states = [s for (s,) in cc.split_optimizer_states()]
Expand Down Expand Up @@ -2637,7 +2637,7 @@ def execute_backward_optimizers_( # noqa C901
if do_pooling:
goc = torch.cat([go.view(B, -1) for go in gos], dim=1)
else:
goc = torch.cat(gos, dim=0).contiguous()
goc = torch.cat(gos, dim=0)
fc2.backward(goc)
cc.flush()

Expand Down

0 comments on commit 30ad4b7

Please sign in to comment.