Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dict obs support for PPO #559

Merged
merged 8 commits into from
Nov 27, 2024
Merged

Add dict obs support for PPO #559

merged 8 commits into from
Nov 27, 2024

Conversation

Andrew-Luo1
Copy link
Contributor

Implementation of @kevinzakka's specification for dictionary observations. Dictionary-valued observations are useful in a lot of contexts: privileged critic inputs and policy inputs that mix pixels and state information to name two.

Behaviour

  1. In the case of ndarray-valued observations, this PR makes no changes.
  2. When you call ppo/train.py with a dict-valued observation with observation normalisation enabled, it is applied to obs['state'].
  3. When you call any other training agent with a dict-valued observation, it raises a NotImplementedError. Supporting other agents would be future work.

Usage
An upcoming PR on pixels-based PPO training provides an example of how to train with dictionary observations. Essentially,

  1. Ensure that env.step and env.reset return the observation in a dictionary form; see the example below.
  2. Implement your environment's observation_size -> Union[int, Mapping[str, Tuple[int, ...]]] property.
  3. Change the called network, the apply method and the dummy observation generation in make_policy_network and make_value_network
  4. Implement your desired network to process the dictionary observation.

Example observation

obs = {
  'state': jnp.concat(data.qpos, data.qvel)
  'pixels/rgb': rgb/255.0
  'pixels/depth': depth/self._max_depth
}

Side-effects
AutoResetWrapper and ppo/losses.py needed to be slightly modified to support obs-dict training. These modifications do not appear to affect performance nor training (not shown), tested on the Franka Pick-up Cube task.

Obs Type Commit f43727 PR
State 212118 SPS 213242 SPS
Dict NA 212729 SPS

@Andrew-Luo1 Andrew-Luo1 mentioned this pull request Nov 26, 2024
Copy link
Collaborator

@btaba btaba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Andrew!

