Skip to content

Commit

Permalink
Automatically determine num_actions and num_chance_outcomes in stocha…
Browse files Browse the repository at this point in the history
…stic_muzero_policy.
  • Loading branch information
carlosgmartin committed Oct 8, 2023
1 parent 545b8ee commit b2e09ad
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
12 changes: 5 additions & 7 deletions mctx/_src/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,6 @@ def stochastic_muzero_policy(
decision_recurrent_fn: base.DecisionRecurrentFn,
chance_recurrent_fn: base.ChanceRecurrentFn,
num_simulations: int,
num_actions: int,
num_chance_outcomes: int,
invalid_actions: Optional[chex.Array] = None,
max_depth: Optional[int] = None,
loop_fn: base.LoopFn = jax.lax.fori_loop,
Expand Down Expand Up @@ -271,8 +269,6 @@ def stochastic_muzero_policy(
`(params, rng_key, chance_outcome, afterstate_embedding)` and returns a
`(ChanceRecurrentFnOutput, state_embedding)`.
num_simulations: the number of simulations.
num_actions: number of environment actions.
num_chance_outcomes: number of chance outcomes following an afterstate.
invalid_actions: a mask with invalid actions. Invalid actions have ones,
valid actions have zeros in the mask. Shape `[B, num_actions]`.
max_depth: maximum search tree depth allowed during simulation.
Expand All @@ -293,6 +289,8 @@ def stochastic_muzero_policy(
search tree.
"""

num_actions = root.prior_logits.shape[-1]

rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3)

# Adding Dirichlet noise.
Expand All @@ -309,9 +307,9 @@ def stochastic_muzero_policy(
# construct a dummy afterstate embedding
batch_size = jax.tree_util.tree_leaves(root.embedding)[0].shape[0]
dummy_action = jnp.zeros([batch_size], dtype=jnp.int32)
_, dummy_afterstate_embedding = decision_recurrent_fn(params, rng_key,
dummy_action,
root.embedding)
dummy_output, dummy_afterstate_embedding = decision_recurrent_fn(
params, rng_key, dummy_action, root.embedding)
num_chance_outcomes = dummy_output.chance_logits.shape[-1]

root = root.replace(
# pad action logits with num_chance_outcomes so dim is A + C
Expand Down
2 changes: 0 additions & 2 deletions mctx/_src/tests/policies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,6 @@ def test_stochastic_muzero_policy(self):
decision_recurrent_fn=decision_rec_fn,
chance_recurrent_fn=chance_rec_fn,
num_simulations=2 * num_simulations,
num_actions=4,
num_chance_outcomes=num_chance_outcomes,
invalid_actions=invalid_actions,
dirichlet_fraction=0.0)

Expand Down

0 comments on commit b2e09ad

Please sign in to comment.