From aa561177d17b200a9e9c2d6b6dbe4afc58a82719 Mon Sep 17 00:00:00 2001 From: Hiraoka Date: Fri, 19 Jan 2024 22:21:55 +0900 Subject: [PATCH 1/3] cleaning (_initialize_search) --- mcts/tree.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/mcts/tree.py b/mcts/tree.py index 28c02f6..dd5c98b 100644 --- a/mcts/tree.py +++ b/mcts/tree.py @@ -46,6 +46,14 @@ def __init__(self, network: DualNet, tree_size: int=MCTS_TREE_SIZE, \ self.to_move = Stone.BLACK + def _initialize_search(self, board: GoBoard, color: Stone) -> NoReturn: + self.num_nodes = 0 + self.current_root = self.expand_node(board, color) + input_plane = generate_input_planes(board, color, 0) + self.batch_queue.push(input_plane, [], self.current_root) + self.process_mini_batch(board) + + def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManager, \ analysis_query: Dict[str, Any]) -> int: """モンテカルロ木探索を実行して最善手を返す。 @@ -58,16 +66,10 @@ def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManag Returns: int: 着手する座標。 """ - self.num_nodes = 0 + self._initialize_search(board, color) time_manager.start_timer() - self.current_root = self.expand_node(board, color) - input_plane = generate_input_planes(board, color, 0) - self.batch_queue.push(input_plane, [], self.current_root) - - self.process_mini_batch(board) - root = self.node[self.current_root] # 候補手が1つしかない場合はPASSを返す @@ -111,12 +113,7 @@ def ponder(self, board: GoBoard, color: Stone, analysis_query: Dict[str, Any]) - color (Stone): 思考する手番の色。 analysis_query (Dict): 解析情報。 """ - self.num_nodes = 0 - - self.current_root = self.expand_node(board, color) - input_plane = generate_input_planes(board, color, 0) - self.batch_queue.push(input_plane, [], self.current_root) - self.process_mini_batch(board) + self._initialize_search(board, color) # 探索を実行する max_visits = 999999999 From 50d7598385875892be040d2a199ce072aeb85d27 Mon Sep 17 00:00:00 2001 From: Hiraoka Date: Fri, 19 Jan 2024 22:21:55 +0900 Subject: [PATCH 2/3] cleaning (get_analysis_status_list) --- mcts/node.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mcts/node.py b/mcts/node.py index a2269e3..5f8526f 100644 --- a/mcts/node.py +++ b/mcts/node.py @@ -408,6 +408,12 @@ def get_analysis(self, board: GoBoard, mode: str, \ Returns: str: GTP応答用解析結果文字列。 """ + children_status_list = self.get_analysis_status_list(board, pv_lists_func) + return self.get_analysis_from_status_list(mode, children_status_list) + + + def get_analysis_status_list(self, board: GoBoard, \ + pv_lists_func: Callable[[List[str], int], List[str]]): sorted_list = [] for i in range(self.num_children): sorted_list.append((self.children_visits[i], i)) @@ -439,7 +445,10 @@ def get_analysis(self, board: GoBoard, mode: str, \ } ) order += 1 + return children_status_list + + def get_analysis_from_status_list(self, mode, children_status_list): out = "" if mode == "cgos": cgos_dict = { From b67aae25686a419a1393fc9b19e7a8763c88c008 Mon Sep 17 00:00:00 2001 From: Hiraoka Date: Fri, 19 Jan 2024 22:21:55 +0900 Subject: [PATCH 3/3] Add MCTS animation feature in "fake lz-analyze" approach (#93). --- animation/animation.py | 68 ++++++++++++++++++++++++++++++++++++++++++ gtp/client.py | 20 +++++++++++-- main.py | 11 +++++-- mcts/node.py | 4 +++ mcts/tree.py | 24 ++++++++++++++- 5 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 animation/animation.py diff --git a/animation/animation.py b/animation/animation.py new file mode 100644 index 0000000..d396705 --- /dev/null +++ b/animation/animation.py @@ -0,0 +1,68 @@ +import sys +import select +import time + + +def animate_mcts(mcts, board, to_move, pv_wait_sec, move_wait_sec): + previous_pv = [] + def callback(path): + _animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv) + finished = _stdin_has_data() + return finished + mcts.search_with_callback(board, to_move, callback) + + +def _stdin_has_data(): + rlist, _, _ = select.select([sys.stdin], [], [], 0) + return bool(rlist) + + +def _animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv): + # ����õ�����������°���� + root_index, i = path[0] + root = mcts.node[root_index] + if root.children_visits[i] == 0: + return + coordinate = board.coordinate + move = coordinate.convert_to_gtp_format(root.action[i]) + pv = [coordinate.convert_to_gtp_format(mcts.node[index].action[child_index]) for (index, child_index) in path] + pv_visits = [str(mcts.node[index].children_visits[child_index]) for (index, child_index) in path] + pv_winrate = [str(int(10000 * _get_winrate(mcts, index, child_index, depth))) for depth, (index, child_index) in enumerate(path)] + + # lz-analyze ������ν������Ƥ�ù� + children_status_list = root.get_analysis_status_list(board, mcts.get_pv_lists) + fake_status_list = [status.copy() for status in children_status_list] + target = next((status for status in fake_status_list if status["move"] == move), None) + if target is None: + return # can't happen + # ����õ����������ν��������ȵ��äƽ�̤�դ�ľ�� + target["order"] = -1 + fake_status_list.sort(key=lambda status: status["order"]) + for order, status in enumerate(fake_status_list): + status["order"] = order + + # PV ��򺹤������ʤ���ʣ������Ϥ��뤳�Ȥǰ�ꤺ�ĥ��˥᡼����� + for k in range(1, len(pv) + 1): + # ����η���ȶ��̤ʼ��ϥ����å� + if pv[:k] == previous_pv[:k]: + continue + + target["pv"] = " ".join(pv[:k]) + target["pvVisits"] = " ".join(pv_visits[:k]) + target["pvWinrate"] = " ".join(pv_winrate[:k]) + + sys.stdout.write(root.get_analysis_from_status_list("lz", fake_status_list)) + sys.stdout.flush() + time.sleep(max(move_wait_sec, 0.0)) + + previous_pv[:] = pv + time.sleep(max(pv_wait_sec, 0.0)) + + +def _get_winrate(mcts, index, child_index, depth): + node = mcts.node[index] + i = child_index + visits = node.children_visits[i] + value = node.children_value_sum[i] / visits if visits > 0 else node.children_value[i] + winrate = value if depth % 2 == 0 else 1.0 - value + return winrate diff --git a/gtp/client.py b/gtp/client.py index dd40960..22f5662 100644 --- a/gtp/client.py +++ b/gtp/client.py @@ -19,6 +19,7 @@ from nn.policy_player import generate_move_from_policy from nn.utility import load_network from sgf.reader import SGFReader +from animation.animation import animate_mcts gtp_command_id = "" @@ -30,7 +31,8 @@ class GtpClient: # pylint: disable=R0902,R0903 def __init__(self, board_size: int, superko: bool, model_file_path: str, \ use_gpu: bool, policy_move: bool, use_sequential_halving: bool, \ komi: float, mode: TimeControl, visits: int, const_time: float, \ - time: float, batch_size: int, tree_size: int, cgos_mode: bool): # pylint: disable=R0913 + time: float, batch_size: int, tree_size: int, cgos_mode: bool, \ + animation_pv_wait: float, animation_move_wait:float): # pylint: disable=R0913 """Go Text Protocolクライアントの初期化をする。 Args: @@ -92,6 +94,8 @@ def __init__(self, board_size: int, superko: bool, model_file_path: str, \ self.policy_move = policy_move self.use_sequential_halving = use_sequential_halving self.use_network = False + self.animation_pv_wait = animation_pv_wait + self.animation_move_wait = animation_move_wait if mode is TimeControl.CONSTANT_PLAYOUT or mode is TimeControl.STRICT_PLAYOUT: self.time_manager = TimeManager(mode=mode, constant_visits=visits) @@ -406,6 +410,18 @@ def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float): return error_value return (to_move, interval) + def _analyze_or_animate(self, mode: str, arg_list: List[str]) -> NoReturn: + if max(self.animation_pv_wait, self.animation_move_wait) >= 0: + self._animate(arg_list, self.animation_pv_wait, self.animation_move_wait) + else: + self._analyze(mode, arg_list) + + def _animate(self, arg_list: List[str], pv_wait: float, move_wait: float) -> NoReturn: + to_move, _ = self._decode_analyze_arg(arg_list) + respond_success("", ongoing=True) + animate_mcts(self.mcts, self.board, to_move, pv_wait, move_wait) + print_out("") + def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn: """analyzeコマンド(lz-analyze, cgos-analyze)を実行する。 @@ -571,7 +587,7 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915 self.board.display_self_atari(Stone.WHITE) respond_success("") elif input_gtp_command == "lz-analyze": - self._analyze("lz", command_list[1:]) + self._analyze_or_animate("lz", command_list[1:]) print("") elif input_gtp_command == "lz-genmove_analyze": self._genmove_analyze("lz", command_list[1:]) diff --git a/main.py b/main.py index 2298ba2..68bad0f 100755 --- a/main.py +++ b/main.py @@ -44,9 +44,14 @@ help=f"探索木を構成するノードの最大数。デフォルトはMCTS_TREE_SIZE = {MCTS_TREE_SIZE}。") @click.option('--cgos-mode', type=click.BOOL, default=False, \ help="全ての石を打ち上げるまでパスしないモード設定。デフォルトはFalse。") +@click.option('--animation-pv-wait', type=click.FLOAT, default=-1.0, \ + help="lz-analyzeの出力をMCTSアニメーションに差しかえて、系列ごとに指定秒停止。") +@click.option('--animation-move-wait', type=click.FLOAT, default=-1.0, \ + help="lz-analyzeの出力をMCTSアニメーションに差しかえて、一手ごとに指定秒停止。") def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halving: bool, \ policy_move: bool, komi: float, visits: int, strict_visits: int, const_time: float, time: float, \ - batch_size: int, tree_size: int, cgos_mode: bool): + batch_size: int, tree_size: int, cgos_mode: bool, \ + animation_pv_wait: float, animation_move_wait: float): """GTPクライアントの起動。 Args: @@ -64,6 +69,8 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv batch_size (int): 探索実行時のニューラルネットワークのミニバッチサイズ。デフォルトはNN_BATCH_SIZE。 tree_size (int): 探索木を構成するノードの最大数。デフォルトはMCTS_TREE_SIZE。 cgos_mode (bool): 全ての石を打ち上げるまでパスしないモード設定。デフォルトはFalse。 + animation_pv_wait (float): lz-analyzeの出力をMCTSアニメーションに差しかえて、系列ごとに指定秒停止。 + animation_move_wait (float): lz-analyzeの出力をMCTSアニメーションに差しかえて、一手ごとに指定秒停止。 """ mode = TimeControl.CONSTANT_PLAYOUT @@ -78,7 +85,7 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv program_dir = os.path.dirname(__file__) client = GtpClient(size, superko, os.path.join(program_dir, model), use_gpu, policy_move, \ sequential_halving, komi, mode, visits, const_time, time, batch_size, tree_size, \ - cgos_mode) + cgos_mode, animation_pv_wait, animation_move_wait) client.run() diff --git a/mcts/node.py b/mcts/node.py index 5f8526f..2a07da0 100644 --- a/mcts/node.py +++ b/mcts/node.py @@ -466,6 +466,10 @@ def get_analysis_from_status_list(self, mode, children_status_list): out += f"lcb {int(10000 * status['lcb'])} " out += f"order {status['order']} " out += f"pv {status['pv']}" + # if "pvVisits" in status: + # out += f" pvVisits {status['pvVisits']}" + # if "pvWinrate" in status: + # out += f" lizgobanPvWinrate {status['pvWinrate']}" out += " " elif mode == "cgos": cgos_dict["moves"].append(status) diff --git a/mcts/tree.py b/mcts/tree.py index dd5c98b..dca9da1 100644 --- a/mcts/tree.py +++ b/mcts/tree.py @@ -1,6 +1,6 @@ """モンテカルロ木探索の実装。 """ -from typing import Any, Dict, List, NoReturn, Tuple +from typing import Any, Dict, List, NoReturn, Tuple, Callable import sys import select import copy @@ -174,6 +174,28 @@ def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \ sys.stdout.flush() + def search_with_callback(self, board: GoBoard, color: Stone, callback: Callable[List[Tuple[int, int]], bool]) -> NoReturn: + """探索を実行し、探索系列をコールバック関数へ渡す動作をくり返す。 +コールバック関数の戻り値が真になれば終了する。 + Args: + board (GoBoard): 現在の局面情報。 + color (Stone): 現局面の手番の色。 + callback (Callable[List[Tuple[int, int]], bool]): コールバック関数。 + """ + original_batch_size = self.batch_size + self.batch_size = 1 + self._initialize_search(board, color) + search_board = copy.deepcopy(board) + while True: + path = [] + copy_board(dst=search_board, src=board) + self.search_mcts(search_board, color, self.current_root, path) + finished = callback(path) + if finished: + break + self.batch_size = original_batch_size + + def search_mcts(self, board: GoBoard, color: Stone, current_index: int, \ path: List[Tuple[int, int]]) -> NoReturn: """モンテカルロ木探索を実行する。