Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] Extract game logic functions #1130

Merged
merged 21 commits into from
Dec 27, 2023
206 changes: 107 additions & 99 deletions pgx/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class GameState:
_ko: Array = jnp.int32(-1) # by SSK
_komi: Array = jnp.float32(7.5)
_black_player: Array = jnp.int32(0)
is_terminal: Array = FALSE
is_psk: Array = FALSE


@dataclass
Expand Down Expand Up @@ -154,97 +156,118 @@ def _observe(state: State, player_id, size, history_length):
my_turn = jax.lax.select(
player_id == state.current_player, state._x._turn, 1 - state._x._turn
)
current_player_color = _my_color(state) # -1 or 1
my_color, opp_color = jax.lax.cond(
player_id == state.current_player,
lambda: (current_player_color, -1 * current_player_color),
lambda: (-1 * current_player_color, current_player_color),
)
return _observe_game_state(state._x, my_turn, size, history_length)


def _observe_game_state(x: GameState, my_turn, size, history_length):
my_color = jnp.int32([1, -1])[my_turn]

@jax.vmap
def _make(i):
color = jnp.int32([1, -1])[i % 2] * my_color
return state._x._board_history[i // 2] == color
return x._board_history[i // 2] == color

log = _make(jnp.arange(history_length * 2))
color = jnp.full_like(log[0], my_turn) # black=0, white=1

return jnp.vstack([log, color]).transpose().reshape((size, size, -1))


def _init(key: PRNGKey, size: int, komi: float = 7.5) -> State:
black_player = jnp.int32(jax.random.bernoulli(key))
current_player = black_player
_x = GameState(
def _init_game_state(size: int, komi: float, black_player: Array) -> GameState:
return GameState(
_size=jnp.int32(size),
_chain_id_board=jnp.zeros(size**2, dtype=jnp.int32),
_board_history=jnp.full((8, size**2), 2, dtype=jnp.int32),
_komi=jnp.float32(komi),
_black_player=black_player,
)


def _init(key: PRNGKey, size: int, komi: float = 7.5) -> State:
black_player = jnp.int32(jax.random.bernoulli(key))
current_player = black_player
return State( # type:ignore
legal_action_mask=jnp.ones(size**2 + 1, dtype=jnp.bool_),
current_player=current_player,
_x=_x,
_x=_init_game_state(size, komi, black_player),
)


def _step(state: State, action: int, size: int) -> State:
state = state.replace(_x=state._x.replace(_ko=jnp.int32(-1))) # type: ignore
def _step_game_state(x: GameState, action: int, size: int) -> GameState:
x = x.replace(_ko=jnp.int32(-1)) # type: ignore

# update state
state = jax.lax.cond(
x = jax.lax.cond(
(action < size * size),
lambda: _not_pass_move(state, action, size),
lambda: _pass_move(state, size),
lambda: _not_pass_move(x, action, size),
lambda: _pass_move(x),
)

# increment turns
state = state.replace(_x=state._x.replace(_turn=(state._x._turn + 1) % 2)) # type: ignore
state = state.replace(current_player=(state.current_player + 1) % 2) # type: ignore

# add legal action mask
state = state.replace( # type:ignore
legal_action_mask=state.legal_action_mask.at[:-1]
.set(legal_actions(state, size))
.at[-1]
.set(TRUE)
)
x = x.replace(_turn=(x._turn + 1) % 2) # type: ignore

# update board history
board_history = jnp.roll(state._x._board_history, size**2)
board_history = jnp.roll(x._board_history, size**2)
board_history = board_history.at[0].set(
jnp.clip(state._x._chain_id_board, -1, 1).astype(jnp.int32)
jnp.clip(x._chain_id_board, -1, 1).astype(jnp.int32)
)
x = x.replace(_board_history=board_history) # type: ignore

# PSK
is_psk = _check_PSK(x)
x = x.replace(is_psk=is_psk, is_terminal=(x.is_terminal | is_psk)) # type: ignore

return x


def _step(state: State, action: int, size: int) -> State:
x = _step_game_state(state._x, action, size)

current_player = (state.current_player + 1) % 2 # player to act

rewards = jax.lax.cond(
x.is_terminal,
lambda: _get_reward(state, size),
lambda: jnp.zeros_like(state.rewards),
)
state = state.replace( # type:ignore
_x=state._x.replace(_board_history=board_history) # type: ignore

rewards = jax.lax.select(
x.is_psk, jnp.float32([-1, -1]).at[current_player].set(1.0), rewards
)

# check PSK up to 8-steps before
state = _check_PSK(state)
return state
return state.replace( # type:ignore
current_player=current_player,
terminated=x.is_terminal,
rewards=rewards,
legal_action_mask=state.legal_action_mask.at[:-1]
.set(legal_actions(x, size))
.at[-1]
.set(TRUE),
_x=x,
)


def _pass_move(state: State, size) -> State:
def _pass_move(state: GameState) -> GameState:
return jax.lax.cond(
state._x._passed,
state._passed,
# consecutive passes results in the game end
lambda: state.replace(terminated=TRUE, rewards=_get_reward(state, size)), # type: ignore
lambda: state.replace(is_terminal=TRUE), # type: ignore
# One pass continues the game
lambda: state.replace(_x=state._x.replace(_passed=TRUE), rewards=jnp.zeros(2, dtype=jnp.float32)), # type: ignore
lambda: state.replace(_passed=TRUE), # type: ignore
)


def _not_pass_move(state: State, action, size) -> State:
state = state.replace(_x=state._x.replace(_passed=FALSE)) # type: ignore
def _not_pass_move(state: GameState, action, size) -> GameState:
state = state.replace(_passed=FALSE) # type: ignore
xy = action
num_captured_stones_before = state._x._num_captured_stones[state._x._turn]
num_captured_stones_before = state._num_captured_stones[state._turn]

ko_may_occur = _ko_may_occur(state, xy)

# Remove killed stones
adj_xy = _neighbour(xy, size)
oppo_color = _opponent_color(state)
chain_id = state._x._chain_id_board[adj_xy]
chain_id = state._chain_id_board[adj_xy]
num_pseudo, idx_sum, idx_squared_sum = _count(state, size)
chain_ix = jnp.abs(chain_id) - 1
is_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[
Expand Down Expand Up @@ -277,20 +300,20 @@ def _not_pass_move(state: State, action, size) -> State:
# Check Ko
# fmt: off
state = jax.lax.cond(
state._x._num_captured_stones[state._x._turn] - num_captured_stones_before == 1,
state._num_captured_stones[state._turn] - num_captured_stones_before == 1,
lambda: state,
lambda: state.replace(_x=state._x.replace(_ko=jnp.int32(-1))) # type:ignore
lambda: state.replace(_ko=jnp.int32(-1)) # type:ignore
)
# fmt: on

return state.replace(rewards=jnp.zeros(2, dtype=jnp.float32)) # type: ignore
return state


def _merge_around_xy(i, state: State, xy, size):
def _merge_around_xy(i, state: GameState, xy, size):
my_color = _my_color(state)
adj_xy = _neighbour(xy, size)[i]
is_off = adj_xy == -1
is_my_chain = state._x._chain_id_board[adj_xy] * my_color > 0
is_my_chain = state._chain_id_board[adj_xy] * my_color > 0
state = jax.lax.cond(
((~is_off) & is_my_chain),
lambda: _merge_chain(state, xy, adj_xy),
Expand All @@ -299,70 +322,64 @@ def _merge_around_xy(i, state: State, xy, size):
return state


def _set_stone(state: State, xy) -> State:
def _set_stone(state: GameState, xy) -> GameState:
my_color = _my_color(state)
return state.replace( # type: ignore
_x=state._x.replace( # type:ignore
_chain_id_board=state._x._chain_id_board.at[xy].set(
(xy + 1) * my_color
),
)
_chain_id_board=state._chain_id_board.at[xy].set((xy + 1) * my_color),
)


def _merge_chain(state: State, xy, adj_xy):
def _merge_chain(state: GameState, xy, adj_xy):
my_color = _my_color(state)
new_id = jnp.abs(state._x._chain_id_board[xy])
adj_chain_id = jnp.abs(state._x._chain_id_board[adj_xy])
new_id = jnp.abs(state._chain_id_board[xy])
adj_chain_id = jnp.abs(state._chain_id_board[adj_xy])
small_id = jnp.minimum(new_id, adj_chain_id) * my_color
large_id = jnp.maximum(new_id, adj_chain_id) * my_color

# Keep larger chain ID and connect to the chain with smaller ID
chain_id_board = jnp.where(
state._x._chain_id_board == large_id,
state._chain_id_board == large_id,
small_id,
state._x._chain_id_board,
state._chain_id_board,
)

return state.replace(_x=state._x.replace(_chain_id_board=chain_id_board)) # type: ignore
return state.replace(_chain_id_board=chain_id_board) # type: ignore


def _remove_stones(
state: State, rm_chain_id, rm_stone_xy, ko_may_occur
) -> State:
surrounded_stones = state._x._chain_id_board == rm_chain_id
state: GameState, rm_chain_id, rm_stone_xy, ko_may_occur
) -> GameState:
surrounded_stones = state._chain_id_board == rm_chain_id
num_captured_stones = jnp.count_nonzero(surrounded_stones)
chain_id_board = jnp.where(surrounded_stones, 0, state._x._chain_id_board)
chain_id_board = jnp.where(surrounded_stones, 0, state._chain_id_board)
ko = jax.lax.cond(
ko_may_occur & (num_captured_stones == 1),
lambda: jnp.int32(rm_stone_xy),
lambda: state._x._ko,
lambda: state._ko,
)
return state.replace( # type: ignore
_x=state._x.replace( # type:ignore
_chain_id_board=chain_id_board,
_num_captured_stones=state._x._num_captured_stones.at[
state._x._turn
].add(num_captured_stones),
_ko=ko,
)
_chain_id_board=chain_id_board,
_num_captured_stones=state._num_captured_stones.at[state._turn].add(
num_captured_stones
),
_ko=ko,
)


def legal_actions(state: State, size: int) -> Array:
def legal_actions(state: GameState, size: int) -> Array:
"""Logic is highly inspired by OpenSpiel's Go implementation"""
is_empty = state._x._chain_id_board == 0
is_empty = state._chain_id_board == 0

my_color = _my_color(state)
opp_color = _opponent_color(state)
num_pseudo, idx_sum, idx_squared_sum = _count(state, size)

chain_ix = jnp.abs(state._x._chain_id_board) - 1
chain_ix = jnp.abs(state._chain_id_board) - 1
# fmt: off
in_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[chain_ix] * num_pseudo[chain_ix]
# fmt: on
has_liberty = (state._x._chain_id_board * my_color > 0) & ~in_atari
kills_opp = (state._x._chain_id_board * opp_color > 0) & in_atari
has_liberty = (state._chain_id_board * my_color > 0) & ~in_atari
kills_opp = (state._chain_id_board * opp_color > 0) & in_atari

@jax.vmap
def is_neighbor_ok(xy):
Expand All @@ -381,15 +398,15 @@ def is_neighbor_ok(xy):
legal_action_mask = is_empty & neighbor_ok

return jax.lax.cond(
(state._x._ko == -1),
(state._ko == -1),
lambda: legal_action_mask,
lambda: legal_action_mask.at[state._x._ko].set(FALSE),
lambda: legal_action_mask.at[state._ko].set(FALSE),
)


def _count(state: State, size):
def _count(state: GameState, size):
ZERO = jnp.int32(0)
chain_id_board = jnp.abs(state._x._chain_id_board)
chain_id_board = jnp.abs(state._chain_id_board)
is_empty = chain_id_board == 0
idx_sum = jnp.where(is_empty, jnp.arange(1, size**2 + 1), ZERO)
idx_squared_sum = jnp.where(
Expand Down Expand Up @@ -426,22 +443,22 @@ def _idx_squared_sum(x):
return _num_pseudo(idx), _idx_sum(idx), _idx_squared_sum(idx)


def _my_color(state: State):
return jnp.int32([1, -1])[state._x._turn]
def _my_color(state: GameState):
return jnp.int32([1, -1])[state._turn]


def _opponent_color(state: State):
return jnp.int32([-1, 1])[state._x._turn]
def _opponent_color(state: GameState):
return jnp.int32([-1, 1])[state._turn]


def _ko_may_occur(state: State, xy: int) -> Array:
size = state._x._size
def _ko_may_occur(state: GameState, xy: int) -> Array:
size = state._size
x = xy // size
y = xy % size
oob = jnp.bool_([x - 1 < 0, x + 1 >= size, y - 1 < 0, y + 1 >= size])
oppo_color = _opponent_color(state)
is_occupied_by_opp = (
state._x._chain_id_board[_neighbour(xy, size)] * oppo_color > 0
state._chain_id_board[_neighbour(xy, size)] * oppo_color > 0
)
return (oob | is_occupied_by_opp).all()

Expand Down Expand Up @@ -513,7 +530,7 @@ def fill_opp(x):
return (b == 0).sum()


def _check_PSK(state):
def _check_PSK(state: GameState):
"""On PSK implementations.

Tromp-Taylor rule employ PSK. However, implementing strict PSK is inefficient because
Expand Down Expand Up @@ -541,18 +558,9 @@ def _check_PSK(state):
Anyway, we believe it's effect is very small as PSK rarely happens, especially in 19x19 board.
"""
# fmt: off
is_psk = ~state._x._passed & (jnp.abs(state._x._board_history[0] - state._x._board_history[1:]).sum(axis=1) == 0).any()
winner = state.current_player
state = jax.lax.cond(
is_psk,
lambda: state.replace( # type: ignore
terminated=TRUE,
rewards=jnp.float32([-1, -1]).at[winner].set(1.0),
),
lambda: state,
)
is_psk = ~state._passed & (jnp.abs(state._board_history[0] - state._board_history[1:]).sum(axis=1) == 0).any()
# fmt: on
return state
return is_psk


# only for debug
Expand Down
Loading