Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MCTS過程のアニメーション(#93)の実装 #106

Merged
merged 3 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions animation/animation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import sys
import select
import time


def animate_mcts(mcts, board, to_move, pv_wait_sec, move_wait_sec):
previous_pv = []
def callback(path):
_animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv)
finished = _stdin_has_data()
return finished
mcts.search_with_callback(board, to_move, callback)


def _stdin_has_data():
rlist, _, _ = select.select([sys.stdin], [], [], 0)
return bool(rlist)


def _animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv):
# ����õ�����������°����
root_index, i = path[0]
root = mcts.node[root_index]
if root.children_visits[i] == 0:
return
coordinate = board.coordinate
move = coordinate.convert_to_gtp_format(root.action[i])
pv = [coordinate.convert_to_gtp_format(mcts.node[index].action[child_index]) for (index, child_index) in path]
pv_visits = [str(mcts.node[index].children_visits[child_index]) for (index, child_index) in path]
pv_winrate = [str(int(10000 * _get_winrate(mcts, index, child_index, depth))) for depth, (index, child_index) in enumerate(path)]

# lz-analyze ������ν������Ƥ�ù�
children_status_list = root.get_analysis_status_list(board, mcts.get_pv_lists)
fake_status_list = [status.copy() for status in children_status_list]
target = next((status for status in fake_status_list if status["move"] == move), None)
if target is None:
return # can't happen
# ����õ����������ν��������ȵ��äƽ�̤�դ�ľ��
target["order"] = -1
fake_status_list.sort(key=lambda status: status["order"])
for order, status in enumerate(fake_status_list):
status["order"] = order

# PV ��򺹤������ʤ���ʣ������Ϥ��뤳�Ȥǰ�ꤺ�ĥ��˥᡼�����
for k in range(1, len(pv) + 1):
# ����η���ȶ��̤ʼ��ϥ����å�
if pv[:k] == previous_pv[:k]:
continue

target["pv"] = " ".join(pv[:k])
target["pvVisits"] = " ".join(pv_visits[:k])
target["pvWinrate"] = " ".join(pv_winrate[:k])

sys.stdout.write(root.get_analysis_from_status_list("lz", fake_status_list))
sys.stdout.flush()
time.sleep(max(move_wait_sec, 0.0))

previous_pv[:] = pv
time.sleep(max(pv_wait_sec, 0.0))


def _get_winrate(mcts, index, child_index, depth):
node = mcts.node[index]
i = child_index
visits = node.children_visits[i]
value = node.children_value_sum[i] / visits if visits > 0 else node.children_value[i]
winrate = value if depth % 2 == 0 else 1.0 - value
return winrate
20 changes: 18 additions & 2 deletions gtp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nn.policy_player import generate_move_from_policy
from nn.utility import load_network
from sgf.reader import SGFReader
from animation.animation import animate_mcts


gtp_command_id = ""
Expand All @@ -30,7 +31,8 @@ class GtpClient: # pylint: disable=R0902,R0903
def __init__(self, board_size: int, superko: bool, model_file_path: str, \
use_gpu: bool, policy_move: bool, use_sequential_halving: bool, \
komi: float, mode: TimeControl, visits: int, const_time: float, \
time: float, batch_size: int, tree_size: int, cgos_mode: bool): # pylint: disable=R0913
time: float, batch_size: int, tree_size: int, cgos_mode: bool, \
animation_pv_wait: float, animation_move_wait:float): # pylint: disable=R0913
"""Go Text Protocolクライアントの初期化をする。

Args:
Expand Down Expand Up @@ -92,6 +94,8 @@ def __init__(self, board_size: int, superko: bool, model_file_path: str, \
self.policy_move = policy_move
self.use_sequential_halving = use_sequential_halving
self.use_network = False
self.animation_pv_wait = animation_pv_wait
self.animation_move_wait = animation_move_wait

if mode is TimeControl.CONSTANT_PLAYOUT or mode is TimeControl.STRICT_PLAYOUT:
self.time_manager = TimeManager(mode=mode, constant_visits=visits)
Expand Down Expand Up @@ -406,6 +410,18 @@ def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float):
return error_value
return (to_move, interval)

def _analyze_or_animate(self, mode: str, arg_list: List[str]) -> NoReturn:
if max(self.animation_pv_wait, self.animation_move_wait) >= 0:
self._animate(arg_list, self.animation_pv_wait, self.animation_move_wait)
else:
self._analyze(mode, arg_list)

def _animate(self, arg_list: List[str], pv_wait: float, move_wait: float) -> NoReturn:
to_move, _ = self._decode_analyze_arg(arg_list)
respond_success("", ongoing=True)
animate_mcts(self.mcts, self.board, to_move, pv_wait, move_wait)
print_out("")

def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
"""analyzeコマンド(lz-analyze, cgos-analyze)を実行する。

