diff --git a/pgx/_src/games/hex.py b/pgx/_src/games/hex.py index c5342cc1e..c5cc2d01e 100644 --- a/pgx/_src/games/hex.py +++ b/pgx/_src/games/hex.py @@ -69,17 +69,14 @@ def is_terminal(self, state: GameState) -> Array: # ... -def _is_terminal(state: GameState, size: int) -> Array: +def _is_terminal(state: GameState, action: Array, size: int) -> Array: top, bottom = jax.lax.cond( state.color == 0, lambda: (state.board[::size], state.board[size - 1 :: size]), lambda: (state.board[:size], state.board[-size:]), ) - - def check_same_id_exist(_id): - return (_id < 0) & (_id == bottom).any() - - return jax.vmap(check_same_id_exist)(top).any() + target_id = state.board[action] # target_id != 0 + return (top == target_id).any() & (bottom == target_id).any() def _step(state: GameState, action: Array, size: int) -> GameState: @@ -101,9 +98,7 @@ def merge(i, b): step_count=state.step_count + 1, board=board * -1, ) - - terminated = _is_terminal(state, size) - return state._replace(terminated=terminated) + return state._replace(terminated=_is_terminal(state, action, size)) def _swap(state: GameState, size: int) -> GameState: