diff --git a/board/go_board.py b/board/go_board.py index 93668bb..a1cbe82 100644 --- a/board/go_board.py +++ b/board/go_board.py @@ -1,6 +1,6 @@ """碁盤のデータ定義と操作処理。 """ -from typing import List, NoReturn +from typing import List, Tuple, NoReturn from collections import deque import numpy as np @@ -469,6 +469,26 @@ def get_komi(self) -> float: """ return self.komi + def get_to_move(self) -> Stone: + """手番の色を取得する。 + + Returns: + Stone: 手番の色。 + """ + if self.moves == 1: + return Stone.BLACK + else: + last_move_color, _, _ = self.record.get(self.moves - 1) + return Stone.get_opponent_color(last_move_color) + + def get_move_history(self) -> List[Tuple[Stone, int, np.array]]: + """着手の履歴を取得する。 + + Returns: + [(Stone, int, np.array), ...]: (着手の色、座標、ハッシュ値) のリスト。 + """ + return [self.record.get(m) for m in range(1, self.moves)] + def count_score(self) -> int: # pylint: disable=R0912 """領地を簡易的にカウントする。 diff --git a/gtp/client.py b/gtp/client.py index 042f35d..a595621 100644 --- a/gtp/client.py +++ b/gtp/client.py @@ -20,6 +20,7 @@ from sgf.reader import SGFReader +gtp_command_id = "" class GtpClient: # pylint: disable=R0902,R0903 """_Go Text Protocolクライアントの実装クラス @@ -163,6 +164,19 @@ def _play(self, color: str, pos: str) -> NoReturn: respond_success("") + def _undo(self) -> NoReturn: + """undoコマンドを処理する。 + """ + # 一旦クリアして初手から直前手まで打ち直す非効率実装 + history = self.board.get_move_history() + if not history: + respond_failure("cannot undo") + return + self._clear_board() + for (color, pos, _) in history[:-1]: + self.board.put_stone(pos, color) + respond_success("") + def _genmove(self, color: str) -> NoReturn: """genmoveコマンドを処理する。 入力された手番で思考し、着手を生成する。 @@ -295,6 +309,45 @@ def _load_sgf(self, arg_list: List[str]) -> NoReturn: respond_success("") + def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float): + """analyzeコマンド(lz-analyze, cgos-analyze)の引数を解釈する。 + 不正な引数の場合は更新間隔として負値を返す。 + + Args: + arg_list (List[str]): コマンドの引数リスト。 + + Returns: + (Stone, float): 手番の色、更新間隔(秒) + """ + to_move = self.board.get_to_move() + interval = 0 + error_value = (to_move, -1.0) + # 受けつける形式の例 + # lz-analyze B 10 + # lz-analyze B + # lz-analyze 10 + # lz-analyze B interval 10 + # lz-analyze interval 10 + try: + if arg_list[0][0] in ['B', 'b']: + to_move = Stone.BLACK + arg_list.pop(0) + elif arg_list[0][0] in ['W', 'w']: + to_move = Stone.WHITE + arg_list.pop(0) + if arg_list[0] == "interval": + if len(arg_list) == 1: + return error_value + arg_list.pop(0) + if arg_list[0].isdigit(): + interval = int(arg_list[0])/100 + arg_list.pop(0) + except IndexError as e: + pass + if arg_list: + return error_value + return (to_move, interval) + def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn: """analyzeコマンド(lz-analyze, cgos-analyze)を実行する。 @@ -302,18 +355,13 @@ def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn: mode (str): 解析モード。値は"lz"か"cgos"。 arg_list (List[str]): コマンドの引数リスト (手番の色, 更新間隔)。 """ - interval = 0 - if len(arg_list) >= 2: - interval = int(arg_list[1])/100 - - if arg_list[0][0] in ['B', 'b']: - to_move = Stone.BLACK - elif arg_list[0][0] in ['W', 'w']: - to_move = Stone.WHITE - else: - respond_failure(f"{mode}-analyze color") + to_move, interval = self._decode_analyze_arg(arg_list) + if interval < 0: + respond_failure(f"{mode}-analyze [color] [interval]") return + respond_success("", ongoing=True) + analysis_query = { "mode" : mode, "interval" : interval, @@ -328,19 +376,13 @@ def _genmove_analyze(self, mode: str, arg_list: List[str]) -> NoReturn: mode (str): 解析モード。値は"lz"か"cgos"。 arg_list (List[str]): コマンドの引数リスト(手番の色, 更新間隔)。 """ - color = arg_list[0] - interval = 0 - if len(arg_list) >= 2: - interval = int(arg_list[1])/100 - - if color.lower()[0] == 'b': - genmove_color = Stone.BLACK - elif color.lower()[0] == 'w': - genmove_color = Stone.WHITE - else: - respond_failure(f"{mode}-genmove_analyze color") + genmove_color, interval = self._decode_analyze_arg(arg_list) + if interval < 0: + respond_failure(f"{mode}-analyze [color] [interval]") return + respond_success("", ongoing=True) + if self.use_network: # モンテカルロ木探索で着手生成 analysis_query = { @@ -369,13 +411,24 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915 """Go Text Protocolのクライアントの実行処理。 入力されたコマンドに対応する処理を実行し、応答メッセージを表示する。 """ + global gtp_command_id while True: command = input() command_list = command.rstrip().split(' ') + gtp_command_id = "" input_gtp_command = command_list[0] + # 入力されたコマンドの冒頭が数字なら、それを id とみなす。 + # (参照) + # Specification of the Go Text Protocol, version 2, draft 2 + # の「2.5 Command Structure」 + # http://www.lysator.liu.se/~gunnar/gtp/gtp2-spec-draft2/gtp2-spec.html#SECTION00035000000000000000 + if input_gtp_command.isdigit(): + gtp_command_id = command_list.pop(0) + input_gtp_command = command_list[0] + if input_gtp_command == "version": _version() elif input_gtp_command == "protocol_version": @@ -392,6 +445,8 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915 self._komi(command_list[1]) elif input_gtp_command == "play": self._play(command_list[1], command_list[2]) + elif input_gtp_command == "undo": + self._undo() elif input_gtp_command == "genmove": self._genmove(command_list[1]) elif input_gtp_command == "boardsize": @@ -445,18 +500,14 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915 self.board.display_self_atari(Stone.WHITE) respond_success("") elif input_gtp_command == "lz-analyze": - print_out("= ") self._analyze("lz", command_list[1:]) print("") elif input_gtp_command == "lz-genmove_analyze": - print_out("= ") self._genmove_analyze("lz", command_list[1:]) elif input_gtp_command == "cgos-analyze": - print_out("= ") self._analyze("cgos", command_list[1:]) print("") elif input_gtp_command == "cgos-genmove_analyze": - print_out("= ") self._genmove_analyze("cgos", command_list[1:]) elif input_gtp_command == "hash_record": print_err(self.board.record.get_hash_history()) @@ -464,13 +515,15 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915 else: respond_failure("unknown_command") -def respond_success(response: str) -> NoReturn: +def respond_success(response: str, ongoing: bool = False) -> NoReturn: """コマンド処理成功時の応答メッセージを表示する。 Args: response (str): 表示する応答メッセージ。 + ongoing (bool): 追加の応答メッセージが後に続くかどうか。 """ - print("= " + response + '\n') + terminator = "" if ongoing else '\n' + print(f"={gtp_command_id} " + response + terminator) def respond_failure(response: str) -> NoReturn: """コマンド処理失敗時の応答メッセージを表示する。 @@ -478,7 +531,7 @@ def respond_failure(response: str) -> NoReturn: Args: response (str): 表示する応答メッセージ。 """ - print("= ? " + response + '\n') + print(f"?{gtp_command_id} " + response + '\n') def _version() -> NoReturn: """versionコマンドを処理する。