Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix set_cuda arg mismatch issue #3089

Closed
wants to merge 1 commit into from

Commits on Sep 6, 2024

  1. fix set_cuda arg mismatch issue (pytorch#3089)

    Summary:
    X-link: facebookresearch/FBGEMM#181
    
    Pull Request resolved: pytorch#3089
    
    ```
    RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2203, in state_dict
        module.state_dict(
      [Previous line repeated 1 more time]
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2203, in state_dict
        module.state_dict(
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2199, in state_dict
        hook(self, prefix, keep_vars)
      [Previous line repeated 1 more time]
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding.py", line 605, in _pre_state_dict_hook
        lookup.flush()
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2199, in state_dict
        hook(self, prefix, keep_vars)
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding_lookup.py", line 350, in flush
        emb_module.flush()
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding.py", line 605, in _pre_state_dict_hook
        lookup.flush()
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/batched_embedding_kernel.py", line 812, in flush
        self.emb_module.flush()
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding_lookup.py", line 350, in flush
        emb_module.flush()
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/fbgemm_gpu/tbe/ssd/training.py", line 1548, in flush
        self.ssd_db.set_cuda(
    RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/batched_embedding_kernel.py", line 812, in flush
        self.emb_module.flush()
      File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/fbgemm_gpu/tbe/ssd/training.py", line 1548, in flush
        self.ssd_db.set_cuda(
    RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0
    RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0
    ```
    
    seemed to be from D60413462
    
    need to change PS path to accept is_bwd for set_cuda too.
    
    Differential Revision: D62247158
    Jianbo Liu authored and facebook-github-bot committed Sep 6, 2024
    Configuration menu
    Copy the full SHA
    e11a31d View commit details
    Browse the repository at this point in the history