Expand Down Expand Up @@ -571,7 +587,7 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
self.board.display_self_atari(Stone.WHITE)
respond_success("")
elif input_gtp_command == "lz-analyze":
self._analyze("lz", command_list[1:])
self._analyze_or_animate("lz", command_list[1:])
print("")
elif input_gtp_command == "lz-genmove_analyze":
self._genmove_analyze("lz", command_list[1:])
Expand Down
11 changes: 9 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@
help=f"探索木を構成するノードの最大数。デフォルトはMCTS_TREE_SIZE = {MCTS_TREE_SIZE}。")
@click.option('--cgos-mode', type=click.BOOL, default=False, \
help="全ての石を打ち上げるまでパスしないモード設定。デフォルトはFalse。")
@click.option('--animation-pv-wait', type=click.FLOAT, default=-1.0, \
help="lz-analyzeの出力をMCTSアニメーションに差しかえて、系列ごとに指定秒停止。")
@click.option('--animation-move-wait', type=click.FLOAT, default=-1.0, \
help="lz-analyzeの出力をMCTSアニメーションに差しかえて、一手ごとに指定秒停止。")
def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halving: bool, \
policy_move: bool, komi: float, visits: int, strict_visits: int, const_time: float, time: float, \
batch_size: int, tree_size: int, cgos_mode: bool):
batch_size: int, tree_size: int, cgos_mode: bool, \
animation_pv_wait: float, animation_move_wait: float):
"""GTPクライアントの起動。

Args:
Expand All @@ -64,6 +69,8 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv
batch_size (int): 探索実行時のニューラルネットワークのミニバッチサイズ。デフォルトはNN_BATCH_SIZE。
tree_size (int): 探索木を構成するノードの最大数。デフォルトはMCTS_TREE_SIZE。
cgos_mode (bool): 全ての石を打ち上げるまでパスしないモード設定。デフォルトはFalse。
animation_pv_wait (float): lz-analyzeの出力をMCTSアニメーションに差しかえて、系列ごとに指定秒停止。
animation_move_wait (float): lz-analyzeの出力をMCTSアニメーションに差しかえて、一手ごとに指定秒停止。
"""
mode = TimeControl.CONSTANT_PLAYOUT

Expand All @@ -78,7 +85,7 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv
program_dir = os.path.dirname(__file__)
client = GtpClient(size, superko, os.path.join(program_dir, model), use_gpu, policy_move, \
sequential_halving, komi, mode, visits, const_time, time, batch_size, tree_size, \
cgos_mode)
cgos_mode, animation_pv_wait, animation_move_wait)
client.run()


