diff --git a/pgx/go.py b/pgx/go.py index 46e277759..e054efc09 100644 --- a/pgx/go.py +++ b/pgx/go.py @@ -41,6 +41,8 @@ class GameState: _ko: Array = jnp.int32(-1) # by SSK _komi: Array = jnp.float32(7.5) _black_player: Array = jnp.int32(0) + is_terminal: Array = FALSE + is_psk: Array = FALSE @dataclass @@ -154,17 +156,16 @@ def _observe(state: State, player_id, size, history_length): my_turn = jax.lax.select( player_id == state.current_player, state._x._turn, 1 - state._x._turn ) - current_player_color = _my_color(state) # -1 or 1 - my_color, opp_color = jax.lax.cond( - player_id == state.current_player, - lambda: (current_player_color, -1 * current_player_color), - lambda: (-1 * current_player_color, current_player_color), - ) + return _observe_game_state(state._x, my_turn, size, history_length) + + +def _observe_game_state(x: GameState, my_turn, size, history_length): + my_color = jnp.int32([1, -1])[my_turn] @jax.vmap def _make(i): color = jnp.int32([1, -1])[i % 2] * my_color - return state._x._board_history[i // 2] == color + return x._board_history[i // 2] == color log = _make(jnp.arange(history_length * 2)) color = jnp.full_like(log[0], my_turn) # black=0, white=1 @@ -172,79 +173,101 @@ def _make(i): return jnp.vstack([log, color]).transpose().reshape((size, size, -1)) -def _init(key: PRNGKey, size: int, komi: float = 7.5) -> State: - black_player = jnp.int32(jax.random.bernoulli(key)) - current_player = black_player - _x = GameState( +def _init_game_state(size: int, komi: float, black_player: Array) -> GameState: + return GameState( _size=jnp.int32(size), _chain_id_board=jnp.zeros(size**2, dtype=jnp.int32), _board_history=jnp.full((8, size**2), 2, dtype=jnp.int32), _komi=jnp.float32(komi), _black_player=black_player, ) + + +def _init(key: PRNGKey, size: int, komi: float = 7.5) -> State: + black_player = jnp.int32(jax.random.bernoulli(key)) + current_player = black_player return State( # type:ignore legal_action_mask=jnp.ones(size**2 + 1, dtype=jnp.bool_), current_player=current_player, - _x=_x, + _x=_init_game_state(size, komi, black_player), ) -def _step(state: State, action: int, size: int) -> State: - state = state.replace(_x=state._x.replace(_ko=jnp.int32(-1))) # type: ignore +def _step_game_state(x: GameState, action: int, size: int) -> GameState: + x = x.replace(_ko=jnp.int32(-1)) # type: ignore + # update state - state = jax.lax.cond( + x = jax.lax.cond( (action < size * size), - lambda: _not_pass_move(state, action, size), - lambda: _pass_move(state, size), + lambda: _not_pass_move(x, action, size), + lambda: _pass_move(x), ) # increment turns - state = state.replace(_x=state._x.replace(_turn=(state._x._turn + 1) % 2)) # type: ignore - state = state.replace(current_player=(state.current_player + 1) % 2) # type: ignore - - # add legal action mask - state = state.replace( # type:ignore - legal_action_mask=state.legal_action_mask.at[:-1] - .set(legal_actions(state, size)) - .at[-1] - .set(TRUE) - ) + x = x.replace(_turn=(x._turn + 1) % 2) # type: ignore # update board history - board_history = jnp.roll(state._x._board_history, size**2) + board_history = jnp.roll(x._board_history, size**2) board_history = board_history.at[0].set( - jnp.clip(state._x._chain_id_board, -1, 1).astype(jnp.int32) + jnp.clip(x._chain_id_board, -1, 1).astype(jnp.int32) + ) + x = x.replace(_board_history=board_history) # type: ignore + + # PSK + is_psk = _check_PSK(x) + x = x.replace(is_psk=is_psk, is_terminal=(x.is_terminal | is_psk)) # type: ignore + + return x + + +def _step(state: State, action: int, size: int) -> State: + x = _step_game_state(state._x, action, size) + + current_player = (state.current_player + 1) % 2 # player to act + + rewards = jax.lax.cond( + x.is_terminal, + lambda: _get_reward(state, size), + lambda: jnp.zeros_like(state.rewards), ) - state = state.replace( # type:ignore - _x=state._x.replace(_board_history=board_history) # type: ignore + + rewards = jax.lax.select( + x.is_psk, jnp.float32([-1, -1]).at[current_player].set(1.0), rewards ) - # check PSK up to 8-steps before - state = _check_PSK(state) - return state + return state.replace( # type:ignore + current_player=current_player, + terminated=x.is_terminal, + rewards=rewards, + legal_action_mask=state.legal_action_mask.at[:-1] + .set(legal_actions(x, size)) + .at[-1] + .set(TRUE), + _x=x, + ) -def _pass_move(state: State, size) -> State: +def _pass_move(state: GameState) -> GameState: return jax.lax.cond( - state._x._passed, + state._passed, # consecutive passes results in the game end - lambda: state.replace(terminated=TRUE, rewards=_get_reward(state, size)), # type: ignore + lambda: state.replace(is_terminal=TRUE), # type: ignore # One pass continues the game - lambda: state.replace(_x=state._x.replace(_passed=TRUE), rewards=jnp.zeros(2, dtype=jnp.float32)), # type: ignore + lambda: state.replace(_passed=TRUE), # type: ignore ) -def _not_pass_move(state: State, action, size) -> State: - state = state.replace(_x=state._x.replace(_passed=FALSE)) # type: ignore +def _not_pass_move(state: GameState, action, size) -> GameState: + state = state.replace(_passed=FALSE) # type: ignore xy = action - num_captured_stones_before = state._x._num_captured_stones[state._x._turn] + num_captured_stones_before = state._num_captured_stones[state._turn] ko_may_occur = _ko_may_occur(state, xy) # Remove killed stones adj_xy = _neighbour(xy, size) oppo_color = _opponent_color(state) - chain_id = state._x._chain_id_board[adj_xy] + chain_id = state._chain_id_board[adj_xy] num_pseudo, idx_sum, idx_squared_sum = _count(state, size) chain_ix = jnp.abs(chain_id) - 1 is_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[ @@ -277,20 +300,20 @@ def _not_pass_move(state: State, action, size) -> State: # Check Ko # fmt: off state = jax.lax.cond( - state._x._num_captured_stones[state._x._turn] - num_captured_stones_before == 1, + state._num_captured_stones[state._turn] - num_captured_stones_before == 1, lambda: state, - lambda: state.replace(_x=state._x.replace(_ko=jnp.int32(-1))) # type:ignore + lambda: state.replace(_ko=jnp.int32(-1)) # type:ignore ) # fmt: on - return state.replace(rewards=jnp.zeros(2, dtype=jnp.float32)) # type: ignore + return state -def _merge_around_xy(i, state: State, xy, size): +def _merge_around_xy(i, state: GameState, xy, size): my_color = _my_color(state) adj_xy = _neighbour(xy, size)[i] is_off = adj_xy == -1 - is_my_chain = state._x._chain_id_board[adj_xy] * my_color > 0 + is_my_chain = state._chain_id_board[adj_xy] * my_color > 0 state = jax.lax.cond( ((~is_off) & is_my_chain), lambda: _merge_chain(state, xy, adj_xy), @@ -299,70 +322,64 @@ def _merge_around_xy(i, state: State, xy, size): return state -def _set_stone(state: State, xy) -> State: +def _set_stone(state: GameState, xy) -> GameState: my_color = _my_color(state) return state.replace( # type: ignore - _x=state._x.replace( # type:ignore - _chain_id_board=state._x._chain_id_board.at[xy].set( - (xy + 1) * my_color - ), - ) + _chain_id_board=state._chain_id_board.at[xy].set((xy + 1) * my_color), ) -def _merge_chain(state: State, xy, adj_xy): +def _merge_chain(state: GameState, xy, adj_xy): my_color = _my_color(state) - new_id = jnp.abs(state._x._chain_id_board[xy]) - adj_chain_id = jnp.abs(state._x._chain_id_board[adj_xy]) + new_id = jnp.abs(state._chain_id_board[xy]) + adj_chain_id = jnp.abs(state._chain_id_board[adj_xy]) small_id = jnp.minimum(new_id, adj_chain_id) * my_color large_id = jnp.maximum(new_id, adj_chain_id) * my_color # Keep larger chain ID and connect to the chain with smaller ID chain_id_board = jnp.where( - state._x._chain_id_board == large_id, + state._chain_id_board == large_id, small_id, - state._x._chain_id_board, + state._chain_id_board, ) - return state.replace(_x=state._x.replace(_chain_id_board=chain_id_board)) # type: ignore + return state.replace(_chain_id_board=chain_id_board) # type: ignore def _remove_stones( - state: State, rm_chain_id, rm_stone_xy, ko_may_occur -) -> State: - surrounded_stones = state._x._chain_id_board == rm_chain_id + state: GameState, rm_chain_id, rm_stone_xy, ko_may_occur +) -> GameState: + surrounded_stones = state._chain_id_board == rm_chain_id num_captured_stones = jnp.count_nonzero(surrounded_stones) - chain_id_board = jnp.where(surrounded_stones, 0, state._x._chain_id_board) + chain_id_board = jnp.where(surrounded_stones, 0, state._chain_id_board) ko = jax.lax.cond( ko_may_occur & (num_captured_stones == 1), lambda: jnp.int32(rm_stone_xy), - lambda: state._x._ko, + lambda: state._ko, ) return state.replace( # type: ignore - _x=state._x.replace( # type:ignore - _chain_id_board=chain_id_board, - _num_captured_stones=state._x._num_captured_stones.at[ - state._x._turn - ].add(num_captured_stones), - _ko=ko, - ) + _chain_id_board=chain_id_board, + _num_captured_stones=state._num_captured_stones.at[state._turn].add( + num_captured_stones + ), + _ko=ko, ) -def legal_actions(state: State, size: int) -> Array: +def legal_actions(state: GameState, size: int) -> Array: """Logic is highly inspired by OpenSpiel's Go implementation""" - is_empty = state._x._chain_id_board == 0 + is_empty = state._chain_id_board == 0 my_color = _my_color(state) opp_color = _opponent_color(state) num_pseudo, idx_sum, idx_squared_sum = _count(state, size) - chain_ix = jnp.abs(state._x._chain_id_board) - 1 + chain_ix = jnp.abs(state._chain_id_board) - 1 # fmt: off in_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[chain_ix] * num_pseudo[chain_ix] # fmt: on - has_liberty = (state._x._chain_id_board * my_color > 0) & ~in_atari - kills_opp = (state._x._chain_id_board * opp_color > 0) & in_atari + has_liberty = (state._chain_id_board * my_color > 0) & ~in_atari + kills_opp = (state._chain_id_board * opp_color > 0) & in_atari @jax.vmap def is_neighbor_ok(xy): @@ -381,15 +398,15 @@ def is_neighbor_ok(xy): legal_action_mask = is_empty & neighbor_ok return jax.lax.cond( - (state._x._ko == -1), + (state._ko == -1), lambda: legal_action_mask, - lambda: legal_action_mask.at[state._x._ko].set(FALSE), + lambda: legal_action_mask.at[state._ko].set(FALSE), ) -def _count(state: State, size): +def _count(state: GameState, size): ZERO = jnp.int32(0) - chain_id_board = jnp.abs(state._x._chain_id_board) + chain_id_board = jnp.abs(state._chain_id_board) is_empty = chain_id_board == 0 idx_sum = jnp.where(is_empty, jnp.arange(1, size**2 + 1), ZERO) idx_squared_sum = jnp.where( @@ -426,22 +443,22 @@ def _idx_squared_sum(x): return _num_pseudo(idx), _idx_sum(idx), _idx_squared_sum(idx) -def _my_color(state: State): - return jnp.int32([1, -1])[state._x._turn] +def _my_color(state: GameState): + return jnp.int32([1, -1])[state._turn] -def _opponent_color(state: State): - return jnp.int32([-1, 1])[state._x._turn] +def _opponent_color(state: GameState): + return jnp.int32([-1, 1])[state._turn] -def _ko_may_occur(state: State, xy: int) -> Array: - size = state._x._size +def _ko_may_occur(state: GameState, xy: int) -> Array: + size = state._size x = xy // size y = xy % size oob = jnp.bool_([x - 1 < 0, x + 1 >= size, y - 1 < 0, y + 1 >= size]) oppo_color = _opponent_color(state) is_occupied_by_opp = ( - state._x._chain_id_board[_neighbour(xy, size)] * oppo_color > 0 + state._chain_id_board[_neighbour(xy, size)] * oppo_color > 0 ) return (oob | is_occupied_by_opp).all() @@ -513,7 +530,7 @@ def fill_opp(x): return (b == 0).sum() -def _check_PSK(state): +def _check_PSK(state: GameState): """On PSK implementations. Tromp-Taylor rule employ PSK. However, implementing strict PSK is inefficient because @@ -541,18 +558,9 @@ def _check_PSK(state): Anyway, we believe it's effect is very small as PSK rarely happens, especially in 19x19 board. """ # fmt: off - is_psk = ~state._x._passed & (jnp.abs(state._x._board_history[0] - state._x._board_history[1:]).sum(axis=1) == 0).any() - winner = state.current_player - state = jax.lax.cond( - is_psk, - lambda: state.replace( # type: ignore - terminated=TRUE, - rewards=jnp.float32([-1, -1]).at[winner].set(1.0), - ), - lambda: state, - ) + is_psk = ~state._passed & (jnp.abs(state._board_history[0] - state._board_history[1:]).sum(axis=1) == 0).any() # fmt: on - return state + return is_psk # only for debug