Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dict obs support for PPO #559

Merged
merged 8 commits into from
Nov 27, 2024
Merged
17 changes: 11 additions & 6 deletions brax/envs/base.py
btaba marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""A brax environment for training and inference."""

import abc
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

from brax import base
from brax.generalized import pipeline as g_pipeline
Expand All @@ -28,13 +28,14 @@
import jax
import numpy as np

ObservationSize = Union[Union[Tuple, int], Mapping[str, Union[Tuple[int, ...], int]]]

@struct.dataclass
class State(base.Base):
"""Environment state for training and inference."""

pipeline_state: Optional[base.State]
obs: jax.Array
obs: Union[jax.Array, Mapping[str, jax.Array]]
reward: jax.Array
done: jax.Array
metrics: Dict[str, jax.Array] = struct.field(default_factory=dict)
Expand All @@ -54,7 +55,7 @@ def step(self, state: State, action: jax.Array) -> State:

@property
@abc.abstractmethod
def observation_size(self) -> int:
def observation_size(self) -> ObservationSize:
"""The size of the observation vector returned in step and reset."""

@property
Expand Down Expand Up @@ -139,10 +140,14 @@ def dt(self) -> jax.Array:
return self.sys.opt.timestep * self._n_frames # pytype: disable=attribute-error

@property
def observation_size(self) -> int:
def observation_size(self) -> ObservationSize:
rng = jax.random.PRNGKey(0)
reset_state = self.unwrapped.reset(rng)
return reset_state.obs.shape[-1]
obs = reset_state.obs
# Compatibility with existing training agents for vector ndarray obs
if isinstance(obs, jax.Array) and len(obs.shape) == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the len(obs.shape) == 1 ? just do what we had on the left hand side

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added clarifying comment. Does it make sense to have observation_size return a tuple for multi-dimensional obs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix

return obs.shape[-1]
return jax.tree_util.tree_map(lambda x: x.shape, obs)

@property
def action_size(self) -> int:
Expand Down Expand Up @@ -176,7 +181,7 @@ def step(self, state: State, action: jax.Array) -> State:
return self.env.step(state, action)

@property
def observation_size(self) -> int:
def observation_size(self) -> ObservationSize:
return self.env.observation_size

@property
Expand Down
3 changes: 3 additions & 0 deletions brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, **kwargs):
self._dt = 0.02
self._reset_count = 0
self._step_count = 0
self._use_dict_obs = kwargs.get('use_dict_obs', False)

def reset(self, rng: jax.Array) -> State:
self._reset_count += 1
Expand All @@ -39,6 +40,7 @@ def reset(self, rng: jax.Array) -> State:
contact=None
)
obs = jp.zeros(2)
obs = {'state': obs} if self._use_dict_obs else obs
reward, done = jp.array(0.0), jp.array(0.0)
return State(pipeline_state, obs, reward, done)

Expand All @@ -53,6 +55,7 @@ def step(self, state: State, action: jax.Array) -> State:
xd=state.pipeline_state.xd.replace(vel=vel),
)
obs = jp.array([pos[0], vel[0]])
obs = {'state': obs} if self._use_dict_obs else obs
reward = pos[0]

return state.replace(pipeline_state=qp, obs=obs, reward=reward)
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/wrappers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def where_done(x, y):
pipeline_state = jax.tree.map(
where_done, state.info['first_pipeline_state'], state.pipeline_state
Andrew-Luo1 marked this conversation as resolved.
Show resolved Hide resolved
)
obs = where_done(state.info['first_obs'], state.obs)
obs = jax.tree.map(where_done, state.info['first_obs'], state.obs)
return state.replace(pipeline_state=pipeline_state, obs=obs)


Expand Down
6 changes: 5 additions & 1 deletion brax/training/agents/apg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,15 @@ def train(
reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in APG")

normalize = lambda x, y: x
if normalize_observations:
normalize = running_statistics.normalize
apg_network = network_factory(
env.observation_size,
obs_size,
env.action_size,
preprocess_observations_fn=normalize)
make_policy = apg_networks.make_inference_fn(apg_network)
Expand Down
2 changes: 2 additions & 0 deletions brax/training/agents/ars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def train(
)

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in ARS")

normalize_fn = lambda x, y: x
if normalize_observations:
Expand Down
4 changes: 3 additions & 1 deletion brax/training/agents/es/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def train(
)

obs_size = env.observation_size

if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in ES")

normalize_fn = lambda x, y: x
if normalize_observations:
normalize_fn = running_statistics.normalize
Expand Down
4 changes: 2 additions & 2 deletions brax/training/agents/ppo/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def compute_ppo_loss(
data.observation)

baseline = value_apply(normalizer_params, params.value, data.observation)

terminal_obs = jax.tree_util.tree_map(lambda x: x[-1], data.next_observation)
bootstrap_value = value_apply(normalizer_params, params.value,
data.next_observation[-1])
terminal_obs)

rewards = data.reward * reward_scaling
truncation = data.extras['state_extras']['truncation']
Expand Down
11 changes: 8 additions & 3 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,16 @@ def train(
key_envs = jnp.reshape(key_envs,
(local_devices_to_use, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)
ndarray_obs = isinstance(env_state.obs, jnp.ndarray) # Check whether observations are in dictionary form.

# Discard the batch axes over devices and envs.
obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs)

normalize = lambda x, y: x
if normalize_observations:
normalize = running_statistics.normalize
ppo_network = network_factory(
btaba marked this conversation as resolved.
Show resolved Hide resolved
env_state.obs.shape[-1],
obs_shape,
env.action_size,
preprocess_observations_fn=normalize)
make_policy = ppo_networks.make_inference_fn(ppo_network)
Expand Down Expand Up @@ -324,7 +328,7 @@ def f(carry, unused_t):
# Update normalization params and normalize observations.
normalizer_params = running_statistics.update(
training_state.normalizer_params,
data.observation,
data.observation if ndarray_obs else data.observation['state'],
pmap_axis_name=_PMAP_AXIS_NAME)

(optimizer_state, params, _), metrics = jax.lax.scan(
Expand Down Expand Up @@ -381,11 +385,12 @@ def training_epoch_with_timing(
value=ppo_network.value_network.init(key_value),
)

obs_shape = env_state.obs.shape if ndarray_obs else env_state.obs['state'].shape
training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray
optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars
params=init_params,
normalizer_params=running_statistics.init_state(
specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32'))),
specs.Array(obs_shape[-1:], jnp.dtype('float32'))),
env_steps=0)

if (
Expand Down
6 changes: 3 additions & 3 deletions brax/training/agents/ppo/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
class PPOTest(parameterized.TestCase):
"""Tests for PPO module."""


def testTrain(self):
@parameterized.parameters(True, False)
def testTrain(self, use_dict_obs):
"""Test PPO with a simple env."""
fast = envs.get_environment('fast')
fast = envs.get_environment('fast', use_dict_obs=use_dict_obs)
_, _, metrics = ppo.train(
fast,
num_timesteps=2**15,
Expand Down
3 changes: 3 additions & 0 deletions brax/training/agents/sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def train(
)

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in SAC")

action_size = env.action_size

normalize_fn = lambda x, y: x
Expand Down
21 changes: 15 additions & 6 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Network definitions."""