Expand Down
13 changes: 13 additions & 0 deletions mcts/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,12 @@ def get_analysis(self, board: GoBoard, mode: str, \
Returns:
str: GTP応答用解析結果文字列。
"""
children_status_list = self.get_analysis_status_list(board, pv_lists_func)
return self.get_analysis_from_status_list(mode, children_status_list)


def get_analysis_status_list(self, board: GoBoard, \
pv_lists_func: Callable[[List[str], int], List[str]]):
sorted_list = []
for i in range(self.num_children):
sorted_list.append((self.children_visits[i], i))
Expand Down Expand Up @@ -439,7 +445,10 @@ def get_analysis(self, board: GoBoard, mode: str, \
}
)
order += 1
return children_status_list


def get_analysis_from_status_list(self, mode, children_status_list):
out = ""
if mode == "cgos":
cgos_dict = {
Expand All @@ -457,6 +466,10 @@ def get_analysis(self, board: GoBoard, mode: str, \
out += f"lcb {int(10000 * status['lcb'])} "
out += f"order {status['order']} "
out += f"pv {status['pv']}"
# if "pvVisits" in status:
# out += f" pvVisits {status['pvVisits']}"
# if "pvWinrate" in status:
# out += f" lizgobanPvWinrate {status['pvWinrate']}"
out += " "
elif mode == "cgos":
cgos_dict["moves"].append(status)
Expand Down
47 changes: 33 additions & 14 deletions mcts/tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""モンテカルロ木探索の実装。
"""
from typing import Any, Dict, List, NoReturn, Tuple
from typing import Any, Dict, List, NoReturn, Tuple, Callable
import sys
import select
import copy
Expand Down Expand Up @@ -46,6 +46,14 @@ def __init__(self, network: DualNet, tree_size: int=MCTS_TREE_SIZE, \
self.to_move = Stone.BLACK


def _initialize_search(self, board: GoBoard, color: Stone) -> NoReturn:
self.num_nodes = 0
self.current_root = self.expand_node(board, color)
input_plane = generate_input_planes(board, color, 0)
self.batch_queue.push(input_plane, [], self.current_root)
self.process_mini_batch(board)


def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
analysis_query: Dict[str, Any]) -> int:
"""モンテカルロ木探索を実行して最善手を返す。
Expand All @@ -58,16 +66,10 @@ def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManag
Returns:
int: 着手する座標。
"""
self.num_nodes = 0
self._initialize_search(board, color)

time_manager.start_timer()

self.current_root = self.expand_node(board, color)
input_plane = generate_input_planes(board, color, 0)
self.batch_queue.push(input_plane, [], self.current_root)

self.process_mini_batch(board)

root = self.node[self.current_root]

# 候補手が1つしかない場合はPASSを返す
Expand Down Expand Up @@ -111,12 +113,7 @@ def ponder(self, board: GoBoard, color: Stone, analysis_query: Dict[str, Any]) -
color (Stone): 思考する手番の色。
analysis_query (Dict): 解析情報。
"""
self.num_nodes = 0

self.current_root = self.expand_node(board, color)
input_plane = generate_input_planes(board, color, 0)
self.batch_queue.push(input_plane, [], self.current_root)
self.process_mini_batch(board)
self._initialize_search(board, color)

# 探索を実行する
max_visits = 999999999
Expand Down Expand Up @@ -177,6 +174,28 @@ def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
sys.stdout.flush()


def search_with_callback(self, board: GoBoard, color: Stone, callback: Callable[List[Tuple[int, int]], bool]) -> NoReturn:
"""探索を実行し、探索系列をコールバック関数へ渡す動作をくり返す。
コールバック関数の戻り値が真になれば終了する。
Args:
board (GoBoard): 現在の局面情報。
color (Stone): 現局面の手番の色。
callback (Callable[List[Tuple[int, int]], bool]): コールバック関数。
"""
original_batch_size = self.batch_size
self.batch_size = 1
self._initialize_search(board, color)
search_board = copy.deepcopy(board)
while True:
path = []
copy_board(dst=search_board, src=board)
self.search_mcts(search_board, color, self.current_root, path)
finished = callback(path)
if finished:
break
self.batch_size = original_batch_size


def search_mcts(self, board: GoBoard, color: Stone, current_index: int, \
path: List[Tuple[int, int]]) -> NoReturn:
"""モンテカルロ木探索を実行する。
Expand Down