From b2e09add697294c2793506b8be074b128b166286 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Sun, 8 Oct 2023 16:04:14 -0400 Subject: [PATCH] Automatically determine num_actions and num_chance_outcomes in stochastic_muzero_policy. --- mctx/_src/policies.py | 12 +++++------- mctx/_src/tests/policies_test.py | 2 -- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/mctx/_src/policies.py b/mctx/_src/policies.py index 0e75b50..624d92b 100644 --- a/mctx/_src/policies.py +++ b/mctx/_src/policies.py @@ -238,8 +238,6 @@ def stochastic_muzero_policy( decision_recurrent_fn: base.DecisionRecurrentFn, chance_recurrent_fn: base.ChanceRecurrentFn, num_simulations: int, - num_actions: int, - num_chance_outcomes: int, invalid_actions: Optional[chex.Array] = None, max_depth: Optional[int] = None, loop_fn: base.LoopFn = jax.lax.fori_loop, @@ -271,8 +269,6 @@ def stochastic_muzero_policy( `(params, rng_key, chance_outcome, afterstate_embedding)` and returns a `(ChanceRecurrentFnOutput, state_embedding)`. num_simulations: the number of simulations. - num_actions: number of environment actions. - num_chance_outcomes: number of chance outcomes following an afterstate. invalid_actions: a mask with invalid actions. Invalid actions have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`. max_depth: maximum search tree depth allowed during simulation. @@ -293,6 +289,8 @@ def stochastic_muzero_policy( search tree. """ + num_actions = root.prior_logits.shape[-1] + rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3) # Adding Dirichlet noise. @@ -309,9 +307,9 @@ def stochastic_muzero_policy( # construct a dummy afterstate embedding batch_size = jax.tree_util.tree_leaves(root.embedding)[0].shape[0] dummy_action = jnp.zeros([batch_size], dtype=jnp.int32) - _, dummy_afterstate_embedding = decision_recurrent_fn(params, rng_key, - dummy_action, - root.embedding) + dummy_output, dummy_afterstate_embedding = decision_recurrent_fn( + params, rng_key, dummy_action, root.embedding) + num_chance_outcomes = dummy_output.chance_logits.shape[-1] root = root.replace( # pad action logits with num_chance_outcomes so dim is A + C diff --git a/mctx/_src/tests/policies_test.py b/mctx/_src/tests/policies_test.py index cde5b7a..c08a093 100644 --- a/mctx/_src/tests/policies_test.py +++ b/mctx/_src/tests/policies_test.py @@ -347,8 +347,6 @@ def test_stochastic_muzero_policy(self): decision_recurrent_fn=decision_rec_fn, chance_recurrent_fn=chance_rec_fn, num_simulations=2 * num_simulations, - num_actions=4, - num_chance_outcomes=num_chance_outcomes, invalid_actions=invalid_actions, dirichlet_fraction=0.0)