Skip to content

Commit

Permalink
PPO on Pixels (#560)
Browse files Browse the repository at this point in the history
* Initial commit of vision PPO networks

* implement vision wrappers

* change ppo loss and the autoresetwrapper to support dictionary-valued observations

* add random image shifts

* support normalising observations, clean up train_pixels.py

* vision networks

* fix bug in state normalisation

* add channel-wise layer norm in CNN

* remove old file

* clean up imports

* enforce FrozenDict to avoid incorrect gradients

* refactor the vision wrappers as flags in envs.training wrappers

* support asymmetric actor critic on pixels, clean up normalisation logic

* rename networks files

* write basic pixels ppo test, make remove_pixels() check for non-dict obs

* update test for ppo on pixels to test pixel-only observations and cast to frozen dict (does not decrease performance)

* fix bug for aac on pixels

* remove old file

* linting

* clean up logic for toy testing env

* small code placement and logic clean-up

* for vision networks, only normalize as needed

* move vision networks around

* remove scan parameter for wrapping but switch wrapping order

* linting

* add acknowledgement

* replace boolean args to testing env with obs_mode enum

* write docstring for toy testing env and clean up

* make pixels functions private

* update sac test

---------

Co-authored-by: Mustafa <[email protected]>
  • Loading branch information
Andrew-Luo1 and StafaH authored Dec 4, 2024
1 parent 417465c commit 68906bc
Show file tree
Hide file tree
Showing 8 changed files with 506 additions and 85 deletions.
105 changes: 77 additions & 28 deletions brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,76 @@
# pylint:disable=g-multiple-import
"""Gotta go fast! This trivial Env is for unit testing."""

from brax import base
from brax.envs.base import PipelineEnv, State
from enum import Enum

import jax
from jax import numpy as jp
from flax.core import FrozenDict

from brax import base
from brax.envs.base import PipelineEnv, State


class ObservationMode(Enum):
"""
Describes observation formats.
Attributes:
NDARRAY: Flat NumPy array of state info.
DICT_STATE: Dictionary of state info.
DICT_PIXELS: Dictionary of pixel observations.
DICT_PIXELS_STATE: Dictionary of both state and pixel info.
"""
NDARRAY = "ndarray"
DICT_STATE = "dict_state"
DICT_PIXELS = "dict_pixels"
DICT_PIXELS_STATE = "dict_pixels_state"


class Fast(PipelineEnv):
"""Trains an agent to go fast."""

def __init__(self, **kwargs):
def __init__(
self,
asymmetric_obs: bool = False,
obs_mode: ObservationMode = ObservationMode.NDARRAY,
**kwargs,
):
self._dt = 0.02
self._reset_count = 0
self._step_count = 0
self._use_dict_obs = kwargs.get('use_dict_obs', False)
self._asymmetric_obs = kwargs.get('asymmetric_obs', False)
if self._asymmetric_obs and not self._use_dict_obs:
raise ValueError('asymmetric_obs requires use_dict_obs=True')
self._asymmetric_obs = asymmetric_obs
self._obs_mode = ObservationMode(obs_mode)

if self._asymmetric_obs and self._obs_mode == ObservationMode.NDARRAY:
raise ValueError("asymmetric_obs requires dictionary observations")

def reset(self, rng: jax.Array) -> State:
self._reset_count += 1
pipeline_state = base.State(
q=jp.zeros(1),
qd=jp.zeros(1),
x=base.Transform.create(pos=jp.zeros(3)),
xd=base.Motion.create(vel=jp.zeros(3)),
contact=None
q=jp.zeros(1),
qd=jp.zeros(1),
x=base.Transform.create(pos=jp.zeros(3)),
xd=base.Motion.create(vel=jp.zeros(3)),
contact=None,
)
obs = jp.zeros(2)
obs = {'state': obs} if self._use_dict_obs else obs
obs = {"state": jp.zeros(2)}
if self._asymmetric_obs:
obs['privileged_state'] = jp.zeros(4) # Dummy privileged state.
obs["privileged_state"] = jp.zeros(4) # Dummy privileged state.
pixels = {
"pixels/view_0": jp.zeros((4, 4, 3)),
"pixels/view_1": jp.zeros((4, 4, 3)),
}

if self._obs_mode == ObservationMode.DICT_STATE:
obs = FrozenDict(obs)
elif self._obs_mode == ObservationMode.DICT_PIXELS:
obs = FrozenDict(pixels)
elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE:
obs = FrozenDict({**obs, **pixels})
elif self._obs_mode == ObservationMode.NDARRAY:
obs = obs["state"]

reward, done = jp.array(0.0), jp.array(0.0)
return State(pipeline_state, obs, reward, done)

Expand All @@ -56,13 +95,26 @@ def step(self, state: State, action: jax.Array) -> State:
pos = state.pipeline_state.x.pos + vel * self._dt

qp = state.pipeline_state.replace(
x=state.pipeline_state.x.replace(pos=pos),
xd=state.pipeline_state.xd.replace(vel=vel),
x=state.pipeline_state.x.replace(pos=pos),
xd=state.pipeline_state.xd.replace(vel=vel),
)
obs = jp.array([pos[0], vel[0]])
obs = {'state': obs} if self._use_dict_obs else obs
obs = {"state": jp.array([pos[0], vel[0]])}
if self._asymmetric_obs:
obs['privileged_state'] = jp.zeros(4) # Dummy privileged state.
obs["privileged_state"] = jp.zeros(4) # Dummy privileged state.
pixels = {
"pixels/view_0": jp.zeros((4, 4, 3)),
"pixels/view_1": jp.zeros((4, 4, 3)),
}

if self._obs_mode == ObservationMode.DICT_STATE:
obs = FrozenDict(obs)
elif self._obs_mode == ObservationMode.DICT_PIXELS:
obs = FrozenDict(pixels)
elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE:
obs = FrozenDict({**obs, **pixels})
elif self._obs_mode == ObservationMode.NDARRAY:
obs = obs["state"]

reward = pos[0]

return state.replace(pipeline_state=qp, obs=obs, reward=reward)
Expand All @@ -77,14 +129,11 @@ def step_count(self):

@property
def observation_size(self):
if not self._use_dict_obs:
return 2

obs = {'state': 2}
if self._asymmetric_obs:
obs['privileged_state'] = 4

return obs
ret = super().observation_size
if self._obs_mode == ObservationMode.NDARRAY:
return ret
# Turn 1-D tuples to ints.
return {key: value[0] if len(value) == 1 else value for key, value in ret.items()}

@property
def action_size(self):
Expand Down
4 changes: 2 additions & 2 deletions brax/envs/wrappers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def wrap(
action_repeat: int = 1,
randomization_fn: Optional[
Callable[[System], Tuple[System, System]]
] = None,
] = None
) -> Wrapper:
"""Common wrapper pattern for all training agents.
Expand All @@ -46,11 +46,11 @@ def wrap(
environment did not already have batch dimensions, it is additional Vmap
wrapped.
"""
env = EpisodeWrapper(env, episode_length, action_repeat)
if randomization_fn is None:
env = VmapWrapper(env)
else:
env = DomainRandomizationVmapWrapper(env, randomization_fn)
env = EpisodeWrapper(env, episode_length, action_repeat)
env = AutoResetWrapper(env)
return env

