Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703632132
Change-Id: Idf7a3bd72c0097a9028ca537186c08b03a1969ce
  • Loading branch information
Brax Team authored and btaba committed Dec 6, 2024
1 parent 8becede commit c59da3f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
3 changes: 0 additions & 3 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from brax.v1 import envs as envs_v1
from etils import epath
import flax
from flax import core
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -89,8 +88,6 @@ def _random_translate_pixels(
Returns:
A dictionary of observations with translated pixels
"""
obs = core.FrozenDict(obs)

@jax.vmap
def rt_all_views(
ub_obs: Mapping[str, jax.Array], key: PRNGKey
Expand Down
6 changes: 2 additions & 4 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from brax.training import types
from brax.training.acme import running_statistics
from brax.training.spectral_norm import SNDense
from flax import core
from flax import linen
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -370,12 +369,11 @@ def make_policy_network_vision(
)

def apply(processor_params, policy_params, obs):
obs = core.FrozenDict(obs)
if state_obs_key:
state_obs = preprocess_observations_fn(
obs[state_obs_key], normalizer_select(processor_params, state_obs_key)
)
obs = core.copy(obs, {state_obs_key: state_obs})
obs = {**obs, state_obs_key: state_obs}
return module.apply(policy_params, obs)

dummy_obs = {
Expand Down Expand Up @@ -409,7 +407,7 @@ def apply(processor_params, policy_params, obs):
state_obs = preprocess_observations_fn(
obs[state_obs_key], normalizer_select(processor_params, state_obs_key)
)
obs = core.copy(obs, {state_obs_key: state_obs})
obs = {**obs, state_obs_key: state_obs}
return jnp.squeeze(value_module.apply(policy_params, obs), axis=-1)

dummy_obs = {
Expand Down

0 comments on commit c59da3f

Please sign in to comment.