Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Loading losses with modules that have no parameters #1593

Closed
matteobettini opened this issue Oct 2, 2023 · 4 comments · Fixed by pytorch/tensordict#650
Closed

[BUG] Loading losses with modules that have no parameters #1593

matteobettini opened this issue Oct 2, 2023 · 4 comments · Fixed by pytorch/tensordict#650
Assignees
Labels
bug Something isn't working

Comments

@matteobettini
Copy link
Contributor

matteobettini commented Oct 2, 2023

When loading a loss that has a neural network with no parameters, the reloading fails

  model = torch.nn.Tanh() # does not work
  # model = torch.nn.Linear(1, 1) works
  value = QValueActor(module=model, in_keys="obs", action_space="one_hot")
  loss = DQNLoss(value_network=model, action_space="one_hot")
  state = loss.state_dict()

  loss = DQNLoss(value_network=model, action_space="one_hot")
  loss.load_state_dict(state)
Traceback (most recent call last):
  File "/Users/matbet/PycharmProjects/rl/prova.py", line 16, in <module>
    loss.load_state_dict(state)
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2027, in load_state_dict
    load(self, state_dict)
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2015, in load
    load(child, child_state_dict, child_prefix)
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2009, in load
    module._load_from_state_dict(
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/nn/params.py", line 792, in _load_from_state_dict
    self.data.load_state_dict(data)
  File "/Users/matbet/PycharmProjects/tensordict/tensordict/tensordict.py", line 834, in load_state_dict
    raise RuntimeError(
RuntimeError: Cannot load state-dict because the key sets don't match: got state_dict extra keys 
set()
 and tensordict extra keys
{'module'}

an example use case is the VDN module in MARL which is just a sum of the input and will cause this in the QMixerLoss

@matteobettini matteobettini added the bug Something isn't working label Oct 2, 2023
@vmoens
Copy link
Contributor

vmoens commented Oct 4, 2023

Thanks
I think moving to torch.func.functional_call will solve this issue. For this, pytorch/tensordict#526 needs to be mature

@vmoens
Copy link
Contributor

vmoens commented Feb 1, 2024

I would suggest to use

import tensordict
sd = tensordict.TensorDict.from_module(loss)
sd.to_module(loss)

@matteobettini
Copy link
Contributor Author

I see, thanks.

For BC-compatibility and interchangability with other components I still need to use the state_dict() interface though.

@vmoens
Copy link
Contributor

vmoens commented Feb 1, 2024

yes i'm on it, but it's more efficient, faster and safer to serialize with tensordict

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants