-
Notifications
You must be signed in to change notification settings - Fork 261
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial commit of vision PPO networks * implement vision wrappers * change ppo loss and the autoresetwrapper to support dictionary-valued observations * add random image shifts * support normalising observations, clean up train_pixels.py * vision networks * fix bug in state normalisation * add channel-wise layer norm in CNN * remove old file * clean up imports * enforce FrozenDict to avoid incorrect gradients * refactor the vision wrappers as flags in envs.training wrappers * support asymmetric actor critic on pixels, clean up normalisation logic * rename networks files * write basic pixels ppo test, make remove_pixels() check for non-dict obs * update test for ppo on pixels to test pixel-only observations and cast to frozen dict (does not decrease performance) * fix bug for aac on pixels * remove old file * linting * clean up logic for toy testing env * small code placement and logic clean-up * for vision networks, only normalize as needed * move vision networks around * remove scan parameter for wrapping but switch wrapping order * linting * add acknowledgement * replace boolean args to testing env with obs_mode enum * write docstring for toy testing env and clean up * make pixels functions private * update sac test --------- Co-authored-by: Mustafa <[email protected]>
- Loading branch information
1 parent
417465c
commit 68906bc
Showing
8 changed files
with
506 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright 2024 The Brax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""PPO vision networks.""" | ||
|
||
from typing import Any, Callable, Mapping, Sequence, Tuple | ||
|
||
import flax | ||
from flax import linen | ||
import jax.numpy as jp | ||
|
||
from brax.training import distribution | ||
from brax.training import networks | ||
from brax.training import types | ||
|
||
|
||
ModuleDef = Any | ||
ActivationFn = Callable[[jp.ndarray], jp.ndarray] | ||
Initializer = Callable[..., Any] | ||
|
||
|
||
@flax.struct.dataclass | ||
class PPONetworks: | ||
policy_network: networks.FeedForwardNetwork | ||
value_network: networks.FeedForwardNetwork | ||
parametric_action_distribution: distribution.ParametricDistribution | ||
|
||
|
||
def make_ppo_networks_vision( | ||
# channel_size: int, | ||
observation_size: Mapping[str, Tuple[int, ...]], | ||
action_size: int, | ||
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, | ||
policy_hidden_layer_sizes: Sequence[int] = [256, 256], | ||
value_hidden_layer_sizes: Sequence[int] = [256, 256], | ||
activation: ActivationFn = linen.swish, | ||
normalise_channels: bool = False, | ||
policy_obs_key: str = "", | ||
value_obs_key: str = "", | ||
) -> PPONetworks: | ||
"""Make Vision PPO networks with preprocessor.""" | ||
|
||
parametric_action_distribution = distribution.NormalTanhDistribution( | ||
event_size=action_size | ||
) | ||
|
||
policy_network = networks.make_policy_network_vision( | ||
observation_size=observation_size, | ||
output_size=parametric_action_distribution.param_size, | ||
preprocess_observations_fn=preprocess_observations_fn, | ||
activation=activation, | ||
hidden_layer_sizes=policy_hidden_layer_sizes, | ||
state_obs_key=policy_obs_key, | ||
normalise_channels=normalise_channels, | ||
) | ||
|
||
value_network = networks.make_value_network_vision( | ||
observation_size=observation_size, | ||
preprocess_observations_fn=preprocess_observations_fn, | ||
activation=activation, | ||
hidden_layer_sizes=value_hidden_layer_sizes, | ||
state_obs_key=value_obs_key, | ||
normalise_channels=normalise_channels, | ||
) | ||
|
||
return PPONetworks( | ||
policy_network=policy_network, | ||
value_network=value_network, | ||
parametric_action_distribution=parametric_action_distribution, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.