From 68906bcc19eebca739e42d8388e5af08ab203252 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 4 Dec 2024 14:31:54 -0500 Subject: [PATCH] PPO on Pixels (#560) * Initial commit of vision PPO networks * implement vision wrappers * change ppo loss and the autoresetwrapper to support dictionary-valued observations * add random image shifts * support normalising observations, clean up train_pixels.py * vision networks * fix bug in state normalisation * add channel-wise layer norm in CNN * remove old file * clean up imports * enforce FrozenDict to avoid incorrect gradients * refactor the vision wrappers as flags in envs.training wrappers * support asymmetric actor critic on pixels, clean up normalisation logic * rename networks files * write basic pixels ppo test, make remove_pixels() check for non-dict obs * update test for ppo on pixels to test pixel-only observations and cast to frozen dict (does not decrease performance) * fix bug for aac on pixels * remove old file * linting * clean up logic for toy testing env * small code placement and logic clean-up * for vision networks, only normalize as needed * move vision networks around * remove scan parameter for wrapping but switch wrapping order * linting * add acknowledgement * replace boolean args to testing env with obs_mode enum * write docstring for toy testing env and clean up * make pixels functions private * update sac test --------- Co-authored-by: Mustafa --- brax/envs/fast.py | 105 +++++++++---- brax/envs/wrappers/training.py | 4 +- brax/training/acme/running_statistics.py | 5 +- brax/training/agents/ppo/networks_vision.py | 81 ++++++++++ brax/training/agents/ppo/train.py | 75 ++++++++- brax/training/agents/ppo/train_test.py | 154 ++++++++++++------ brax/training/agents/sac/train_test.py | 2 +- brax/training/networks.py | 165 +++++++++++++++++++- 8 files changed, 506 insertions(+), 85 deletions(-) create mode 100644 brax/training/agents/ppo/networks_vision.py diff --git a/brax/envs/fast.py b/brax/envs/fast.py index 7a87001a..397d9758 100644 --- a/brax/envs/fast.py +++ b/brax/envs/fast.py @@ -15,37 +15,76 @@ # pylint:disable=g-multiple-import """Gotta go fast! This trivial Env is for unit testing.""" -from brax import base -from brax.envs.base import PipelineEnv, State +from enum import Enum + import jax from jax import numpy as jp +from flax.core import FrozenDict + +from brax import base +from brax.envs.base import PipelineEnv, State + + +class ObservationMode(Enum): + """ + Describes observation formats. + + Attributes: + NDARRAY: Flat NumPy array of state info. + DICT_STATE: Dictionary of state info. + DICT_PIXELS: Dictionary of pixel observations. + DICT_PIXELS_STATE: Dictionary of both state and pixel info. + """ + NDARRAY = "ndarray" + DICT_STATE = "dict_state" + DICT_PIXELS = "dict_pixels" + DICT_PIXELS_STATE = "dict_pixels_state" class Fast(PipelineEnv): """Trains an agent to go fast.""" - def __init__(self, **kwargs): + def __init__( + self, + asymmetric_obs: bool = False, + obs_mode: ObservationMode = ObservationMode.NDARRAY, + **kwargs, + ): self._dt = 0.02 self._reset_count = 0 self._step_count = 0 - self._use_dict_obs = kwargs.get('use_dict_obs', False) - self._asymmetric_obs = kwargs.get('asymmetric_obs', False) - if self._asymmetric_obs and not self._use_dict_obs: - raise ValueError('asymmetric_obs requires use_dict_obs=True') + self._asymmetric_obs = asymmetric_obs + self._obs_mode = ObservationMode(obs_mode) + + if self._asymmetric_obs and self._obs_mode == ObservationMode.NDARRAY: + raise ValueError("asymmetric_obs requires dictionary observations") def reset(self, rng: jax.Array) -> State: self._reset_count += 1 pipeline_state = base.State( - q=jp.zeros(1), - qd=jp.zeros(1), - x=base.Transform.create(pos=jp.zeros(3)), - xd=base.Motion.create(vel=jp.zeros(3)), - contact=None + q=jp.zeros(1), + qd=jp.zeros(1), + x=base.Transform.create(pos=jp.zeros(3)), + xd=base.Motion.create(vel=jp.zeros(3)), + contact=None, ) - obs = jp.zeros(2) - obs = {'state': obs} if self._use_dict_obs else obs + obs = {"state": jp.zeros(2)} if self._asymmetric_obs: - obs['privileged_state'] = jp.zeros(4) # Dummy privileged state. + obs["privileged_state"] = jp.zeros(4) # Dummy privileged state. + pixels = { + "pixels/view_0": jp.zeros((4, 4, 3)), + "pixels/view_1": jp.zeros((4, 4, 3)), + } + + if self._obs_mode == ObservationMode.DICT_STATE: + obs = FrozenDict(obs) + elif self._obs_mode == ObservationMode.DICT_PIXELS: + obs = FrozenDict(pixels) + elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE: + obs = FrozenDict({**obs, **pixels}) + elif self._obs_mode == ObservationMode.NDARRAY: + obs = obs["state"] + reward, done = jp.array(0.0), jp.array(0.0) return State(pipeline_state, obs, reward, done) @@ -56,13 +95,26 @@ def step(self, state: State, action: jax.Array) -> State: pos = state.pipeline_state.x.pos + vel * self._dt qp = state.pipeline_state.replace( - x=state.pipeline_state.x.replace(pos=pos), - xd=state.pipeline_state.xd.replace(vel=vel), + x=state.pipeline_state.x.replace(pos=pos), + xd=state.pipeline_state.xd.replace(vel=vel), ) - obs = jp.array([pos[0], vel[0]]) - obs = {'state': obs} if self._use_dict_obs else obs + obs = {"state": jp.array([pos[0], vel[0]])} if self._asymmetric_obs: - obs['privileged_state'] = jp.zeros(4) # Dummy privileged state. + obs["privileged_state"] = jp.zeros(4) # Dummy privileged state. + pixels = { + "pixels/view_0": jp.zeros((4, 4, 3)), + "pixels/view_1": jp.zeros((4, 4, 3)), + } + + if self._obs_mode == ObservationMode.DICT_STATE: + obs = FrozenDict(obs) + elif self._obs_mode == ObservationMode.DICT_PIXELS: + obs = FrozenDict(pixels) + elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE: + obs = FrozenDict({**obs, **pixels}) + elif self._obs_mode == ObservationMode.NDARRAY: + obs = obs["state"] + reward = pos[0] return state.replace(pipeline_state=qp, obs=obs, reward=reward) @@ -77,14 +129,11 @@ def step_count(self): @property def observation_size(self): - if not self._use_dict_obs: - return 2 - - obs = {'state': 2} - if self._asymmetric_obs: - obs['privileged_state'] = 4 - - return obs + ret = super().observation_size + if self._obs_mode == ObservationMode.NDARRAY: + return ret + # Turn 1-D tuples to ints. + return {key: value[0] if len(value) == 1 else value for key, value in ret.items()} @property def action_size(self): diff --git a/brax/envs/wrappers/training.py b/brax/envs/wrappers/training.py index d4a364d2..b6f6403b 100644 --- a/brax/envs/wrappers/training.py +++ b/brax/envs/wrappers/training.py @@ -30,7 +30,7 @@ def wrap( action_repeat: int = 1, randomization_fn: Optional[ Callable[[System], Tuple[System, System]] - ] = None, + ] = None ) -> Wrapper: """Common wrapper pattern for all training agents. @@ -46,11 +46,11 @@ def wrap( environment did not already have batch dimensions, it is additional Vmap wrapped. """ - env = EpisodeWrapper(env, episode_length, action_repeat) if randomization_fn is None: env = VmapWrapper(env) else: env = DomainRandomizationVmapWrapper(env, randomization_fn) + env = EpisodeWrapper(env, episode_length, action_repeat) env = AutoResetWrapper(env) return env diff --git a/brax/training/acme/running_statistics.py b/brax/training/acme/running_statistics.py index 33c94ec2..23e4f414 100644 --- a/brax/training/acme/running_statistics.py +++ b/brax/training/acme/running_statistics.py @@ -124,7 +124,10 @@ def update(state: RunningStatisticsState, # We require exactly the same structure to avoid issues when flattened # batch and state have different order of elements. assert jax.tree_util.tree_structure(batch) == jax.tree_util.tree_structure(state.mean) - batch_shape = jax.tree_util.tree_leaves(batch)[0].shape + batch_leaves = jax.tree_util.tree_leaves(batch) + if not batch_leaves: # State and batch are both empty. Nothing to normalize. + return state + batch_shape = batch_leaves[0].shape # We assume the batch dimensions always go first. batch_dims = batch_shape[:len(batch_shape) - jax.tree_util.tree_leaves(state.mean)[0].ndim] diff --git a/brax/training/agents/ppo/networks_vision.py b/brax/training/agents/ppo/networks_vision.py new file mode 100644 index 00000000..5c156629 --- /dev/null +++ b/brax/training/agents/ppo/networks_vision.py @@ -0,0 +1,81 @@ +# Copyright 2024 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PPO vision networks.""" + +from typing import Any, Callable, Mapping, Sequence, Tuple + +import flax +from flax import linen +import jax.numpy as jp + +from brax.training import distribution +from brax.training import networks +from brax.training import types + + +ModuleDef = Any +ActivationFn = Callable[[jp.ndarray], jp.ndarray] +Initializer = Callable[..., Any] + + +@flax.struct.dataclass +class PPONetworks: + policy_network: networks.FeedForwardNetwork + value_network: networks.FeedForwardNetwork + parametric_action_distribution: distribution.ParametricDistribution + + +def make_ppo_networks_vision( + # channel_size: int, + observation_size: Mapping[str, Tuple[int, ...]], + action_size: int, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, + policy_hidden_layer_sizes: Sequence[int] = [256, 256], + value_hidden_layer_sizes: Sequence[int] = [256, 256], + activation: ActivationFn = linen.swish, + normalise_channels: bool = False, + policy_obs_key: str = "", + value_obs_key: str = "", +) -> PPONetworks: + """Make Vision PPO networks with preprocessor.""" + + parametric_action_distribution = distribution.NormalTanhDistribution( + event_size=action_size + ) + + policy_network = networks.make_policy_network_vision( + observation_size=observation_size, + output_size=parametric_action_distribution.param_size, + preprocess_observations_fn=preprocess_observations_fn, + activation=activation, + hidden_layer_sizes=policy_hidden_layer_sizes, + state_obs_key=policy_obs_key, + normalise_channels=normalise_channels, + ) + + value_network = networks.make_value_network_vision( + observation_size=observation_size, + preprocess_observations_fn=preprocess_observations_fn, + activation=activation, + hidden_layer_sizes=value_hidden_layer_sizes, + state_obs_key=value_obs_key, + normalise_channels=normalise_channels, + ) + + return PPONetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index cc84adbd..449c481a 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -19,7 +19,7 @@ import functools import time -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Mapping, Optional, Tuple, Union from absl import logging from brax import base @@ -37,6 +37,7 @@ from brax.v1 import envs as envs_v1 from etils import epath import flax +from flax.core import FrozenDict import jax import jax.numpy as jnp import numpy as np @@ -72,6 +73,47 @@ def f(leaf): return jax.tree_util.tree_map(f, tree) +def _random_translate_pixels(obs: Mapping[str, jax.Array], key: PRNGKey): + """Apply random translations to B x T x ... pixel observations. + The same shift is applied across the unroll_length (T) dimension.""" + obs = FrozenDict(obs) + + @jax.vmap + def rt_all_views(ub_obs: Mapping[str, jax.Array], key: PRNGKey) -> Mapping[str, jax.Array]: + # Expects dictionary of unbatched observations. + def rt_view(img: jax.Array, padding: int, key: PRNGKey) -> jax.Array: # TxHxWxC + # Randomly translates a set of pixel inputs. + # Adapted from https://github.com/ikostrikov/jaxrl/blob/main/jaxrl/agents/drq/augmentations.py + crop_from = jax.random.randint(key, (2,), 0, 2 * padding + 1) + zero = jnp.zeros((1,), dtype=jnp.int32) + crop_from = jnp.concatenate([zero, crop_from, zero]) + padded_img = jnp.pad( + img, ((0, 0), (padding, padding), (padding, padding), (0, 0)), mode="edge" + ) + return jax.lax.dynamic_slice(padded_img, crop_from, img.shape) + + out = {} + for k_view, v_view in ub_obs.items(): + if k_view.startswith("pixels/"): + key, key_shift = jax.random.split(key) + out[k_view] = rt_view(v_view, 4, key_shift) + ub_obs = ub_obs.copy(out) # Update the shifted fields + return ub_obs + + bdim = next(iter(obs.items()), None)[1].shape[0] + keys = jax.random.split(key, bdim) + obs = rt_all_views(obs, keys) + return obs + + +def _remove_pixels(obs: Union[jnp.ndarray, Mapping]) -> Union[jnp.ndarray, Mapping]: + """Removes pixel observations from the observation dict. + FrozenDicts are used to avoid incorrect gradients.""" + if not isinstance(obs, Mapping): + return obs + return FrozenDict({k: v for k, v in obs.items() if not k.startswith("pixels/")}) + + def train( environment: Union[envs_v1.Env, envs.Env], num_timesteps: int, @@ -108,6 +150,8 @@ def train( ] = None, restore_checkpoint_path: Optional[str] = None, max_grad_norm: Optional[float] = None, + madrona_backend: bool = False, + augment_pixels: bool = False ): """PPO training. @@ -164,6 +208,14 @@ def train( Returns: Tuple of (make_policy function, network params, metrics) """ + if madrona_backend: + if eval_env: + raise ValueError("Madrona-MJX doesn't support multiple env instances") + if num_eval_envs != num_envs: + raise ValueError("Madrona-MJX requires a fixed batch size") + if action_repeat != 1: + raise ValueError("Implement action_repeat using PipelineEnv's _n_frames to avoid unnecessary rendering!") + assert batch_size * num_minibatches % num_envs == 0 xt = time.time() @@ -225,7 +277,7 @@ def train( environment, episode_length=episode_length, action_repeat=action_repeat, - randomization_fn=v_randomization_fn, + randomization_fn=v_randomization_fn ) reset_fn = jax.jit(jax.vmap(env.reset)) @@ -285,6 +337,18 @@ def sgd_step(carry, unused_t, data: types.Transition, optimizer_state, params, key = carry key, key_perm, key_grad = jax.random.split(key, 3) + if augment_pixels: + key, key_rt = jax.random.split(key) + r_translate = functools.partial(_random_translate_pixels, key=key_rt) + data = types.Transition( + observation=r_translate(data.observation), + action=data.action, + reward=data.reward, + discount=data.discount, + next_observation=r_translate(data.next_observation), + extras=data.extras + ) + def convert_data(x: jnp.ndarray): x = jax.random.permutation(key_perm, x) x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) @@ -334,8 +398,9 @@ def f(carry, unused_t): # Update normalization params and normalize observations. normalizer_params = running_statistics.update( training_state.normalizer_params, - data.observation, - pmap_axis_name=_PMAP_AXIS_NAME) + _remove_pixels(data.observation), + pmap_axis_name=_PMAP_AXIS_NAME + ) (optimizer_state, params, _), metrics = jax.lax.scan( functools.partial( @@ -397,7 +462,7 @@ def training_epoch_with_timing( training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars params=init_params, - normalizer_params=running_statistics.init_state(obs_shape), + normalizer_params=running_statistics.init_state(_remove_pixels(obs_shape)), env_steps=0) if ( diff --git a/brax/training/agents/ppo/train_test.py b/brax/training/agents/ppo/train_test.py index ec8669bf..9eb6ca0a 100644 --- a/brax/training/agents/ppo/train_test.py +++ b/brax/training/agents/ppo/train_test.py @@ -21,19 +21,18 @@ from brax import envs from brax.training.acme import running_statistics from brax.training.agents.ppo import networks as ppo_networks +from brax.training.agents.ppo import networks_vision as ppo_networks_vision from brax.training.agents.ppo import train as ppo import jax -import jax.numpy as jp class PPOTest(parameterized.TestCase): """Tests for PPO module.""" - - @parameterized.parameters(True, False) - def testTrain(self, use_dict_obs): + @parameterized.parameters("ndarray", "dict_state") + def testTrain(self, obs_mode): """Test PPO with a simple env.""" - fast = envs.get_environment('fast', use_dict_obs=use_dict_obs) + fast = envs.get_environment('fast', obs_mode=obs_mode) _, _, metrics = ppo.train( fast, num_timesteps=2**15, @@ -75,46 +74,6 @@ def testTrainV2(self): normalize_advantage=False, ) - def testTrainAsymmetricActorCritic(self): - """Test PPO with asymmetric actor critic.""" - env = envs.get_environment('fast', asymmetric_obs=True, use_dict_obs=True) - - network_factory = functools.partial( - ppo_networks.make_ppo_networks, - policy_hidden_layer_sizes=(32,), - value_hidden_layer_sizes=(32,), - policy_obs_key='state', - value_obs_key='privileged_state', - ) - - _, (_, policy_params, value_params), _ = ppo.train( - env, - num_timesteps=2**15, - episode_length=1000, - num_envs=64, - learning_rate=3e-4, - entropy_cost=1e-2, - discounting=0.95, - unroll_length=5, - batch_size=64, - num_minibatches=8, - num_updates_per_batch=4, - normalize_observations=False, - seed=2, - reward_scaling=10, - normalize_advantage=False, - network_factory=network_factory, - ) - - self.assertEqual( - policy_params['params']['hidden_0']['kernel'].shape, - (env.observation_size['state'], 32), - ) - self.assertEqual( - value_params['params']['hidden_0']['kernel'].shape, - (env.observation_size['privileged_state'], 32), - ) - @parameterized.parameters(True, False) def testNetworkEncoding(self, normalize_observations): env = envs.get_environment('fast') @@ -175,6 +134,111 @@ def get_offset(rng): randomization_fn=rand_fn, ) + def testTrainAsymmetricActorCritic(self): + """Test PPO with asymmetric actor critic.""" + env = envs.get_environment('fast', asymmetric_obs=True, obs_mode="dict_state") + + network_factory = functools.partial( + ppo_networks.make_ppo_networks, + policy_hidden_layer_sizes=(32,), + value_hidden_layer_sizes=(32,), + policy_obs_key='state', + value_obs_key='privileged_state' + ) + + _, (_, policy_params, value_params), _ = ppo.train( + env, + num_timesteps=2**15, + episode_length=1000, + num_envs=64, + learning_rate=3e-4, + entropy_cost=1e-2, + discounting=0.95, + unroll_length=5, + batch_size=64, + num_minibatches=8, + num_updates_per_batch=4, + normalize_observations=False, + seed=2, + reward_scaling=10, + normalize_advantage=False, + network_factory=network_factory, + ) + + self.assertEqual( + policy_params['params']['hidden_0']['kernel'].shape, + (env.observation_size['state'], 32), + ) + self.assertEqual( + value_params['params']['hidden_0']['kernel'].shape, + (env.observation_size['privileged_state'], 32), + ) + + @parameterized.parameters( + {"asymmetric_obs": True, "obs_mode": "dict_pixels_state"}, + {"asymmetric_obs": False, "obs_mode": "dict_pixels_state"}, + {"asymmetric_obs": False, "obs_mode": "dict_pixels"}, + ) + def testPixelsPPO(self, asymmetric_obs, obs_mode): + """Test PPO with pixel observations.""" + env = envs.get_environment( + "fast", + pixel_obs=True, + asymmetric_obs=asymmetric_obs, + obs_mode=obs_mode, + ) + if obs_mode == "dict_pixels": + policy_obs_key = "" + value_obs_key = "" + else: + policy_obs_key = "state" + value_obs_key = "privileged_state" if asymmetric_obs else "state" + + network_factory = functools.partial( + ppo_networks_vision.make_ppo_networks_vision, + policy_hidden_layer_sizes=(32,), + value_hidden_layer_sizes=(32,), + policy_obs_key=policy_obs_key, + value_obs_key=value_obs_key, + ) + + _, (_, policy_params, value_params), _ = ppo.train( + env, + num_timesteps=2**15, + episode_length=1000, + num_envs=64, + learning_rate=3e-4, + entropy_cost=1e-2, + discounting=0.95, + unroll_length=5, + batch_size=64, + num_minibatches=8, + num_updates_per_batch=4, + normalize_observations=True, + seed=2, + reward_scaling=10, + normalize_advantage=False, + network_factory=network_factory, + augment_pixels=True, + ) + num_views = 2 + cnn_features = 64 + + if asymmetric_obs: + self.assertEqual( + policy_params["params"]["MLP_0"]["hidden_0"]["kernel"].shape, + (num_views * cnn_features + env.observation_size["state"], 32), + ) + self.assertEqual( + value_params["params"]["MLP_0"]["hidden_0"]["kernel"].shape, + (num_views * cnn_features + env.observation_size["privileged_state"], 32), + ) + if obs_mode == "dict_pixels": + self.assertEqual( + policy_params["params"]["MLP_0"]["hidden_0"]["kernel"].shape, + (num_views * cnn_features, 32), + ) + if __name__ == '__main__': absltest.main() diff --git a/brax/training/agents/sac/train_test.py b/brax/training/agents/sac/train_test.py index c603d138..a03e7b08 100644 --- a/brax/training/agents/sac/train_test.py +++ b/brax/training/agents/sac/train_test.py @@ -46,7 +46,7 @@ def testTrain(self): num_evals=3, seed=0) self.assertGreater(metrics['eval/episode_reward'], 140 * 0.995) - self.assertEqual(fast.reset_count, 2) # type: ignore + self.assertEqual(fast.reset_count, 3) # type: ignore # once for prefill, once for train, once for eval self.assertEqual(fast.step_count, 3) # type: ignore diff --git a/brax/training/networks.py b/brax/training/networks.py index d7c69fe7..a3069e4e 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -15,15 +15,20 @@ """Network definitions.""" import dataclasses +import functools from typing import Any, Callable, Mapping, Sequence, Tuple import warnings -from brax.training import types -from brax.training.spectral_norm import SNDense from flax import linen +from flax.core import FrozenDict import jax import jax.numpy as jnp +from brax.training import types +from brax.training.acme.running_statistics import RunningStatisticsState +from brax.training.spectral_norm import SNDense + + ActivationFn = Callable[[jnp.ndarray], jnp.ndarray] Initializer = Callable[..., Any] @@ -83,11 +88,88 @@ def __call__(self, data: jnp.ndarray): return hidden +class CNN(linen.Module): + """CNN module. Inputs are expected in Batch * HWC format.""" + + num_filters: Sequence[int] + kernel_sizes: Sequence[Tuple] + strides: Sequence[Tuple] + activation: ActivationFn = linen.relu + use_bias: bool = True + + @linen.compact + def __call__(self, data: jnp.ndarray): + hidden = data + for i, (num_filter, kernel_size, stride) in enumerate( + zip(self.num_filters, self.kernel_sizes, self.strides) + ): + hidden = linen.Conv( + num_filter, kernel_size=kernel_size, strides=stride, use_bias=self.use_bias + )(hidden) + + hidden = self.activation(hidden) + return hidden + + +class VisionMLP(linen.Module): + """ + Applies a CNN backbone then an MLP. + + The CNN architecture originates from the paper: + "Human-level control through deep reinforcement learning", + Nature 518, no. 7540 (2015): 529-533 + """ + layer_sizes: Sequence[int] + activation: ActivationFn = linen.relu + kernel_init: Initializer = jax.nn.initializers.lecun_uniform() + activate_final: bool = False + layer_norm: bool = False + normalise_channels: bool = False + state_obs_key: str = "" + + @linen.compact + def __call__(self, data: dict): + pixels_hidden = {k: v for k, v in data.items() if k.startswith("pixels/")} + if self.normalise_channels: + # Calculates shared statistics over an entire 2D image. + image_layernorm = functools.partial( + linen.LayerNorm, use_bias=False, use_scale=False, reduction_axes=(-1, -2) + ) + def ln_per_chan(v: jax.Array): + normalised = [image_layernorm()(v[..., chan]) for chan in range(v.shape[-1])] + return jnp.stack(normalised, axis=-1) + + pixels_hidden = jax.tree.map(ln_per_chan, pixels_hidden) + + natureCNN = functools.partial( + CNN, + num_filters=[32, 64, 64], + kernel_sizes=[(8, 8), (4, 4), (3, 3)], + strides=[(4, 4), (2, 2), (1, 1)], + activation=linen.relu, + use_bias=False, + ) + cnn_outs = [natureCNN()(pixels_hidden[key]) for key in pixels_hidden] + cnn_outs = [jnp.mean(cnn_out, axis=(-2, -3)) for cnn_out in cnn_outs] + if self.state_obs_key: + cnn_outs.append( + data[self.state_obs_key] + ) # TODO: Try with dedicated state network + + hidden = jnp.concatenate(cnn_outs, axis=-1) + return MLP( + layer_sizes=self.layer_sizes, + activation=self.activation, + kernel_init=self.kernel_init, + activate_final=self.activate_final, + layer_norm=self.layer_norm, + )(hidden) + + def _get_obs_state_size(obs_size: types.ObservationSize, obs_key: str) -> int: obs_size = obs_size[obs_key] if isinstance(obs_size, Mapping) else obs_size return jax.tree_util.tree_flatten(obs_size)[0][-1] - def make_policy_network( param_size: int, obs_size: types.ObservationSize, @@ -140,6 +222,83 @@ def apply(processor_params, value_params, obs): init=lambda key: value_module.init(key, dummy_obs), apply=apply) +def normalizer_select( + processor_params: RunningStatisticsState, obs_key: str +) -> RunningStatisticsState: + return RunningStatisticsState( + count=processor_params, + mean=processor_params.mean[obs_key], + summed_variance=processor_params.summed_variance[obs_key], + std=processor_params.std[obs_key], + ) + +def make_policy_network_vision( + observation_size: Mapping[str, Tuple[int, ...]], + output_size: int, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, + hidden_layer_sizes: Sequence[int] = [256, 256], + activation: ActivationFn = linen.swish, + kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), + layer_norm: bool = False, + state_obs_key: str = "", + normalise_channels: bool = False, +) -> FeedForwardNetwork: + module = VisionMLP( + layer_sizes=list(hidden_layer_sizes) + [output_size], + activation=activation, + kernel_init=kernel_init, + layer_norm=layer_norm, + normalise_channels=normalise_channels, + state_obs_key=state_obs_key, + ) + + def apply(processor_params, policy_params, obs): + obs = FrozenDict(obs) + if state_obs_key: + state_obs = preprocess_observations_fn( + obs[state_obs_key], normalizer_select(processor_params, state_obs_key) + ) + obs = obs.copy({state_obs_key: state_obs}) + return module.apply(policy_params, obs) + + dummy_obs = {key: jnp.zeros((1,) + shape) for key, shape in observation_size.items()} + return FeedForwardNetwork( + init=lambda key: module.init(key, dummy_obs), apply=apply + ) + + +def make_value_network_vision( + observation_size: Mapping[str, Tuple[int, ...]], + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, + hidden_layer_sizes: Sequence[int] = [256, 256], + activation: ActivationFn = linen.swish, + kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), + state_obs_key: str = "", + normalise_channels: bool = False, +) -> FeedForwardNetwork: + value_module = VisionMLP( + layer_sizes=list(hidden_layer_sizes) + [1], + activation=activation, + kernel_init=kernel_init, + normalise_channels=normalise_channels, + state_obs_key=state_obs_key, + ) + + def apply(processor_params, policy_params, obs): + obs = FrozenDict(obs) + if state_obs_key: + state_obs = preprocess_observations_fn( + obs[state_obs_key], normalizer_select(processor_params, state_obs_key) + ) + obs = obs.copy({state_obs_key: state_obs}) + return jnp.squeeze(value_module.apply(policy_params, obs), axis=-1) + + dummy_obs = {key: jnp.zeros((1,) + shape) for key, shape in observation_size.items()} + return FeedForwardNetwork( + init=lambda key: value_module.init(key, dummy_obs), apply=apply + ) + + def make_q_network( obs_size: types.ObservationSize, action_size: int,