-
Notifications
You must be signed in to change notification settings - Fork 266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dict obs support for PPO #559
Conversation
…n your dict-valued obs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Andrew!
brax/training/agents/ppo/losses.py
Outdated
@@ -136,8 +137,13 @@ def compute_ppo_loss( | |||
|
|||
baseline = value_apply(normalizer_params, params.value, data.observation) | |||
|
|||
if dict_obs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this just be:
terminal_obs = jax.tree_util.tree_map(lambda x: x[-1], data.next_observation)
and get rid of dict_obs
and the if statement?
brax/training/agents/ppo/train.py
Outdated
@@ -251,7 +256,8 @@ def train( | |||
reward_scaling=reward_scaling, | |||
gae_lambda=gae_lambda, | |||
clipping_epsilon=clipping_epsilon, | |||
normalize_advantage=normalize_advantage) | |||
normalize_advantage=normalize_advantage, | |||
dict_obs=not ndarray_obs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's try to get rid of dict_obs
brax/training/agents/ppo/train.py
Outdated
@@ -231,12 +231,17 @@ def train( | |||
key_envs = jnp.reshape(key_envs, | |||
(local_devices_to_use, -1) + key_envs.shape[1:]) | |||
env_state = reset_fn(key_envs) | |||
ndarray_obs = isinstance(env_state.obs, jnp.ndarray) # Check whether observations are in dictionary form. | |||
if not ndarray_obs and normalize_observations: | |||
assert "state" in env.observation_size, "Observation normalisation only supported for states." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's ok to just have a KeyError below rather than a one-off input validation here. We should either do better input validation, or fail loudly below.
brax/training/agents/ppo/train.py
Outdated
if not ndarray_obs and normalize_observations: | ||
assert "state" in env.observation_size, "Observation normalisation only supported for states." | ||
|
||
obs_shape = env_state.obs.shape[-1] if ndarray_obs else env.observation_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not rely on env.observation_size, I believe in many cases, that calls an env.step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
switched to using jax.tree_util.tree_map(lambda x: x.shape, env_state.obs)
brax/training/agents/ppo/train.py
Outdated
training_state.normalizer_params, | ||
data.observation, | ||
pmap_axis_name=_PMAP_AXIS_NAME) | ||
if normalize_observations: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this if statement, and the one that was added below it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For line 392, this handles the case when obs doesn't have the ['state'] key. Not sure if it makes sense for us to cover this case.
For the block under 331, there's nothing to update the normaliser with if 'state' isn't in the obs dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these if statements should be removed if the purpose is to validate whether "state" exists or not.
Let's fail with KeyError, as discussed previously, we don't want to fail silently in general
brax/training/types.py
Outdated
@@ -79,7 +79,7 @@ class NetworkFactory(Protocol[NetworkType]): | |||
|
|||
def __call__( | |||
self, | |||
observation_size: int, | |||
observation_size: Union[int, Mapping[str, Tuple[int, ...]]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be Union[Tuple[int, ...], int] in the Mapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes there seems to be no reason to enforce having <=1D obs be tuple-sized
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of a dict, observation_size will always be Mapping[str, Tuple[int, ...]] the way I calculate it using tree_map, but in general having non-tuple sizes makes sense.
@Andrew-Luo1 how do AutoResetWrapper and losses need to be modified? We should probably apply those updates here as well? |
brax/training/types.py
Outdated
@@ -79,7 +79,7 @@ class NetworkFactory(Protocol[NetworkType]): | |||
|
|||
def __call__( | |||
self, | |||
observation_size: int, | |||
observation_size: Union[int, Mapping[str, Union[Tuple[int, ...], int]]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May as well make this ObservationSize
like the other types above
brax/training/agents/ppo/train.py
Outdated
ndarray_obs = isinstance(env_state.obs, jnp.ndarray) # Check whether observations are in dictionary form. | ||
|
||
obs_shape = env_state.obs.shape[-1] if ndarray_obs \ | ||
else jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs) # Discard batch axes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have [2:] for the dict, buy [-1] for the jax.Array version? Presumably for pixel inputs?
Can they both be [2:] and then remove the if statement?
Can the comment be more explicit about why the first two dims are removed (one for num_envs, and one for num_devices, is that right?)
brax/training/agents/ppo/train.py
Outdated
training_state.normalizer_params, | ||
data.observation, | ||
pmap_axis_name=_PMAP_AXIS_NAME) | ||
if normalize_observations: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these if statements should be removed if the purpose is to validate whether "state" exists or not.
Let's fail with KeyError, as discussed previously, we don't want to fail silently in general
brax/envs/fast.py
Outdated
@@ -28,6 +28,7 @@ def __init__(self, **kwargs): | |||
self._dt = 0.02 | |||
self._reset_count = 0 | |||
self._step_count = 0 | |||
self._dict_obs = kwargs.get('dict_obs', False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use_dict_obs
_dict_obs
makes it seem like you're storing dictionary observations to self
brax/training/networks.py
Outdated
@@ -15,14 +15,15 @@ | |||
"""Network definitions.""" | |||
|
|||
import dataclasses | |||
from typing import Any, Callable, Sequence, Tuple | |||
from typing import Any, Callable, Sequence, Tuple, Mapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: keep these imports in order, here and elsewhere in the PR
brax/training/networks.py
Outdated
import warnings | ||
|
||
from brax.training import types | ||
from brax.training.spectral_norm import SNDense | ||
from flax import linen | ||
import jax | ||
import jax.numpy as jnp | ||
from jax.tree_util import tree_flatten |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: generally we try to avoid importing members of modules directly
brax/training/networks.py
Outdated
preprocess_observations_fn: types.PreprocessObservationFn = types | ||
.identity_observation_preprocessor, | ||
hidden_layer_sizes: Sequence[int] = (256, 256), | ||
activation: ActivationFn = linen.relu, | ||
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), | ||
layer_norm: bool = False) -> FeedForwardNetwork: | ||
"""Creates a policy network.""" | ||
"""Creates a policy network. Only processes state in the case of dict obs.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this comment is better suited where it's applied rather than in the top-level function docstring. So on L107
Similar comment for value_network below
brax/training/networks.py
Outdated
@@ -82,48 +83,55 @@ def __call__(self, data: jnp.ndarray): | |||
hidden = self.activation(hidden) | |||
return hidden | |||
|
|||
def canolicalize_obs_size(obs_size: types.ObservationSize) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "canonical" can be a quite overloaded term, let's call this get_obs_state_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a couple more nits
brax/envs/base.py
Outdated
|
||
from brax import base | ||
from brax.generalized import pipeline as g_pipeline | ||
from brax.io import image | ||
from brax.mjx import pipeline as m_pipeline | ||
from brax.positional import pipeline as p_pipeline | ||
from brax.spring import pipeline as s_pipeline | ||
from brax.training.types import ObservationSize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We try to keep brax.training pretty standalone from other parts of the lib, can we remove this import?
brax/envs/base.py
Outdated
rng = jax.random.PRNGKey(0) | ||
reset_state = self.unwrapped.reset(rng) | ||
return reset_state.obs.shape[-1] | ||
obs = reset_state.obs | ||
if isinstance(obs, jax.Array) and len(obs.shape) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need the len(obs.shape) == 1 ? just do what we had on the left hand side
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added clarifying comment. Does it make sense to have observation_size return a tuple for multi-dimensional obs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix
brax/envs/base.py
Outdated
obs = reset_state.obs | ||
if isinstance(obs, jax.Array) and len(obs.shape) == 1: | ||
return obs.shape[-1] | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
remove the else and just return
brax/training/networks.py
Outdated
@@ -100,30 +103,36 @@ def make_policy_network( | |||
layer_norm=layer_norm) | |||
|
|||
def apply(processor_params, policy_params, obs): | |||
obs = obs if isinstance(obs, jnp.ndarray) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: avoid backslash, use (...)
brax/envs/base.py
Outdated
rng = jax.random.PRNGKey(0) | ||
reset_state = self.unwrapped.reset(rng) | ||
return reset_state.obs.shape[-1] | ||
obs = reset_state.obs | ||
if isinstance(obs, jax.Array) and len(obs.shape) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix
Implementation of @kevinzakka's specification for dictionary observations. Dictionary-valued observations are useful in a lot of contexts: privileged critic inputs and policy inputs that mix pixels and state information to name two.
Behaviour
ppo/train.py
with a dict-valued observation with observation normalisation enabled, it is applied to obs['state'].Usage
An upcoming PR on pixels-based PPO training provides an example of how to train with dictionary observations. Essentially,
env.step
andenv.reset
return the observation in a dictionary form; see the example below.observation_size -> Union[int, Mapping[str, Tuple[int, ...]]]
property.apply
method and the dummy observation generation inmake_policy_network
andmake_value_network
Example observation
Side-effects
AutoResetWrapper
andppo/losses.py
needed to be slightly modified to support obs-dict training. These modifications do not appear to affect performance nor training (not shown), tested on the Franka Pick-up Cube task.