-
Notifications
You must be signed in to change notification settings - Fork 261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Normalization Wrappers #113
Conversation
…s. Ultimately, I'd like the update to occur outside of vmap?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi David - thank you for opening this PR! Please see #89 for ongoing discussion. This is an area that is still quite in flux. For what it's worth, I would like to see most of these env wrappers moved out of Brax and under the purview of gym, so I'm hesitant to make any big changes in the Brax repo until that is all sorted out.
In particular, things like observation normalization are handled in the brax learners outside of the environments, and that's a deliberate design decision. The main reason being that we often run on multiple devices and want to be careful to aggregate running statistics across device, which requires special handling in the trainer.
So I think I defer to @vwxyzjn and @jkterry1 with what to do with this particular PR.
That said! I'd be happy to accept a minimal PR here just to get in a tight version of observation normalization, so that we could simplify (and possibly improve the perf a tiny bit) of this pytorch training colab: https://colab.sandbox.google.com/github/google/brax/blob/main/notebooks/training_torch.ipynb
If you'd like to have a go at making just an observation normalization wrapper, and then updating the colab to show it working, I'd be happy to shephered that in while the grander gym planning is underway.
|
||
from brax import jumpy as jp | ||
from brax.envs import env as brax_env | ||
import gym | ||
from gym import spaces | ||
from gym.vector import utils | ||
import jax | ||
from dataclasses import dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused?
return state.replace(obs=normalize_with_rmstd(state.obs, state.info['running_obs'])) | ||
|
||
|
||
class NormalizeRewardWrapper(brax_env.Wrapper): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would one expect for returns with discounting to be stored in this same wrapper? It seems orthogonal to storing the normalized reward, but I don't know enough about general purpose RL to say whether these should all go together.
Also, what about truncations? Would that be handled in this wrapper or somewhere else?
|
||
|
||
def step(self, state: brax_env.State, action: jp.ndarray) -> brax_env.State: | ||
running_obs = state.info.pop('running_obs') # hack to avoid vmap :( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is an OK solution for now.
|
||
def update_running_mean_std(std_state: RunningMeanStdState, batch: jp.ndarray) -> RunningMeanStdState: | ||
"""Update running statistics with batch of obsrvations (Welford's algorithm)""" | ||
batch = jp.atleast_2d(batch)[0] # Account for unbatched environments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you forget to add some changes to jumpy for this to work?
Partial #89
This is an initial attempt at doing performant normalization wrappers akin to the standard gym normalization wrappers. This includes NormalizeObservation and NormalizeReward, using the same approach as Gym/StableBaselines.
I've implemented it at the brax_env.env level instead of the gym level and tested that it's compatible with JIT and VectorWrapper environments.
Open to ideas of how to make this better...I'm not a fan of the hack I did to include with running statistics in the environment state while keeping them out of vmap logic! But I wanted to get these started :)