Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix set_cuda arg mismatch issue (#3089)
Summary: X-link: facebookresearch/FBGEMM#181 Pull Request resolved: #3089 ``` RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0 File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2203, in state_dict module.state_dict( [Previous line repeated 1 more time] File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2203, in state_dict module.state_dict( File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2199, in state_dict hook(self, prefix, keep_vars) [Previous line repeated 1 more time] File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding.py", line 605, in _pre_state_dict_hook lookup.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2199, in state_dict hook(self, prefix, keep_vars) File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding_lookup.py", line 350, in flush emb_module.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding.py", line 605, in _pre_state_dict_hook lookup.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/batched_embedding_kernel.py", line 812, in flush self.emb_module.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding_lookup.py", line 350, in flush emb_module.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/fbgemm_gpu/tbe/ssd/training.py", line 1548, in flush self.ssd_db.set_cuda( RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0 File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/batched_embedding_kernel.py", line 812, in flush self.emb_module.flush() File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/fbgemm_gpu/tbe/ssd/training.py", line 1548, in flush self.ssd_db.set_cuda( RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0 RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0 ``` seemed to be from D60413462 need to change PS path to accept is_bwd for set_cuda too. Differential Revision: D62247158
- Loading branch information