Skip to content

Commit

Permalink
Make it possible to reload PPO params for inference. (#553)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzakka authored Nov 13, 2024
1 parent bf616ce commit eb66604
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 3 additions & 1 deletion brax/training/agents/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def make_policy(params: types.PolicyParams,

def policy(observations: types.Observation,
key_sample: PRNGKey) -> Tuple[types.Action, types.Extra]:
logits = policy_network.apply(*params, observations)
# Discard the value function.
param_subset = (params[0], params[1].policy)
logits = policy_network.apply(*param_subset, observations)
if deterministic:
return ppo_networks.parametric_action_distribution.mode(logits), {}
raw_actions = parametric_action_distribution.sample_no_postprocessing(
Expand Down
8 changes: 4 additions & 4 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def training_step(
key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)

policy = make_policy(
(training_state.normalizer_params, training_state.params.policy))
(training_state.normalizer_params, training_state.params))

def f(carry, unused_t):
current_state, current_key = carry
Expand Down Expand Up @@ -437,7 +437,7 @@ def training_epoch_with_timing(
if process_id == 0 and num_evals > 1:
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.params.policy)),
(training_state.normalizer_params, training_state.params)),
training_metrics={})
logging.info(metrics)
progress_fn(0, metrics)
Expand Down Expand Up @@ -467,7 +467,7 @@ def training_epoch_with_timing(
# Run evals.
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.params.policy)),
(training_state.normalizer_params, training_state.params)),
training_metrics)
logging.info(metrics)
progress_fn(current_step, metrics)
Expand All @@ -483,7 +483,7 @@ def training_epoch_with_timing(
# devices.
pmap.assert_is_replicated(training_state)
params = _unpmap(
(training_state.normalizer_params, training_state.params.policy))
(training_state.normalizer_params, training_state.params))
logging.info('total steps: %s', total_steps)
pmap.synchronize_hosts()
return (make_policy, params, metrics)

0 comments on commit eb66604

Please sign in to comment.