-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature,Doc] QValue refactoring and QNet + RNN tuto #1060
Conversation
Note: there is still some magic that needs to be done when we pass 2 consecutive trajectories to the RNN, but it can be handled without too much trouble (I think) EDIT: Problem solved, feature implemented!! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I left some questions!
- I like that the 2 hidden states are kept separate and I think that is the best choice
- is_init is used nicely in my opinion. We need a flag that needs when to reset and it does exactly that.
f"reset the noise at the beginning of a trajectory, without it " | ||
f"the behaviour of this exploration method is undefined. " | ||
f"This is allowed for BC compatibility purposes but it will be deprecated soon! " | ||
f"To create a 'step_count' entry, simply append a StepCounter " | ||
f"transform to your environment with `env = TransformedEnv(env, StepCounter())`." | ||
f"To create a 'step_count' entry, simply append an torchrl.envs.InitTracker " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace step_count with is_init?
.. note:: | ||
For a better integration with TorchRL's environments, the best naming | ||
for the output hidden key is ``("next", "hidden0")`` or similar, such | ||
that the hidden values are passed from step to step during a rollout. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we asking the keys like this?
In my mind i would have 3 key arguments:
- in_keys # in keys apart from hidden state (not constrained to 1)
- out_keys # out keys apart from hidden state (not constrained to 1)
- hidden_keys = (hidden0, hidden1) # this is by default like this and the init automatically adds these as input keys and the version of these with preappended next to the output keys
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod | ||
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) | ||
>>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) | ||
>>> lstm_module = LSTMModule(lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the comment above this becomes
LSTMModule(lstm, in_keys=["observation"], out_keys=["intermediate"])
fields={ | ||
hidden0: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False), | ||
hidden1: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, | ||
batch_size=torch.Size([]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it counterintuitive that
obs -> intermediate -> action are all at root level but the hidden state is in next?
What again is the reason for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the hidden is not used at time t but t+1 (obs, intermediate and action are for t and t only). It's probably a bit cleaner to have it in next as this is where it belongs...
so you pack hidden in next such that you find it back when calling step_mdp
Per se we could avoid that, since I think most calls to step_mdp have keep_other=True (it is the default, it is the case for env.rollout and it is the case for the collectors!).
But if someone by mistake calls step_mdp with keep_other=False she'll loose the hidden, and since we populate the hidden with 0s when not found no error will be raised (outch!).
So i'd rather avoid that scenario as much as I can!
Now i'm not against the idea of having it out of "next" if you find it's clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. So we put it in next because we know for sure that a step_mdp will happen.
I other words at time t we put the hidden state at time t+1
One might think that is more intuitive instead that at time t we get the state from time t-1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the algo yes but it's your choice...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to double-check, but I think RNNs are usually written as
y_t, s_t = rnn(x_t, s_{t-1})
in books/papers. I agree that perhaps prev_state, state
makes more sense than state, next_state
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's what papers say and how it aligns with the rest of the API.
ie. I'm not sure how we could mark the input as being "prev" and the current as being "current" if the only containers we have are "current" and "next". Maybe a proper doc that mirrors what you're saying here?
We can't escape from padding in rnns right? |
There is the packed_sequence artifact but i'm not entirely sure that it solves the problem. |
Awesome, really looking forward to this! I think it makes sense to write the recurrent state to Opinion: I think I would really prefer that |
If you want to store every
I agree it's not ideal. First, let's put the cards on the table. There are various things that are related but not exactly identical:
Gradient and exploration are currently handled via "exploration_type" and the torch "grad_mode" environment variables. If it's about naming i'm happy to reconsider the name oc. I'm actually looking forward to have more input on this from the community: what's the best way of approaching all of these "modes" that -- to me -- seem very algorithm-dependant.
Got it, it's a headache on our side too to let's just not use it. |
That tensordict format looks good to me. Yeah that's a good point with the modes. In that case it's better to err on the side of flexibility than simplicity. I think perhaps a small table in the docs mapping each of these modes to a 1-2 sentence description would help clear up any ambiguities. I do like the way exploration etc. is currently handled. Is nit: |
# Conflicts: # torchrl/modules/tensordict_module/actors.py # torchrl/modules/utils/utils.py
# Conflicts: # torchrl/data/replay_buffers/replay_buffers.py
@matteobettini @smorad
I have drafted a tutorial. It's very rough but it trains!
I have refactored the way we work with LSTMs.
The idea is to have module that has 3 in keys (an input and the hidden keys) and 3 output (one value and 2 hidden keys).
Questions for you:
tensordict["next", "hidden"]
or smth like that.Is it too complex? I could make it work with just ["hidden0", "hidden1"] though, but on the torchrl side it will be less clean. If you look at the example, the in_keys are
["hidden0", "hidden1"]
(current) and the out are[("next", "hidden0"), ("next", "hidden1")]
.