import dataclasses
from typing import Any, Callable, Sequence, Tuple
from typing import Any, Callable, Mapping, Sequence, Tuple
import warnings

from brax.training import types
Expand Down Expand Up @@ -82,10 +82,13 @@ def __call__(self, data: jnp.ndarray):
hidden = self.activation(hidden)
return hidden

def get_obs_state_size(obs_size: types.ObservationSize) -> int:
obs_size = obs_size['state'] if isinstance(obs_size, Mapping) else obs_size
return jax.tree_util.tree_flatten(obs_size)[0][-1] # Size can be tuple or int.

def make_policy_network(
param_size: int,
obs_size: int,
obs_size: types.ObservationSize,
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
Expand All @@ -100,30 +103,36 @@ def make_policy_network(
layer_norm=layer_norm)

def apply(processor_params, policy_params, obs):
obs = (obs if isinstance(obs, jnp.ndarray)
else obs['state']) # state-only in the case of dict obs.
obs = preprocess_observations_fn(obs, processor_params)
return policy_module.apply(policy_params, obs)

obs_size = get_obs_state_size(obs_size)
dummy_obs = jnp.zeros((1, obs_size))
return FeedForwardNetwork(
init=lambda key: policy_module.init(key, dummy_obs), apply=apply)


def make_value_network(
obs_size: int,
obs_size: types.ObservationSize,
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
"""Creates a policy network."""
"""Creates a value network."""
value_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [1],
activation=activation,
kernel_init=jax.nn.initializers.lecun_uniform())

def apply(processor_params, policy_params, obs):
def apply(processor_params, value_params, obs):
obs = (obs if isinstance(obs, jnp.ndarray)
else obs['state']) # state-only in the case of dict obs.
obs = preprocess_observations_fn(obs, processor_params)
return jnp.squeeze(value_module.apply(policy_params, obs), axis=-1)
return jnp.squeeze(value_module.apply(value_params, obs), axis=-1)

obs_size = get_obs_state_size(obs_size)
dummy_obs = jnp.zeros((1, obs_size))
return FeedForwardNetwork(
init=lambda key: value_module.init(key, dummy_obs), apply=apply)
Expand Down
7 changes: 4 additions & 3 deletions brax/training/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Brax training types."""

from typing import Any, Mapping, NamedTuple, Tuple, TypeVar
from typing import Any, Mapping, NamedTuple, Tuple, TypeVar, Union

from brax.training.acme.types import NestedArray
import jax.numpy as jnp
Expand All @@ -30,7 +30,8 @@
Params = Any
PRNGKey = jnp.ndarray
Metrics = Mapping[str, jnp.ndarray]
Observation = jnp.ndarray
Observation = Union[jnp.ndarray, Mapping[str, jnp.ndarray]]
ObservationSize = Union[Union[Tuple, int], Mapping[str, Union[Tuple[int, ...], int]]]
Action = jnp.ndarray
Extra = Mapping[str, Any]
PolicyParams = Any
Expand Down Expand Up @@ -79,7 +80,7 @@ class NetworkFactory(Protocol[NetworkType]):

def __call__(
self,
observation_size: int,
observation_size: ObservationSize,
action_size: int,
preprocess_observations_fn:
PreprocessObservationFn = identity_observation_preprocessor
Expand Down