Skip to content

Commit

Permalink
fix set_cuda arg mismatch issue
Browse files Browse the repository at this point in the history
Summary:
```
RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2203, in state_dict
    module.state_dict(
  [Previous line repeated 1 more time]
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2203, in state_dict
    module.state_dict(
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2199, in state_dict
    hook(self, prefix, keep_vars)
  [Previous line repeated 1 more time]
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding.py", line 605, in _pre_state_dict_hook
    lookup.flush()
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torch/nn/modules/module.py", line 2199, in state_dict
    hook(self, prefix, keep_vars)
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding_lookup.py", line 350, in flush
    emb_module.flush()
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding.py", line 605, in _pre_state_dict_hook
    lookup.flush()
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/batched_embedding_kernel.py", line 812, in flush
    self.emb_module.flush()
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/embedding_lookup.py", line 350, in flush
    emb_module.flush()
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/fbgemm_gpu/tbe/ssd/training.py", line 1548, in flush
    self.ssd_db.set_cuda(
RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/torchrec/distributed/batched_embedding_kernel.py", line 812, in flush
    self.emb_module.flush()
  File "/mnt/xarfuse/uid-434849/4ef91778-seed-nspid4026531836_cgpid144089292-ns-4026531840/fbgemm_gpu/tbe/ssd/training.py", line 1548, in flush
    self.ssd_db.set_cuda(
RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0
RuntimeError: set_cuda() expected at most 5 argument(s) but received 6 argument(s). Declaration: set_cuda(__torch__.torch.classes.fbgemm.EmbeddingParameterServerWrapper _0, Tensor _1, Tensor _2, Tensor _3, int _4) -> NoneType _0
```

seemed to be from D60413462

need to change PS path to accept is_bwd for set_cuda too.

Differential Revision: D62247158
  • Loading branch information
Jianbo Liu authored and facebook-github-bot committed Sep 5, 2024
1 parent 53d84ad commit c71dfad
Showing 1 changed file with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,21 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
max_D);
}

void
set_cuda(Tensor indices, Tensor weights, Tensor count, int64_t timestep) {
return impl_->set_cuda(indices, weights, count, timestep);
void set_cuda(
Tensor indices,
Tensor weights,
Tensor count,
int64_t timestep,
bool is_bwd = false) {
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
}

void get_cuda(Tensor indices, Tensor weights, Tensor count) {
return impl_->get_cuda(indices, weights, count);
}

void set(Tensor indices, Tensor weights, Tensor count) {
return impl_->set(indices, weights, count);
void set(Tensor indices, Tensor weights, Tensor count, bool is_bwd = false) {
return impl_->set(indices, weights, count, is_bwd);
}

void get(Tensor indices, Tensor weights, Tensor count) {
Expand Down

0 comments on commit c71dfad

Please sign in to comment.