Skip to content

Commit

Permalink
Merge pull request #106 from kaorahi/mcts_step2
Browse files Browse the repository at this point in the history
MCTS過程のアニメーション(#93)の実装
  • Loading branch information
kobanium authored May 2, 2024
2 parents eecde0b + b67aae2 commit 82fdd8a
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 18 deletions.
68 changes: 68 additions & 0 deletions animation/animation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import sys
import select
import time


def animate_mcts(mcts, board, to_move, pv_wait_sec, move_wait_sec):
previous_pv = []
def callback(path):
_animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv)
finished = _stdin_has_data()
return finished
mcts.search_with_callback(board, to_move, callback)


def _stdin_has_data():
rlist, _, _ = select.select([sys.stdin], [], [], 0)
return bool(rlist)


def _animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv):
# 今回探索した系列の属性値
root_index, i = path[0]
root = mcts.node[root_index]
if root.children_visits[i] == 0:
return
coordinate = board.coordinate
move = coordinate.convert_to_gtp_format(root.action[i])
pv = [coordinate.convert_to_gtp_format(mcts.node[index].action[child_index]) for (index, child_index) in path]
pv_visits = [str(mcts.node[index].children_visits[child_index]) for (index, child_index) in path]
pv_winrate = [str(int(10000 * _get_winrate(mcts, index, child_index, depth))) for depth, (index, child_index) in enumerate(path)]

# lz-analyze の本来の出力内容を加工
children_status_list = root.get_analysis_status_list(board, mcts.get_pv_lists)
fake_status_list = [status.copy() for status in children_status_list]
target = next((status for status in fake_status_list if status["move"] == move), None)
if target is None:
return # can't happen
# 今回探索した系列の初手を最善手と偽って順位をふり直す
target["order"] = -1
fake_status_list.sort(key=lambda status: status["order"])
for order, status in enumerate(fake_status_list):
status["order"] = order

# PV 欄を差しかえながら複数回出力することで一手ずつアニメーション
for k in range(1, len(pv) + 1):
# 前回の系列と共通な手順はスキップ
if pv[:k] == previous_pv[:k]:
continue

target["pv"] = " ".join(pv[:k])
target["pvVisits"] = " ".join(pv_visits[:k])
target["pvWinrate"] = " ".join(pv_winrate[:k])

sys.stdout.write(root.get_analysis_from_status_list("lz", fake_status_list))
sys.stdout.flush()
time.sleep(max(move_wait_sec, 0.0))

previous_pv[:] = pv
time.sleep(max(pv_wait_sec, 0.0))


def _get_winrate(mcts, index, child_index, depth):
node = mcts.node[index]
i = child_index
visits = node.children_visits[i]
value = node.children_value_sum[i] / visits if visits > 0 else node.children_value[i]
winrate = value if depth % 2 == 0 else 1.0 - value
return winrate
20 changes: 18 additions & 2 deletions gtp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nn.policy_player import generate_move_from_policy
from nn.utility import load_network
from sgf.reader import SGFReader
from animation.animation import animate_mcts