brax/envs/wrappers/training.py Show resolved Hide resolved
@@ -136,8 +137,13 @@ def compute_ppo_loss(

baseline = value_apply(normalizer_params, params.value, data.observation)

if dict_obs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this just be:

terminal_obs = jax.tree_util.tree_map(lambda x: x[-1], data.next_observation)

and get rid of dict_obs and the if statement?

@@ -251,7 +256,8 @@ def train(
reward_scaling=reward_scaling,
gae_lambda=gae_lambda,
clipping_epsilon=clipping_epsilon,
normalize_advantage=normalize_advantage)
normalize_advantage=normalize_advantage,
dict_obs=not ndarray_obs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try to get rid of dict_obs

@@ -231,12 +231,17 @@ def train(
key_envs = jnp.reshape(key_envs,
(local_devices_to_use, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)
ndarray_obs = isinstance(env_state.obs, jnp.ndarray) # Check whether observations are in dictionary form.
if not ndarray_obs and normalize_observations:
assert "state" in env.observation_size, "Observation normalisation only supported for states."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ok to just have a KeyError below rather than a one-off input validation here. We should either do better input validation, or fail loudly below.

if not ndarray_obs and normalize_observations:
assert "state" in env.observation_size, "Observation normalisation only supported for states."

obs_shape = env_state.obs.shape[-1] if ndarray_obs else env.observation_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not rely on env.observation_size, I believe in many cases, that calls an env.step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched to using jax.tree_util.tree_map(lambda x: x.shape, env_state.obs)

training_state.normalizer_params,
data.observation,
pmap_axis_name=_PMAP_AXIS_NAME)
if normalize_observations:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this if statement, and the one that was added below it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For line 392, this handles the case when obs doesn't have the ['state'] key. Not sure if it makes sense for us to cover this case.

For the block under 331, there's nothing to update the normaliser with if 'state' isn't in the obs dict.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these if statements should be removed if the purpose is to validate whether "state" exists or not.

Let's fail with KeyError, as discussed previously, we don't want to fail silently in general

@@ -79,7 +79,7 @@ class NetworkFactory(Protocol[NetworkType]):

def __call__(
self,
observation_size: int,
observation_size: Union[int, Mapping[str, Tuple[int, ...]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be Union[Tuple[int, ...], int] in the Mapping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there seems to be no reason to enforce having <=1D obs be tuple-sized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of a dict, observation_size will always be Mapping[str, Tuple[int, ...]] the way I calculate it using tree_map, but in general having non-tuple sizes makes sense.

@btaba
Copy link
Collaborator

btaba commented Nov 26, 2024

@Andrew-Luo1 how do AutoResetWrapper and losses need to be modified? We should probably apply those updates here as well?

@btaba btaba self-assigned this Nov 26, 2024
@@ -79,7 +79,7 @@ class NetworkFactory(Protocol[NetworkType]):

def __call__(
self,
observation_size: int,
observation_size: Union[int, Mapping[str, Union[Tuple[int, ...], int]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May as well make this ObservationSize like the other types above

ndarray_obs = isinstance(env_state.obs, jnp.ndarray) # Check whether observations are in dictionary form.

obs_shape = env_state.obs.shape[-1] if ndarray_obs \
else jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs) # Discard batch axes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have [2:] for the dict, buy [-1] for the jax.Array version? Presumably for pixel inputs?

Can they both be [2:] and then remove the if statement?

Can the comment be more explicit about why the first two dims are removed (one for num_envs, and one for num_devices, is that right?)

brax/training/agents/ppo/train.py Show resolved Hide resolved
training_state.normalizer_params,
data.observation,
pmap_axis_name=_PMAP_AXIS_NAME)
if normalize_observations:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these if statements should be removed if the purpose is to validate whether "state" exists or not.

Let's fail with KeyError, as discussed previously, we don't want to fail silently in general

@@ -28,6 +28,7 @@ def __init__(self, **kwargs):
self._dt = 0.02
self._reset_count = 0
self._step_count = 0
self._dict_obs = kwargs.get('dict_obs', False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use_dict_obs

_dict_obs makes it seem like you're storing dictionary observations to self

brax/envs/base.py Show resolved Hide resolved
@@ -15,14 +15,15 @@
"""Network definitions."""

import dataclasses
from typing import Any, Callable, Sequence, Tuple
from typing import Any, Callable, Sequence, Tuple, Mapping
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: keep these imports in order, here and elsewhere in the PR

import warnings

from brax.training import types
from brax.training.spectral_norm import SNDense
from flax import linen
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: generally we try to avoid importing members of modules directly

preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
layer_norm: bool = False) -> FeedForwardNetwork:
"""Creates a policy network."""
"""Creates a policy network. Only processes state in the case of dict obs."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this comment is better suited where it's applied rather than in the top-level function docstring. So on L107

Similar comment for value_network below

brax/training/agents/ppo/train.py Show resolved Hide resolved
@@ -82,48 +83,55 @@ def __call__(self, data: jnp.ndarray):
hidden = self.activation(hidden)
return hidden

def canolicalize_obs_size(obs_size: types.ObservationSize) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "canonical" can be a quite overloaded term, let's call this get_obs_state_size

Copy link
Collaborator

@btaba btaba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a couple more nits


from brax import base
from brax.generalized import pipeline as g_pipeline
from brax.io import image
from brax.mjx import pipeline as m_pipeline
from brax.positional import pipeline as p_pipeline
from brax.spring import pipeline as s_pipeline
from brax.training.types import ObservationSize
Copy link
Collaborator

@btaba btaba Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try to keep brax.training pretty standalone from other parts of the lib, can we remove this import?

rng = jax.random.PRNGKey(0)
reset_state = self.unwrapped.reset(rng)
return reset_state.obs.shape[-1]
obs = reset_state.obs
if isinstance(obs, jax.Array) and len(obs.shape) == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the len(obs.shape) == 1 ? just do what we had on the left hand side

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added clarifying comment. Does it make sense to have observation_size return a tuple for multi-dimensional obs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix

obs = reset_state.obs
if isinstance(obs, jax.Array) and len(obs.shape) == 1:
return obs.shape[-1]
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

remove the else and just return

@@ -100,30 +103,36 @@ def make_policy_network(
layer_norm=layer_norm)

def apply(processor_params, policy_params, obs):
obs = obs if isinstance(obs, jnp.ndarray) \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: avoid backslash, use (...)

rng = jax.random.PRNGKey(0)
reset_state = self.unwrapped.reset(rng)
return reset_state.obs.shape[-1]
obs = reset_state.obs
if isinstance(obs, jax.Array) and len(obs.shape) == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix

@btaba btaba merged commit e615f42 into google:main Nov 27, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants