From f9b3af3701dba74da4db193d837cf82f031e156c Mon Sep 17 00:00:00 2001 From: Kefei Lu Date: Tue, 27 Jun 2023 14:09:09 -0700 Subject: [PATCH] Register meta dispatcher for permute_pooled_embedding (#1853) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1853 Torch dynamo needs the meta dispatcher for graph capturing. Differential Revision: D47064694 fbshipit-source-id: b2bca4a7362b4dfd0937e9636a398e95a06726db --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h | 3 +++ fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp | 12 ++++++++++++ .../test/permute_pooled_embedding_modules_test.py | 12 ++++++++++++ 3 files changed, 27 insertions(+) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h index 8b9da8ab24..5aa812c624 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h @@ -104,6 +104,9 @@ inline bool torch_tensor_empty_or_on_cpu_check( #define DISPATCH_TO_CPU(name, function) \ m.impl(name, torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(function))) +#define DISPATCH_TO_META(name, function) \ + m.impl(name, torch::dispatch(c10::DispatchKey::Meta, TORCH_FN(function))) + #define DISPATCH_TO_ALL(name, function) \ m.impl(name, torch::dispatch(c10::DispatchKey::CatchAll, TORCH_FN(function))) diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp b/fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp index 757b884186..d2af89b4e3 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp @@ -138,6 +138,15 @@ Tensor permute_pooled_embs_auto_grad_cpu( inv_offset_dim_list, inv_permute_list); } + +Tensor permute_pooled_embs_auto_grad_meta( + const Tensor& pooled_embs, + const Tensor& offset_dim_list, + const Tensor& permute_list, + const Tensor& inv_offset_dim_list, + const Tensor& inv_permute_list) { + return torch::empty_like(pooled_embs); +} } // namespace fbgemm_gpu TORCH_LIBRARY_FRAGMENT(fbgemm, m) { @@ -152,4 +161,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CUDA( "permute_pooled_embs_auto_grad", fbgemm_gpu::permute_pooled_embs_auto_grad_gpu); + DISPATCH_TO_META( + "permute_pooled_embs_auto_grad", + fbgemm_gpu::permute_pooled_embs_auto_grad_meta); } diff --git a/fbgemm_gpu/test/permute_pooled_embedding_modules_test.py b/fbgemm_gpu/test/permute_pooled_embedding_modules_test.py index 9fc42fbce1..ae231b6df5 100644 --- a/fbgemm_gpu/test/permute_pooled_embedding_modules_test.py +++ b/fbgemm_gpu/test/permute_pooled_embedding_modules_test.py @@ -11,6 +11,7 @@ import fbgemm_gpu import torch +import torch._dynamo from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings from hypothesis import given, HealthCheck, settings from torch import nn, Tensor @@ -158,6 +159,17 @@ def test_pooled_table_batched_embedding(self) -> None: ref_permuted_pooled_emb.to(self.device), permuted_pooled_emb ) + def test_permutation_autograd_meta(self) -> None: + """ + Test that permute_pooled_embeddings_autograd works with meta tensor and + dynamo export mode + """ + net = Net().to("meta") + input = torch.randn(2, 1).to("meta") + output = net(input) + assert input.shape == output.shape + torch._dynamo.export(net, input) + if __name__ == "__main__": unittest.main()