gtp_command_id = ""
Expand All @@ -30,7 +31,8 @@ class GtpClient: # pylint: disable=R0902,R0903
def __init__(self, board_size: int, superko: bool, model_file_path: str, \
use_gpu: bool, policy_move: bool, use_sequential_halving: bool, \
komi: float, mode: TimeControl, visits: int, const_time: float, \
time: float, batch_size: int, tree_size: int, cgos_mode: bool): # pylint: disable=R0913
time: float, batch_size: int, tree_size: int, cgos_mode: bool, \
animation_pv_wait: float, animation_move_wait:float): # pylint: disable=R0913
"""Go Text Protocolクライアントの初期化をする。
Args:
Expand Down Expand Up @@ -92,6 +94,8 @@ def __init__(self, board_size: int, superko: bool, model_file_path: str, \
self.policy_move = policy_move
self.use_sequential_halving = use_sequential_halving
self.use_network = False
self.animation_pv_wait = animation_pv_wait
self.animation_move_wait = animation_move_wait

if mode is TimeControl.CONSTANT_PLAYOUT or mode is TimeControl.STRICT_PLAYOUT:
self.time_manager = TimeManager(mode=mode, constant_visits=visits)
Expand Down Expand Up @@ -400,6 +404,18 @@ def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float):
return error_value
return (to_move, interval)

def _analyze_or_animate(self, mode: str, arg_list: List[str]) -> NoReturn:
if max(self.animation_pv_wait, self.animation_move_wait) >= 0:
self._animate(arg_list, self.animation_pv_wait, self.animation_move_wait)
else:
self._analyze(mode, arg_list)

def _animate(self, arg_list: List[str], pv_wait: float, move_wait: float) -> NoReturn:
to_move, _ = self._decode_analyze_arg(arg_list)
respond_success("", ongoing=True)
animate_mcts(self.mcts, self.board, to_move, pv_wait, move_wait)
print_out("")

def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
"""analyzeコマンド(lz-analyze, cgos-analyze)を実行する。
Expand Down Expand Up @@ -565,7 +581,7 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
self.board.display_self_atari(Stone.WHITE)
respond_success("")
elif input_gtp_command == "lz-analyze":
self._analyze("lz", command_list[1:])
self._analyze_or_animate("lz", command_list[1:])
print("")
elif input_gtp_command == "lz-genmove_analyze":
self._genmove_analyze("lz", command_list[1:])
Expand Down
11 changes: 9 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@
help=f"探索木を構成するノードの最大数。デフォルトはMCTS_TREE_SIZE = {MCTS_TREE_SIZE}。")
@click.option('--cgos-mode', type=click.BOOL, default=False, \
help="全ての石を打ち上げるまでパスしないモード設定。デフォルトはFalse。")
@click.option('--animation-pv-wait', type=click.FLOAT, default=-1.0, \
help="lz-analyzeの出力をMCTSアニメーションに差しかえて、系列ごとに指定秒停止。")
@click.option('--animation-move-wait', type=click.FLOAT, default=-1.0, \
help="lz-analyzeの出力をMCTSアニメーションに差しかえて、一手ごとに指定秒停止。")
def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halving: bool, \
policy_move: bool, komi: float, visits: int, strict_visits: int, const_time: float, time: float, \
batch_size: int, tree_size: int, cgos_mode: bool):
batch_size: int, tree_size: int, cgos_mode: bool, \
animation_pv_wait: float, animation_move_wait: float):
"""GTPクライアントの起動。
Args:
Expand All @@ -64,6 +69,8 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv
batch_size (int): 探索実行時のニューラルネットワークのミニバッチサイズ。デフォルトはNN_BATCH_SIZE。
tree_size (int): 探索木を構成するノードの最大数。デフォルトはMCTS_TREE_SIZE。
cgos_mode (bool): 全ての石を打ち上げるまでパスしないモード設定。デフォルトはFalse。
animation_pv_wait (float): lz-analyzeの出力をMCTSアニメーションに差しかえて、系列ごとに指定秒停止。
animation_move_wait (float): lz-analyzeの出力をMCTSアニメーションに差しかえて、一手ごとに指定秒停止。
"""
mode = TimeControl.CONSTANT_PLAYOUT

Expand All @@ -78,7 +85,7 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv
program_dir = os.path.dirname(__file__)
client = GtpClient(size, superko, os.path.join(program_dir, model), use_gpu, policy_move, \
sequential_halving, komi, mode, visits, const_time, time, batch_size, tree_size, \
cgos_mode)
cgos_mode, animation_pv_wait, animation_move_wait)
client.run()


Expand Down
13 changes: 13 additions & 0 deletions mcts/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,12 @@ def get_analysis(self, board: GoBoard, mode: str, \
Returns:
str: GTP応答用解析結果文字列。
"""
children_status_list = self.get_analysis_status_list(board, pv_lists_func)
return self.get_analysis_from_status_list(mode, children_status_list)


def get_analysis_status_list(self, board: GoBoard, \
pv_lists_func: Callable[[List[str], int], List[str]]):
sorted_list = []
for i in range(self.num_children):
sorted_list.append((self.children_visits[i], i))
Expand Down Expand Up @@ -439,7 +445,10 @@ def get_analysis(self, board: GoBoard, mode: str, \
}
)
order += 1
return children_status_list


def get_analysis_from_status_list(self, mode, children_status_list):
out = ""
if mode == "cgos":
cgos_dict = {
Expand All @@ -457,6 +466,10 @@ def get_analysis(self, board: GoBoard, mode: str, \
out += f"lcb {int(10000 * status['lcb'])} "
out += f"order {status['order']} "
out += f"pv {status['pv']}"
# if "pvVisits" in status:
# out += f" pvVisits {status['pvVisits']}"
# if "pvWinrate" in status:
# out += f" lizgobanPvWinrate {status['pvWinrate']}"
out += " "
elif mode == "cgos":
cgos_dict["moves"].append(status)
Expand Down
47 changes: 33 additions & 14 deletions mcts/tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""モンテカルロ木探索の実装。
"""
from typing import Any, Dict, List, NoReturn, Tuple
from typing import Any, Dict, List, NoReturn, Tuple, Callable
import sys
import select
import copy
Expand Down Expand Up @@ -46,6 +46,14 @@ def __init__(self, network: DualNet, tree_size: int=MCTS_TREE_SIZE, \
self.to_move = Stone.BLACK


def _initialize_search(self, board: GoBoard, color: Stone) -> NoReturn:
self.num_nodes = 0
self.current_root = self.expand_node(board, color)
input_plane = generate_input_planes(board, color, 0)
self.batch_queue.push(input_plane, [], self.current_root)
self.process_mini_batch(board)


def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
analysis_query: Dict[str, Any]) -> int:
"""モンテカルロ木探索を実行して最善手を返す。
Expand All @@ -58,16 +66,10 @@ def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManag
Returns:
int: 着手する座標。
"""
self.num_nodes = 0
self._initialize_search(board, color)

time_manager.start_timer()

self.current_root = self.expand_node(board, color)
input_plane = generate_input_planes(board, color, 0)
self.batch_queue.push(input_plane, [], self.current_root)

self.process_mini_batch(board)

root = self.node[self.current_root]

# 候補手が1つしかない場合はPASSを返す
Expand Down Expand Up @@ -111,12 +113,7 @@ def ponder(self, board: GoBoard, color: Stone, analysis_query: Dict[str, Any]) -
color (Stone): 思考する手番の色。
analysis_query (Dict): 解析情報。
"""
self.num_nodes = 0

self.current_root = self.expand_node(board, color)
input_plane = generate_input_planes(board, color, 0)
self.batch_queue.push(input_plane, [], self.current_root)
self.process_mini_batch(board)
self._initialize_search(board, color)

# 探索を実行する
max_visits = 999999999
Expand Down Expand Up @@ -177,6 +174,28 @@ def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
sys.stdout.flush()


def search_with_callback(self, board: GoBoard, color: Stone, callback: Callable[List[Tuple[int, int]], bool]) -> NoReturn:
"""探索を実行し、探索系列をコールバック関数へ渡す動作をくり返す。
コールバック関数の戻り値が真になれば終了する。
Args:
board (GoBoard): 現在の局面情報。
color (Stone): 現局面の手番の色。
callback (Callable[List[Tuple[int, int]], bool]): コールバック関数。
"""
original_batch_size = self.batch_size
self.batch_size = 1
self._initialize_search(board, color)
search_board = copy.deepcopy(board)
while True:
path = []
copy_board(dst=search_board, src=board)
self.search_mcts(search_board, color, self.current_root, path)
finished = callback(path)
if finished:
break
self.batch_size = original_batch_size


def search_mcts(self, board: GoBoard, color: Stone, current_index: int, \
path: List[Tuple[int, int]]) -> NoReturn:
"""モンテカルロ木探索を実行する。
Expand Down

0 comments on commit 82fdd8a

Please sign in to comment.