Skip to content

Commit

Permalink
Register meta dispatcher for permute_pooled_embedding (#1853)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1853

Torch dynamo needs the meta dispatcher for graph capturing.

Differential Revision: D47064694

fbshipit-source-id: b2bca4a7362b4dfd0937e9636a398e95a06726db
  • Loading branch information
kflu authored and facebook-github-bot committed Jun 27, 2023
1 parent 555ad07 commit f9b3af3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
3 changes: 3 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ inline bool torch_tensor_empty_or_on_cpu_check(
#define DISPATCH_TO_CPU(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(function)))

#define DISPATCH_TO_META(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::Meta, TORCH_FN(function)))

#define DISPATCH_TO_ALL(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CatchAll, TORCH_FN(function)))

Expand Down
12 changes: 12 additions & 0 deletions fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ Tensor permute_pooled_embs_auto_grad_cpu(
inv_offset_dim_list,
inv_permute_list);
}

Tensor permute_pooled_embs_auto_grad_meta(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return torch::empty_like(pooled_embs);
}
} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
Expand All @@ -152,4 +161,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
DISPATCH_TO_CUDA(
"permute_pooled_embs_auto_grad",
fbgemm_gpu::permute_pooled_embs_auto_grad_gpu);
DISPATCH_TO_META(
"permute_pooled_embs_auto_grad",
fbgemm_gpu::permute_pooled_embs_auto_grad_meta);
}
12 changes: 12 additions & 0 deletions fbgemm_gpu/test/permute_pooled_embedding_modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import fbgemm_gpu
import torch
import torch._dynamo
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from hypothesis import given, HealthCheck, settings
from torch import nn, Tensor
Expand Down Expand Up @@ -158,6 +159,17 @@ def test_pooled_table_batched_embedding(self) -> None:
ref_permuted_pooled_emb.to(self.device), permuted_pooled_emb
)

def test_permutation_autograd_meta(self) -> None:
"""
Test that permute_pooled_embeddings_autograd works with meta tensor and
dynamo export mode
"""
net = Net().to("meta")
input = torch.randn(2, 1).to("meta")
output = net(input)
assert input.shape == output.shape
torch._dynamo.export(net, input)


if __name__ == "__main__":
unittest.main()

0 comments on commit f9b3af3

Please sign in to comment.