diff --git a/pgx/_src/games/shogi.py b/pgx/_src/games/shogi.py index d17149e52..3583e84b2 100644 --- a/pgx/_src/games/shogi.py +++ b/pgx/_src/games/shogi.py @@ -80,6 +80,20 @@ class GameState(NamedTuple): cache_king: Array = jnp.int32(44) +class Game: + def init(self) -> GameState: + return GameState() + + def step(self, state: GameState, action: Array) -> GameState: + return _step(state, action) + + def observe(self, state: GameState) -> Array: + return _observe(state, False) + + def legal_action_mask(self, state: GameState) -> Array: + return _legal_action_mask(state) + + @dataclass class Action: is_drop: Array diff --git a/pgx/shogi.py b/pgx/shogi.py index 2d99ec1ff..57c29fdad 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -24,7 +24,7 @@ ) from pgx._src.struct import dataclass from pgx._src.types import Array, PRNGKey -from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, _step, _legal_action_mask, _observe, _flip +from pgx._src.games.shogi import MAX_TERMINATION_STEPS, GameState, Game, _observe, _flip TRUE = jnp.bool_(True) @@ -54,7 +54,7 @@ def _from_board(turn, piece_board: Array, hand: Array): # fmt: off state = jax.lax.cond(turn % 2 == 1, lambda: state.replace(_x=_flip(state._x)), lambda: state) # type: ignore # fmt: on - return state.replace(legal_action_mask=_legal_action_mask(state._x)) # type: ignore + return state.replace(legal_action_mask=Game().legal_action_mask(state._x)) # type: ignore @staticmethod def _from_sfen(sfen): @@ -67,8 +67,10 @@ def _to_sfen(self): class Shogi(core.Env): + def __init__(self): super().__init__() + self._game = Game() def _init(self, key: PRNGKey) -> State: state = State() @@ -79,13 +81,13 @@ def _step(self, state: core.State, action: Array, key) -> State: del key assert isinstance(state, State) # Note: Assume that illegal action is already filtered by Env.step - x = _step(state._x, action) + x = self._game.step(state._x, action) state = state.replace( # type: ignore current_player=(state.current_player + 1) % 2, _x=x, ) del x - legal_action_mask = _legal_action_mask(state._x) + legal_action_mask = self._game.legal_action_mask(state._x) terminated = ~legal_action_mask.any() # fmt: off reward = jax.lax.select(