Expand Down
5 changes: 4 additions & 1 deletion brax/training/acme/running_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def update(state: RunningStatisticsState,
# We require exactly the same structure to avoid issues when flattened
# batch and state have different order of elements.
assert jax.tree_util.tree_structure(batch) == jax.tree_util.tree_structure(state.mean)
batch_shape = jax.tree_util.tree_leaves(batch)[0].shape
batch_leaves = jax.tree_util.tree_leaves(batch)
if not batch_leaves: # State and batch are both empty. Nothing to normalize.
return state
batch_shape = batch_leaves[0].shape
# We assume the batch dimensions always go first.
batch_dims = batch_shape[:len(batch_shape) -
jax.tree_util.tree_leaves(state.mean)[0].ndim]
Expand Down
81 changes: 81 additions & 0 deletions brax/training/agents/ppo/networks_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PPO vision networks."""

from typing import Any, Callable, Mapping, Sequence, Tuple

import flax
from flax import linen
import jax.numpy as jp

from brax.training import distribution
from brax.training import networks
from brax.training import types


ModuleDef = Any
ActivationFn = Callable[[jp.ndarray], jp.ndarray]
Initializer = Callable[..., Any]


@flax.struct.dataclass
class PPONetworks:
policy_network: networks.FeedForwardNetwork
value_network: networks.FeedForwardNetwork
parametric_action_distribution: distribution.ParametricDistribution


def make_ppo_networks_vision(
# channel_size: int,
observation_size: Mapping[str, Tuple[int, ...]],
action_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
policy_hidden_layer_sizes: Sequence[int] = [256, 256],
value_hidden_layer_sizes: Sequence[int] = [256, 256],
activation: ActivationFn = linen.swish,
normalise_channels: bool = False,
policy_obs_key: str = "",
value_obs_key: str = "",
) -> PPONetworks:
"""Make Vision PPO networks with preprocessor."""

parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size
)

policy_network = networks.make_policy_network_vision(
observation_size=observation_size,
output_size=parametric_action_distribution.param_size,
preprocess_observations_fn=preprocess_observations_fn,
activation=activation,
hidden_layer_sizes=policy_hidden_layer_sizes,
state_obs_key=policy_obs_key,
normalise_channels=normalise_channels,
)

value_network = networks.make_value_network_vision(
observation_size=observation_size,
preprocess_observations_fn=preprocess_observations_fn,
activation=activation,
hidden_layer_sizes=value_hidden_layer_sizes,
state_obs_key=value_obs_key,
normalise_channels=normalise_channels,
)

return PPONetworks(
policy_network=policy_network,
value_network=value_network,
parametric_action_distribution=parametric_action_distribution,
)
75 changes: 70 additions & 5 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import functools
import time
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Mapping, Optional, Tuple, Union

from absl import logging
from brax import base
Expand All @@ -37,6 +37,7 @@
from brax.v1 import envs as envs_v1
from etils import epath
import flax
from flax.core import FrozenDict
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -72,6 +73,47 @@ def f(leaf):
return jax.tree_util.tree_map(f, tree)


def _random_translate_pixels(obs: Mapping[str, jax.Array], key: PRNGKey):
"""Apply random translations to B x T x ... pixel observations.
The same shift is applied across the unroll_length (T) dimension."""
obs = FrozenDict(obs)

@jax.vmap
def rt_all_views(ub_obs: Mapping[str, jax.Array], key: PRNGKey) -> Mapping[str, jax.Array]:
# Expects dictionary of unbatched observations.
def rt_view(img: jax.Array, padding: int, key: PRNGKey) -> jax.Array: # TxHxWxC
# Randomly translates a set of pixel inputs.
# Adapted from https://github.com/ikostrikov/jaxrl/blob/main/jaxrl/agents/drq/augmentations.py
crop_from = jax.random.randint(key, (2,), 0, 2 * padding + 1)
zero = jnp.zeros((1,), dtype=jnp.int32)
crop_from = jnp.concatenate([zero, crop_from, zero])
padded_img = jnp.pad(
img, ((0, 0), (padding, padding), (padding, padding), (0, 0)), mode="edge"
)
return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)

out = {}
for k_view, v_view in ub_obs.items():
if k_view.startswith("pixels/"):
key, key_shift = jax.random.split(key)
out[k_view] = rt_view(v_view, 4, key_shift)
ub_obs = ub_obs.copy(out) # Update the shifted fields
return ub_obs

bdim = next(iter(obs.items()), None)[1].shape[0]
keys = jax.random.split(key, bdim)
obs = rt_all_views(obs, keys)
return obs


def _remove_pixels(obs: Union[jnp.ndarray, Mapping]) -> Union[jnp.ndarray, Mapping]:
"""Removes pixel observations from the observation dict.
FrozenDicts are used to avoid incorrect gradients."""
if not isinstance(obs, Mapping):
return obs
return FrozenDict({k: v for k, v in obs.items() if not k.startswith("pixels/")})


def train(
environment: Union[envs_v1.Env, envs.Env],
num_timesteps: int,
Expand Down Expand Up @@ -108,6 +150,8 @@ def train(
] = None,
restore_checkpoint_path: Optional[str] = None,
max_grad_norm: Optional[float] = None,
madrona_backend: bool = False,
augment_pixels: bool = False
):
"""PPO training.
Expand Down Expand Up @@ -164,6 +208,14 @@ def train(
Returns:
Tuple of (make_policy function, network params, metrics)
"""
if madrona_backend:
if eval_env:
raise ValueError("Madrona-MJX doesn't support multiple env instances")
if num_eval_envs != num_envs:
raise ValueError("Madrona-MJX requires a fixed batch size")
if action_repeat != 1:
raise ValueError("Implement action_repeat using PipelineEnv's _n_frames to avoid unnecessary rendering!")

assert batch_size * num_minibatches % num_envs == 0
xt = time.time()

Expand Down Expand Up @@ -225,7 +277,7 @@ def train(
environment,
episode_length=episode_length,
action_repeat=action_repeat,
randomization_fn=v_randomization_fn,
randomization_fn=v_randomization_fn
)

reset_fn = jax.jit(jax.vmap(env.reset))
Expand Down Expand Up @@ -285,6 +337,18 @@ def sgd_step(carry, unused_t, data: types.Transition,
optimizer_state, params, key = carry
key, key_perm, key_grad = jax.random.split(key, 3)

if augment_pixels:
key, key_rt = jax.random.split(key)
r_translate = functools.partial(_random_translate_pixels, key=key_rt)
data = types.Transition(
observation=r_translate(data.observation),
action=data.action,
reward=data.reward,
discount=data.discount,
next_observation=r_translate(data.next_observation),
extras=data.extras
)

def convert_data(x: jnp.ndarray):
x = jax.random.permutation(key_perm, x)
x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])
Expand Down Expand Up @@ -334,8 +398,9 @@ def f(carry, unused_t):
# Update normalization params and normalize observations.
normalizer_params = running_statistics.update(
training_state.normalizer_params,
data.observation,
pmap_axis_name=_PMAP_AXIS_NAME)
_remove_pixels(data.observation),
pmap_axis_name=_PMAP_AXIS_NAME
)

(optimizer_state, params, _), metrics = jax.lax.scan(
functools.partial(
Expand Down Expand Up @@ -397,7 +462,7 @@ def training_epoch_with_timing(
training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray
optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars
params=init_params,
normalizer_params=running_statistics.init_state(obs_shape),
normalizer_params=running_statistics.init_state(_remove_pixels(obs_shape)),
env_steps=0)

if (
Expand Down
Loading

0 comments on commit 68906bc

Please sign in to comment.