Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] It's not clear how to call an advantage module with batched envs and pixel observations. #1522

Open
3 tasks done
skandermoalla opened this issue Sep 13, 2023 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@skandermoalla
Copy link
Contributor

skandermoalla commented Sep 13, 2023

Describe the bug

When you get a tensordict rollout of shape (N_envs, N_steps, C, H, W) out of a collector and you want to apply an advantage module that starts with conv2d layers:

  1. directly applying the module will crash with the conv2d layer complaining about the input size e.g. RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [2, 128, 4, 84, 84]
  2. flattening the tensordict first with rollout.reshape(-1) so that it has shape [B, C, H, W] and then calling the advantage module will run but issue the warning torchrl/objectives/value/advantages.py:99: UserWarning: Got a tensordict without a time-marked dimension, assuming time is along the last dimension. leaving you unsure of wether the advantages were computed correctly.

So it's not clear how one should proceed.

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@skandermoalla skandermoalla added the bug Something isn't working label Sep 13, 2023
@vmoens
Copy link
Contributor

vmoens commented Sep 14, 2023

Good point there
Regarding reshaping: you should reshape and refine_names, I believe the last dim will still be time-compliant (but you need to make sure you have truncated signals at the end of each time step)
Other than that, we could consider falling back on vmap / first-class-dimensions whenever this situation is encountered. I will give it a look and ping you once it's on its way, as usual.

@matteobettini
Copy link
Contributor

@vmoens in some cases the env data may have an arbitrary batch size (*B) before the time dimension.

Is the current approach, before we land smth like pytorch/tensordict#525, to try to flatten all these dims into one making sure to add terminations when doing so?

@vmoens
Copy link
Contributor

vmoens commented Sep 14, 2023

I don't think so, as I said in my answer the proper approach should be to vmap over the leading dims up to the time dim. Wdyt?

@skandermoalla
Copy link
Contributor Author

skandermoalla commented Sep 14, 2023

Somehow In the PPO example, the advantage module is called on the rollout batch shape

data = adv_module(data.to(model_device)).cpu()
and doesn't crash with the conv2d complaining.

def make_ppo_modules_pixels(proof_environment):

I also managed to reproduce this with the ConvNet and MLP modules of PyTorch RL and my advantage module now runs without reshaping.

I'm sending more details to compare the settings.

@skandermoalla
Copy link
Contributor Author

Okay, so the ConvNet of TorchRL actually flattens the batch before running a forward and then unflattens it back.

def forward(self, inputs: torch.Tensor) -> torch.Tensor:

Maybe this could be made clearer to the user so that when designing custom models they know that they have to do something similar.

Otherwise, vmaping would be the way to go. I'm just concerned about memory requirements compared to flattening the tensordict.

@vmoens
Copy link
Contributor

vmoens commented Mar 5, 2024

Otherwise, vmaping would be the way to go. I'm just concerned about memory requirements compared to flattening the tensordict.

@skandermoalla Looking back at this comment, I wonder why vmap should have higher mem requirements?

@skandermoalla
Copy link
Contributor Author

I'm not very familiar with vmap, but does the memory taken by the model weights stay the same when you vmap it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants