Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature,Doc] QValue refactoring and QNet + RNN tuto #1060

Merged
merged 30 commits into from
Apr 28, 2023
Merged

[Feature,Doc] QValue refactoring and QNet + RNN tuto #1060

merged 30 commits into from
Apr 28, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Apr 14, 2023

@matteobettini @smorad

I have drafted a tutorial. It's very rough but it trains!

I have refactored the way we work with LSTMs.
The idea is to have module that has 3 in keys (an input and the hidden keys) and 3 output (one value and 2 hidden keys).

Questions for you:

  • We could by default use the same key but the problem here is that we want to write the "next" hidden state in tensordict["next", "hidden"] or smth like that.
    Is it too complex? I could make it work with just ["hidden0", "hidden1"] though, but on the torchrl side it will be less clean. If you look at the example, the in_keys are ["hidden0", "hidden1"] (current) and the out are [("next", "hidden0"), ("next", "hidden1")].
  • Any comment on the use of "is_init"?
  • The current implementation allows you to use the same lstm for executing the policy and for batched exection of the lstm. The default is step-by-step. I think the docstring is informative regarding this feature. Does this make sense? The switch from one behaviour to the other is a bit clunky but specifying this in the constructor doen't really solve our problem as you'll need to call the constructor twice (which I think is worse than calling a method to convert the module).
  • One thing it cannot do is this: if your loop is a bit more complicated (e.g. it takes the previous action + current observation as input) and if you want to backprop through the chain of events it will not work as intended, but for these kind of things you'll need a custom loop anyway I guess

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 14, 2023
@vmoens vmoens added documentation Improvements or additions to documentation enhancement New feature or request labels Apr 21, 2023
@vmoens
Copy link
Contributor Author

vmoens commented Apr 21, 2023

Note: there is still some magic that needs to be done when we pass 2 consecutive trajectories to the RNN, but it can be handled without too much trouble (I think)
@matteobettini I will do that with padding. Don't tell me, I know!

EDIT: Problem solved, feature implemented!!

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I left some questions!

  • I like that the 2 hidden states are kept separate and I think that is the best choice
  • is_init is used nicely in my opinion. We need a flag that needs when to reset and it does exactly that.

f"reset the noise at the beginning of a trajectory, without it "
f"the behaviour of this exploration method is undefined. "
f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
f"To create a 'step_count' entry, simply append a StepCounter "
f"transform to your environment with `env = TransformedEnv(env, StepCounter())`."
f"To create a 'step_count' entry, simply append an torchrl.envs.InitTracker "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace step_count with is_init?

.. note::
For a better integration with TorchRL's environments, the best naming
for the output hidden key is ``("next", "hidden0")`` or similar, such
that the hidden values are passed from step to step during a rollout.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we asking the keys like this?

In my mind i would have 3 key arguments:

  • in_keys # in keys apart from hidden state (not constrained to 1)
  • out_keys # out keys apart from hidden state (not constrained to 1)
  • hidden_keys = (hidden0, hidden1) # this is by default like this and the init automatically adds these as input keys and the version of these with preappended next to the output keys

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True)
>>> lstm_module = LSTMModule(lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the comment above this becomes
LSTMModule(lstm, in_keys=["observation"], out_keys=["intermediate"])

fields={
hidden0: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
hidden1: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it counterintuitive that
obs -> intermediate -> action are all at root level but the hidden state is in next?

What again is the reason for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the hidden is not used at time t but t+1 (obs, intermediate and action are for t and t only). It's probably a bit cleaner to have it in next as this is where it belongs...

so you pack hidden in next such that you find it back when calling step_mdp
Per se we could avoid that, since I think most calls to step_mdp have keep_other=True (it is the default, it is the case for env.rollout and it is the case for the collectors!).
But if someone by mistake calls step_mdp with keep_other=False she'll loose the hidden, and since we populate the hidden with 0s when not found no error will be raised (outch!).

So i'd rather avoid that scenario as much as I can!

Now i'm not against the idea of having it out of "next" if you find it's clearer.

Copy link
Contributor

@matteobettini matteobettini Apr 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So we put it in next because we know for sure that a step_mdp will happen.

I other words at time t we put the hidden state at time t+1
One might think that is more intuitive instead that at time t we get the state from time t-1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the algo yes but it's your choice...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to double-check, but I think RNNs are usually written as

y_t, s_t = rnn(x_t, s_{t-1})

in books/papers. I agree that perhaps prev_state, state makes more sense than state, next_state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's what papers say and how it aligns with the rest of the API.
ie. I'm not sure how we could mark the input as being "prev" and the current as being "current" if the only containers we have are "current" and "next". Maybe a proper doc that mirrors what you're saying here?

@matteobettini
Copy link
Contributor

EDIT: Problem solved, feature implemented!!

We can't escape from padding in rnns right?

@vmoens
Copy link
Contributor Author

vmoens commented Apr 21, 2023

We can't escape from padding in rnns right?

There is the packed_sequence artifact but i'm not entirely sure that it solves the problem.
My implementation of padding / unpadding is 2x faster than other torch tools btw

@smorad
Copy link
Contributor

smorad commented Apr 21, 2023

Awesome, really looking forward to this!

I think it makes sense to write the recurrent state to next. It just might make things tricky with when/who is calling step_mdp. Are there any footguns here, e.g. collector calling step_mdp and reversing state, (next, state) before placing them into the replay buffer?

Opinion: I think set_temporal_mode could be a little confusing. Is this different than inference (collector) mode, or just a synonym for inference? If they are indeed the same, maybe it makes sense to have if self.training: ... in the module forward pass instead of an explicit time-batch/single-step mode.

I would really prefer that packed_sequence is not used here. It's brought me nothing but pain.

@vmoens
Copy link
Contributor Author

vmoens commented Apr 22, 2023

Awesome, really looking forward to this!

I think it makes sense to write the recurrent state to next. It just might make things tricky with when/who is calling step_mdp. Are there any footguns here, e.g. collector calling step_mdp and reversing state, (next, state) before placing them into the replay buffer?

If you want to store every h_{t-1}, s_t, a_t, h_t, s_{t+1} following the nomenclature above, this will correspond to

TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([4, 16]), device=cpu, dtype=torch.int64, is_shared=False),
        action: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        observation: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                hidden: Tensor(shape=torch.Size([4, 16]), device=cpu, dtype=torch.int64, is_shared=False),
                observation: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.bool, is_shared=False),
            batch_size=torch.Size([4]),
            device=cpu,
            is_shared=False),
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

Opinion: I think set_temporal_mode could be a little confusing. Is this different than inference (collector) mode, or just a synonym for inference? If they are indeed the same, maybe it makes sense to have if self.training: ... in the module forward pass instead of an explicit time-batch/single-step mode.

I agree it's not ideal.
First, we had a similar discussion here if your want to have a look.

First, let's put the cards on the table. There are various things that are related but not exactly identical:

  • Train / eval mode
  • Explorative / exploitative mode
  • temporal / static mode (eg example here)
  • grad on / off mode (eg example here)
    When doing inference and collecting data for training you will be in
    eval + explorative + static + grad off
    When testing the policy in "real" setting you will be
    eval + exploitative + static + grad off
    When training the policy you will be in
    train + exploitative* + temporal** + grad on***
  • not always true, sometimes you want to sample the action, sometimes not
    ** not always true: in dreamer you simulate trajectories during training that are fully reparametrized, and you're in static mode bc the RNN feeds itself with the action that results from the previous call.
    *** depending on the loss (policy / critic / model in MBRL) you may want to keep some ops out of the graph or just detach some variables (not talking about meta-RL!)

Gradient and exploration are currently handled via "exploration_type" and the torch "grad_mode" environment variables.
Train and eval are handled via the module attribute. I hesitate for temporal, but I can see both cases where you may want to use static / temporal at both training and inference time. I thought that giving the user full power over what happens internally was the easiest option to avoid any kind of confusion.

If it's about naming i'm happy to reconsider the name oc. I'm actually looking forward to have more input on this from the community: what's the best way of approaching all of these "modes" that -- to me -- seem very algorithm-dependant.
Just think about SARSA vs QMax, where the action selection is completely different. We can easily do that via "exploration_type", but it's neither related to eval/train, gradient off/on or else.

I would really prefer that packed_sequence is not used here. It's brought me nothing but pain.

Got it, it's a headache on our side too to let's just not use it.

@smorad
Copy link
Contributor

smorad commented Apr 24, 2023

That tensordict format looks good to me.

Yeah that's a good point with the modes. In that case it's better to err on the side of flexibility than simplicity. I think perhaps a small table in the docs mapping each of these modes to a 1-2 sentence description would help clear up any ambiguities. I do like the way exploration etc. is currently handled. Is set_recurrent_mode more specific than set_temporal_mode? Could either be batch_time or sequential_time.

nit: hidden/state are super overloaded terms. I think we should be explicit with recurrent_state_h, recurrent_state_c or rstate_h, rstate_c, which will help new users when they are grepping through the codebase.

# Conflicts:
#	torchrl/modules/tensordict_module/actors.py
#	torchrl/modules/utils/utils.py
@vmoens vmoens merged commit 7ca3547 into main Apr 28, 2023
@vmoens vmoens deleted the rnn_example branch April 28, 2023 12:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants