Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated basic APG algorithm #476

Merged
merged 7 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions brax/training/agents/apg/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from brax.training.types import PRNGKey
import flax
from flax import linen
from flax.linen.initializers import orthogonal


@flax.struct.dataclass
Expand Down Expand Up @@ -55,15 +56,18 @@ def make_apg_networks(
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (32,) * 4,
activation: networks.ActivationFn = linen.swish) -> APGNetworks:
activation: networks.ActivationFn = linen.elu,
layer_norm: bool = True) -> APGNetworks:
"""Make APG networks."""
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size)
event_size=action_size, var_scale=0.1)
policy_network = networks.make_policy_network(
parametric_action_distribution.param_size,
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=hidden_layer_sizes, activation=activation)
hidden_layer_sizes=hidden_layer_sizes, activation=activation,
kernel_init = orthogonal(0.01),
layer_norm=layer_norm)
return APGNetworks(
policy_network=policy_network,
parametric_action_distribution=parametric_action_distribution)
118 changes: 83 additions & 35 deletions brax/training/agents/apg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from brax import base
from brax import envs
from brax.training import acting
from brax.training import gradients
from brax.training import pmap
from brax.training import types
from brax.training.acme import running_statistics
Expand Down Expand Up @@ -56,15 +57,20 @@ def _unpmap(v):
def train(
environment: Union[envs_v1.Env, envs.Env],
episode_length: int,
action_repeat: int = 1,
policy_updates: int,
horizon_length: int = 32,
num_envs: int = 1,
num_evals: int = 1,
action_repeat: int = 1,
max_devices_per_host: Optional[int] = None,
num_eval_envs: int = 128,
learning_rate: float = 1e-4,
adam_b: list = [0.7, 0.95],
use_schedule: bool = True,
use_float64: bool = True,
schedule_decay: float = 0.997,
seed: int = 0,
truncation_length: Optional[int] = None,
max_gradient_norm: float = 1e9,
num_evals: int = 1,
normalize_observations: bool = False,
deterministic_eval: bool = False,
network_factory: types.NetworkFactory[
Expand All @@ -91,10 +97,9 @@ def train(
process_id, local_device_count, local_devices_to_use)
device_count = local_devices_to_use * process_count

if truncation_length is not None:
assert truncation_length > 0

num_updates = policy_updates
num_evals_after_init = max(num_evals - 1, 1)
updates_per_epoch = jnp.round(num_updates / (num_evals_after_init))

assert num_envs % device_count == 0
env = environment
Expand All @@ -120,6 +125,9 @@ def train(
action_repeat=action_repeat,
randomization_fn=v_randomiation_fn,
)

reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))

normalize = lambda x, y: x
if normalize_observations:
Expand All @@ -129,8 +137,24 @@ def train(
env.action_size,
preprocess_observations_fn=normalize)
make_policy = apg_networks.make_inference_fn(apg_network)

if use_schedule:
learning_rate = optax.exponential_decay(
init_value=learning_rate,
transition_steps=1,
decay_rate=schedule_decay
)

optimizer = optax.chain(
optax.clip(1.0),
optax.adam(learning_rate=learning_rate, b1=adam_b[0], b2=adam_b[1])
)

optimizer = optax.adam(learning_rate=learning_rate)
def scramble_times(state, key):
state.info['steps'] = jnp.round(
jax.random.uniform(key, (local_devices_to_use, num_envs,),
maxval=episode_length))
return state

def env_step(
carry: Tuple[Union[envs.State, envs_v1.State], PRNGKey],
Expand All @@ -141,23 +165,17 @@ def env_step(
key, key_sample = jax.random.split(key)
actions = policy(env_state.obs, key_sample)[0]
nstate = env.step(env_state, actions)
if truncation_length is not None:
nstate = jax.lax.cond(
jnp.mod(step_index + 1, truncation_length) == 0.,
jax.lax.stop_gradient, lambda x: x, nstate)

return (nstate, key), (nstate.reward, env_state.obs)

def loss(policy_params, normalizer_params, key):
key_reset, key_scan = jax.random.split(key)
env_state = env.reset(
jax.random.split(key_reset, num_envs // process_count))
def loss(policy_params, normalizer_params, env_state, key):
f = functools.partial(
env_step, policy=make_policy((normalizer_params, policy_params)))
(rewards,
obs) = jax.lax.scan(f, (env_state, key_scan),
(jnp.array(range(episode_length // action_repeat))))[1]
return -jnp.mean(rewards), obs
(state_h, _), (rewards,
obs) = jax.lax.scan(f, (env_state, key),
(jnp.arange(horizon_length // action_repeat)))

return -jnp.mean(rewards), (obs, state_h)

loss_grad = jax.grad(loss, has_aux=True)

Expand All @@ -168,62 +186,83 @@ def clip_by_global_norm(updates):
lambda t: jnp.where(trigger, t, (t / g_norm) * max_gradient_norm),
updates)

def training_epoch(training_state: TrainingState, key: PRNGKey):
def minibatch_step(
carry, epoch_step_index: int):
(optimizer_state, normalizer_params,
policy_params, key, state) = carry

key, key_grad = jax.random.split(key)
grad, obs = loss_grad(training_state.policy_params,
training_state.normalizer_params, key_grad)
grad, (obs, state_h) = loss_grad(policy_params,
normalizer_params,
state,
key_grad)

grad = clip_by_global_norm(grad)
grad = jax.lax.pmean(grad, axis_name='i')
params_update, optimizer_state = optimizer.update(
grad, training_state.optimizer_state)
policy_params = optax.apply_updates(training_state.policy_params,
grad, optimizer_state)
policy_params = optax.apply_updates(policy_params,
params_update)

normalizer_params = running_statistics.update(
training_state.normalizer_params, obs, pmap_axis_name=_PMAP_AXIS_NAME)
normalizer_params, obs, pmap_axis_name=_PMAP_AXIS_NAME)

metrics = {
'grad_norm': optax.global_norm(grad),
'params_norm': optax.global_norm(policy_params)
}

return (optimizer_state, normalizer_params, policy_params, key, state_h), metrics

def training_epoch(training_state: TrainingState, env_state: Union[envs.State, envs_v1.State], key: PRNGKey):

(optimizer_state, normalizer_params,
policy_params, key, state_h), metrics = jax.lax.scan(
minibatch_step,
(training_state.optimizer_state, training_state.normalizer_params,
training_state.policy_params, key, env_state),
jnp.arange(updates_per_epoch))

return TrainingState(
optimizer_state=optimizer_state,
normalizer_params=normalizer_params,
policy_params=policy_params), metrics
policy_params=policy_params), state_h, metrics, key

training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)

training_walltime = 0

# Note that this is NOT a pure jittable method.
def training_epoch_with_timing(training_state: TrainingState,
env_state: Union[envs.State, envs_v1.State],
key: PRNGKey) -> Tuple[TrainingState, Metrics]:
nonlocal training_walltime
t = time.time()
(training_state, metrics) = training_epoch(training_state, key)
(training_state, env_state, metrics, key) = training_epoch(training_state, env_state, key)
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)

epoch_training_time = time.time() - t
training_walltime += epoch_training_time
sps = (episode_length * num_envs) / epoch_training_time
sps = (updates_per_epoch * num_envs * horizon_length) / epoch_training_time
metrics = {
'training/sps': sps,
'training/walltime': training_walltime,
**{f'training/{name}': value for name, value in metrics.items()}
}
return training_state, metrics # pytype: disable=bad-return-type # py311-upgrade
return training_state, env_state, metrics, key # pytype: disable=bad-return-type # py311-upgrade

# The network key should be global, so that networks are initialized the same
# way for different processes.
policy_params = apg_network.policy_network.init(global_key)
del global_key

dtype = 'float64' if use_float64 else 'float32'
training_state = TrainingState(
optimizer_state=optimizer.init(policy_params),
policy_params=policy_params,
normalizer_params=running_statistics.init_state(
specs.Array((env.observation_size,), jnp.dtype('float32'))))
specs.Array((env.observation_size,), jnp.dtype(dtype))))
training_state = jax.device_put_replicated(
training_state,
jax.local_devices()[:local_devices_to_use])
Expand Down Expand Up @@ -251,6 +290,7 @@ def training_epoch_with_timing(training_state: TrainingState,

# Run initial eval
metrics = {}

if process_id == 0 and num_evals > 1:
metrics = evaluator.run_evaluation(
_unpmap(
Expand All @@ -259,14 +299,21 @@ def training_epoch_with_timing(training_state: TrainingState,
logging.info(metrics)
progress_fn(0, metrics)

init_key, scramble_key, local_key = jax.random.split(local_key, 3)
init_key = jax.random.split(init_key, (local_devices_to_use, num_envs // process_count))
env_state = reset_fn(init_key)
env_state = scramble_times(env_state, scramble_key)
env_state = step_fn(env_state, jnp.zeros((local_devices_to_use, num_envs // process_count,
env.action_size))) # Prevent recompilation on the second epoch

epoch_key, local_key = jax.random.split(local_key)
epoch_key = jax.random.split(epoch_key, local_devices_to_use)

for it in range(num_evals_after_init):
logging.info('starting iteration %s %s', it, time.time() - xt)

# optimization
epoch_key, local_key = jax.random.split(local_key)
epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
(training_state,
training_metrics) = training_epoch_with_timing(training_state, epoch_keys)
(training_state, env_state,
training_metrics, epoch_key) = training_epoch_with_timing(training_state, env_state, epoch_key)

if process_id == 0:
# Run evals.
Expand All @@ -284,3 +331,4 @@ def training_epoch_with_timing(training_state: TrainingState,
(training_state.normalizer_params, training_state.policy_params))
pmap.synchronize_hosts()
return (make_policy, params, metrics)

5 changes: 3 additions & 2 deletions brax/training/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def forward_log_det_jacobian(self, x):
class NormalTanhDistribution(ParametricDistribution):
"""Normal distribution followed by tanh."""

def __init__(self, event_size, min_std=0.001):
def __init__(self, event_size, min_std=0.001, var_scale=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add var_scale to Args section of docstring

"""Initialize the distribution.

Args:
Expand All @@ -151,8 +151,9 @@ def __init__(self, event_size, min_std=0.001):
event_ndims=1,
reparametrizable=True)
self._min_std = min_std
self._var_scale = var_scale

def create_dist(self, parameters):
loc, scale = jnp.split(parameters, 2, axis=-1)
scale = jax.nn.softplus(scale) + self._min_std
scale = (jax.nn.softplus(scale) + self._min_std) * self._var_scale
return NormalDistribution(loc=loc, scale=scale)
12 changes: 9 additions & 3 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class MLP(linen.Module):
kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
activate_final: bool = False
bias: bool = True

layer_norm: bool = False

@linen.compact
def __call__(self, data: jnp.ndarray):
hidden = data
Expand All @@ -54,6 +55,8 @@ def __call__(self, data: jnp.ndarray):
hidden)
if i != len(self.layer_sizes) - 1 or self.activate_final:
hidden = self.activation(hidden)
if self.layer_norm:
hidden = linen.LayerNorm()(hidden)
return hidden


Expand Down Expand Up @@ -86,12 +89,15 @@ def make_policy_network(
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
activation: ActivationFn = linen.relu,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
layer_norm: bool = False) -> FeedForwardNetwork:
"""Creates a policy network."""
policy_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [param_size],
activation=activation,
kernel_init=jax.nn.initializers.lecun_uniform())
kernel_init=kernel_init,
layer_norm=layer_norm)

def apply(processor_params, policy_params, obs):
obs = preprocess_observations_fn(obs, processor_params)
Expand Down
Loading