Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalization Wrappers #113

Closed
wants to merge 2 commits into from
Closed

Normalization Wrappers #113

wants to merge 2 commits into from

Conversation

DavidSlayback
Copy link
Contributor

@DavidSlayback DavidSlayback commented Nov 11, 2021

Partial #89

This is an initial attempt at doing performant normalization wrappers akin to the standard gym normalization wrappers. This includes NormalizeObservation and NormalizeReward, using the same approach as Gym/StableBaselines.

I've implemented it at the brax_env.env level instead of the gym level and tested that it's compatible with JIT and VectorWrapper environments.

Open to ideas of how to make this better...I'm not a fan of the hack I did to include with running statistics in the environment state while keeping them out of vmap logic! But I wanted to get these started :)

@google-cla google-cla bot added the cla: yes label Nov 11, 2021
Copy link
Collaborator

@erikfrey erikfrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi David - thank you for opening this PR! Please see #89 for ongoing discussion. This is an area that is still quite in flux. For what it's worth, I would like to see most of these env wrappers moved out of Brax and under the purview of gym, so I'm hesitant to make any big changes in the Brax repo until that is all sorted out.

In particular, things like observation normalization are handled in the brax learners outside of the environments, and that's a deliberate design decision. The main reason being that we often run on multiple devices and want to be careful to aggregate running statistics across device, which requires special handling in the trainer.

So I think I defer to @vwxyzjn and @jkterry1 with what to do with this particular PR.

That said! I'd be happy to accept a minimal PR here just to get in a tight version of observation normalization, so that we could simplify (and possibly improve the perf a tiny bit) of this pytorch training colab: https://colab.sandbox.google.com/github/google/brax/blob/main/notebooks/training_torch.ipynb

If you'd like to have a go at making just an observation normalization wrapper, and then updating the colab to show it working, I'd be happy to shephered that in while the grander gym planning is underway.


from brax import jumpy as jp
from brax.envs import env as brax_env
import gym
from gym import spaces
from gym.vector import utils
import jax
from dataclasses import dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused?

return state.replace(obs=normalize_with_rmstd(state.obs, state.info['running_obs']))


class NormalizeRewardWrapper(brax_env.Wrapper):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would one expect for returns with discounting to be stored in this same wrapper? It seems orthogonal to storing the normalized reward, but I don't know enough about general purpose RL to say whether these should all go together.

Also, what about truncations? Would that be handled in this wrapper or somewhere else?



def step(self, state: brax_env.State, action: jp.ndarray) -> brax_env.State:
running_obs = state.info.pop('running_obs') # hack to avoid vmap :(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an OK solution for now.


def update_running_mean_std(std_state: RunningMeanStdState, batch: jp.ndarray) -> RunningMeanStdState:
"""Update running statistics with batch of obsrvations (Welford's algorithm)"""
batch = jp.atleast_2d(batch)[0] # Account for unbatched environments
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you forget to add some changes to jumpy for this to work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants