From e11a31d2dfc06e378e5db6a3aa534499b83f6aec Mon Sep 17 00:00:00 2001 From: Jianbo Liu Date: Thu, 5 Sep 2024 17:37:44 -0700 Subject: [PATCH] fix set_cuda arg mismatch issue (#3089) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/181 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3089 ``` RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0 File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2203, in state_dict module.state_dict( [Previous line repeated 1 more time] File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2203, in state_dict module.state_dict( File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2199, in state_dict hook(self, prefix, keep_vars) [Previous line repeated 1 more time] File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding.py", line 605, in _pre_state_dict_hook lookup.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2199, in state_dict hook(self, prefix, keep_vars) File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding_lookup.py", line 350, in flush emb_module.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding.py", line 605, in _pre_state_dict_hook lookup.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/batched_embedding_kernel.py", line 812, in flush self.emb_module.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding_lookup.py", line 350, in flush emb_module.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/fbgemm_gpu/tbe/ssd/training.py", line 1548, in flush self.ssd_db.set_cuda( RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0 File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/batched_embedding_kernel.py", line 812, in flush self.emb_module.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/fbgemm_gpu/tbe/ssd/training.py", line 1548, in flush self.ssd_db.set_cuda( RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0 RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0 ``` seemed to be from D60413462 need to change PS path to accept is_bwd for set_cuda too. Differential Revision: D62247158 --- .../ps_split_table_batched_embeddings.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_split_table_batched_embeddings.cpp index 692030c40f..395062a2b9 100644 --- a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_split_table_batched_embeddings.cpp @@ -43,17 +43,21 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder { max_D); } - void - set_cuda(Tensor indices, Tensor weights, Tensor count, int64_t timestep) { - return impl_->set_cuda(indices, weights, count, timestep); + void set_cuda( + Tensor indices, + Tensor weights, + Tensor count, + int64_t timestep, + bool is_bwd = false) { + return impl_->set_cuda(indices, weights, count, timestep, is_bwd); } void get_cuda(Tensor indices, Tensor weights, Tensor count) { return impl_->get_cuda(indices, weights, count); } - void set(Tensor indices, Tensor weights, Tensor count) { - return impl_->set(indices, weights, count); + void set(Tensor indices, Tensor weights, Tensor count, bool is_bwd = false) { + return impl_->set(indices, weights, count, is_bwd); } void get(Tensor indices, Tensor weights, Tensor count) {