From 6a62109907ca174b6a88cd0eda4a675af12d33ee Mon Sep 17 00:00:00 2001 From: Brax Team Date: Mon, 14 Oct 2024 16:15:17 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 685868876 Change-Id: Ied7d20aa50889b6b611fc355a56f084ab212b125 --- brax/training/agents/apg/train.py | 53 +++++++++++++------------- brax/training/agents/ars/train.py | 54 ++++++++++++++------------- brax/training/agents/es/train.py | 53 +++++++++++++------------- brax/training/agents/ppo/train.py | 60 ++++++++++++++++-------------- brax/training/agents/sac/train.py | 59 +++++++++++++++-------------- docs/release-notes/next-release.md | 4 +- 6 files changed, 150 insertions(+), 133 deletions(-) diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index 1743869d..351a7166 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -58,6 +58,7 @@ def train( environment: Union[envs_v1.Env, envs.Env], episode_length: int, policy_updates: int, + wrap_env: bool = True, horizon_length: int = 32, num_envs: int = 1, num_evals: int = 1, @@ -102,29 +103,30 @@ def train( updates_per_epoch = jnp.round(num_updates / (num_evals_after_init)) assert num_envs % device_count == 0 - env = environment - if isinstance(env, envs.Env): - wrap_for_training = envs.training.wrap - else: - wrap_for_training = envs_v1.wrappers.wrap_for_training - key = jax.random.PRNGKey(seed) global_key, local_key = jax.random.split(key) rng, global_key = jax.random.split(global_key, 2) local_key = jax.random.fold_in(local_key, process_id) local_key, eval_key = jax.random.split(local_key) - v_randomiation_fn = None - if randomization_fn is not None: - v_randomiation_fn = functools.partial( - randomization_fn, rng=jax.random.split(rng, num_envs // process_count) + env = environment + if wrap_env: + if isinstance(env, envs.Env): + wrap_for_training = envs.training.wrap + else: + wrap_for_training = envs_v1.wrappers.wrap_for_training + + v_randomization_fn = None + if randomization_fn is not None: + v_randomization_fn = functools.partial( + randomization_fn, rng=jax.random.split(rng, num_envs // process_count) + ) + env = wrap_for_training( + env, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - env = wrap_for_training( - env, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomiation_fn, - ) reset_fn = jax.jit(jax.vmap(env.reset)) step_fn = jax.jit(jax.vmap(env.step)) @@ -298,16 +300,17 @@ def training_epoch_with_timing( if not eval_env: eval_env = environment - if randomization_fn is not None: - v_randomiation_fn = functools.partial( - randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + if wrap_env: + if randomization_fn is not None: + v_randomization_fn = functools.partial( + randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + ) + eval_env = wrap_for_training( + eval_env, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - eval_env = wrap_for_training( - eval_env, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomiation_fn, - ) evaluator = acting.Evaluator( eval_env, diff --git a/brax/training/agents/ars/train.py b/brax/training/agents/ars/train.py index 93c48481..70d8a003 100644 --- a/brax/training/agents/ars/train.py +++ b/brax/training/agents/ars/train.py @@ -52,6 +52,7 @@ class TrainingState: # TODO: Pass the network as argument. def train( environment: Union[envs_v1.Env, envs.Env], + wrap_env: bool = True, num_timesteps: int = 100, episode_length: int = 1000, action_repeat: int = 1, @@ -98,23 +99,24 @@ def train( assert num_envs % local_devices_to_use == 0 env = environment - if isinstance(env, envs.Env): - wrap_for_training = envs.training.wrap - else: - wrap_for_training = envs_v1.wrappers.wrap_for_training - - v_randomization_fn = None - if randomization_fn is not None: - v_randomization_fn = functools.partial( - randomization_fn, - rng=jax.random.split(rng_key, num_envs // local_devices_to_use), + if wrap_env: + if isinstance(env, envs.Env): + wrap_for_training = envs.training.wrap + else: + wrap_for_training = envs_v1.wrappers.wrap_for_training + + v_randomization_fn = None + if randomization_fn is not None: + v_randomization_fn = functools.partial( + randomization_fn, + rng=jax.random.split(rng_key, num_envs // local_devices_to_use), + ) + env = wrap_for_training( + env, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - env = wrap_for_training( - env, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomization_fn, - ) obs_size = env.observation_size @@ -273,17 +275,17 @@ def training_epoch_with_timing(training_state: TrainingState, if not eval_env: eval_env = environment - - if randomization_fn is not None: - v_randomization_fn = functools.partial( - randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + if wrap_env: + if randomization_fn is not None: + v_randomization_fn = functools.partial( + randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + ) + eval_env = wrap_for_training( + eval_env, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - eval_env = wrap_for_training( - eval_env, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomization_fn, - ) # Evaluator function evaluator = acting.Evaluator( diff --git a/brax/training/agents/es/train.py b/brax/training/agents/es/train.py index 8204015e..7351f3dc 100644 --- a/brax/training/agents/es/train.py +++ b/brax/training/agents/es/train.py @@ -75,6 +75,7 @@ class FitnessShaping(enum.Enum): # TODO: Pass the network as argument. def train( environment: Union[envs_v1.Env, envs.Env], + wrap_env: bool = True, num_timesteps: int = 100, episode_length: int = 1000, action_repeat: int = 1, @@ -125,23 +126,24 @@ def train( assert num_envs % local_devices_to_use == 0 env = environment - if isinstance(env, envs.Env): - wrap_for_training = envs.training.wrap - else: - wrap_for_training = envs_v1.wrappers.wrap_for_training - - v_randomization_fn = None - if randomization_fn is not None: - v_randomization_fn = functools.partial( - randomization_fn, - rng=jax.random.split(rng_key, num_envs // local_devices_to_use), + if wrap_env: + if isinstance(env, envs.Env): + wrap_for_training = envs.training.wrap + else: + wrap_for_training = envs_v1.wrappers.wrap_for_training + + v_randomization_fn = None + if randomization_fn is not None: + v_randomization_fn = functools.partial( + randomization_fn, + rng=jax.random.split(rng_key, num_envs // local_devices_to_use), + ) + env = wrap_for_training( + env, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - env = wrap_for_training( - env, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomization_fn, - ) obs_size = env.observation_size @@ -325,16 +327,17 @@ def training_epoch_with_timing(training_state: TrainingState, if not eval_env: eval_env = environment - if randomization_fn is not None: - v_randomization_fn = functools.partial( - randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + if wrap_env: + if randomization_fn is not None: + v_randomization_fn = functools.partial( + randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + ) + eval_env = wrap_for_training( + eval_env, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - eval_env = wrap_for_training( - eval_env, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomization_fn, - ) # Evaluator function evaluator = acting.Evaluator( diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 90ce5cd4..00cc0186 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -76,6 +76,7 @@ def train( environment: Union[envs_v1.Env, envs.Env], num_timesteps: int, episode_length: int, + wrap_env: bool = True, action_repeat: int = 1, num_envs: int = 1, max_devices_per_host: Optional[int] = None, @@ -113,6 +114,8 @@ def train( environment: the environment to train num_timesteps: the total number of environment steps to use during training episode_length: the length of an environment episode + wrap_env: If True, wrap the environment for training. Otherwise use the + environment as is. action_repeat: the number of timesteps to repeat an action num_envs: the number of parallel environments to use for rollouts NOTE: `num_envs` must be divisible by the total number of chips since each @@ -202,27 +205,27 @@ def train( assert num_envs % device_count == 0 - v_randomization_fn = None - if randomization_fn is not None: - randomization_batch_size = num_envs // local_device_count - # all devices gets the same randomization rng - randomization_rng = jax.random.split(key_env, randomization_batch_size) - v_randomization_fn = functools.partial( - randomization_fn, rng=randomization_rng + env = environment + if wrap_env: + v_randomization_fn = None + if randomization_fn is not None: + randomization_batch_size = num_envs // local_device_count + # all devices gets the same randomization rng + randomization_rng = jax.random.split(key_env, randomization_batch_size) + v_randomization_fn = functools.partial( + randomization_fn, rng=randomization_rng + ) + if isinstance(environment, envs.Env): + wrap_for_training = envs.training.wrap + else: + wrap_for_training = envs_v1.wrappers.wrap_for_training + env = wrap_for_training( + environment, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - if isinstance(environment, envs.Env): - wrap_for_training = envs.training.wrap - else: - wrap_for_training = envs_v1.wrappers.wrap_for_training - - env = wrap_for_training( - environment, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomization_fn, - ) - reset_fn = jax.jit(jax.vmap(env.reset)) key_envs = jax.random.split(key_env, num_envs // process_count) key_envs = jnp.reshape(key_envs, @@ -409,16 +412,17 @@ def training_epoch_with_timing( if not eval_env: eval_env = environment - if randomization_fn is not None: - v_randomization_fn = functools.partial( - randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + if wrap_env: + if randomization_fn is not None: + v_randomization_fn = functools.partial( + randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + ) + eval_env = wrap_for_training( + eval_env, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - eval_env = wrap_for_training( - eval_env, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomization_fn, - ) evaluator = acting.Evaluator( eval_env, diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index 55713eba..64053a81 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -108,6 +108,7 @@ def train( environment: Union[envs_v1.Env, envs.Env], num_timesteps, episode_length: int, + wrap_env: bool = True, action_repeat: int = 1, num_envs: int = 1, num_eval_envs: int = 128, @@ -167,26 +168,27 @@ def train( assert num_envs % device_count == 0 env = environment - if isinstance(env, envs.Env): - wrap_for_training = envs.training.wrap - else: - wrap_for_training = envs_v1.wrappers.wrap_for_training - - rng = jax.random.PRNGKey(seed) - rng, key = jax.random.split(rng) - v_randomization_fn = None - if randomization_fn is not None: - v_randomization_fn = functools.partial( - randomization_fn, - rng=jax.random.split( - key, num_envs // jax.process_count() // local_devices_to_use), + if wrap_env: + if isinstance(env, envs.Env): + wrap_for_training = envs.training.wrap + else: + wrap_for_training = envs_v1.wrappers.wrap_for_training + + rng = jax.random.PRNGKey(seed) + rng, key = jax.random.split(rng) + v_randomization_fn = None + if randomization_fn is not None: + v_randomization_fn = functools.partial( + randomization_fn, + rng=jax.random.split( + key, num_envs // jax.process_count() // local_devices_to_use), + ) + env = wrap_for_training( + env, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - env = wrap_for_training( - env, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomization_fn, - ) obs_size = env.observation_size action_size = env.action_size @@ -431,16 +433,17 @@ def training_epoch_with_timing( if not eval_env: eval_env = environment - if randomization_fn is not None: - v_randomization_fn = functools.partial( - randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + if wrap_env: + if randomization_fn is not None: + v_randomization_fn = functools.partial( + randomization_fn, rng=jax.random.split(eval_key, num_eval_envs) + ) + eval_env = wrap_for_training( + eval_env, + episode_length=episode_length, + action_repeat=action_repeat, + randomization_fn=v_randomization_fn, ) - eval_env = wrap_for_training( - eval_env, - episode_length=episode_length, - action_repeat=action_repeat, - randomization_fn=v_randomization_fn, - ) evaluator = acting.Evaluator( eval_env, diff --git a/docs/release-notes/next-release.md b/docs/release-notes/next-release.md index cbefb69c..60a8358e 100644 --- a/docs/release-notes/next-release.md +++ b/docs/release-notes/next-release.md @@ -1 +1,3 @@ -# Brax Release Notes \ No newline at end of file +# Brax Release Notes + +* Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is. \ No newline at end of file