From f7e4668651b4bca12bcdff43a77eed804f888dfc Mon Sep 17 00:00:00 2001 From: Yuki Kobayashi Date: Sat, 19 Aug 2023 14:37:48 +0900 Subject: [PATCH] support mixed value approximation and bugfix --- .pylintrc | 2 +- CONTRIBUTORS | 2 +- mcts/node.py | 28 +++++++++++++++++++++++++--- mcts/time_manager.py | 18 ++++++++++++++++++ mcts/tree.py | 6 ++++-- pipeline.sh | 6 +++--- program.py | 5 ++++- selfplay/worker.py | 22 ++++++++++++++++++++-- selfplay_main.py | 6 +++++- 9 files changed, 81 insertions(+), 14 deletions(-) diff --git a/.pylintrc b/.pylintrc index 4ae1670..50aea86 100644 --- a/.pylintrc +++ b/.pylintrc @@ -24,7 +24,7 @@ ignore=CVS # Add files or directories matching the regex patterns to the ignore-list. The # regex matches against paths and can be in Posix or Windows format. -ignore-paths=LICENSE,README.md,model,requirements.txt,pipeline.sh +ignore-paths=LICENSE,README.md,model,requirements.txt,pipeline.sh,CONTRIBUTORS,test # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. The default value ignores emacs file diff --git a/CONTRIBUTORS b/CONTRIBUTORS index f8897f4..5d21e5e 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -5,4 +5,4 @@ Yuki Kobayashi ("kobanium") Contributors ============ -"CGLemon", Chinese translation and GTP analyze commands implementation. +"CGLemon", Chinese translation, GTP analyze commands implementation, and various contributions. diff --git a/mcts/node.py b/mcts/node.py index 93dde0c..7109240 100644 --- a/mcts/node.py +++ b/mcts/node.py @@ -26,6 +26,7 @@ def __init__(self, num_actions: int=MAX_ACTIONS): self.node_visits = 0 self.virtual_loss = 0 self.node_value_sum = 0.0 + self.raw_value = 0.0 self.action = [0] * num_actions self.children_index = np.zeros(num_actions, dtype=np.int32) self.children_value = np.zeros(num_actions, dtype=np.float64) @@ -43,8 +44,9 @@ def expand(self, policy: Dict[int, float]) -> NoReturn: policy (Dict[int, float]): 候補手に対応するPolicyのマップ。 """ self.node_visits = 0 - self.node_value_sum = 0.0 self.virtual_loss = 0 + self.node_value_sum = 0.0 + self.raw_value = 0.0 self.action = [0] * MAX_ACTIONS self.children_index.fill(NOT_EXPANDED) self.children_value.fill(0.0) @@ -103,6 +105,15 @@ def set_leaf_value(self, index: int, value: float) -> NoReturn: self.children_value[index] = value + def set_raw_value(self, value: float) -> NoReturn: + """ノードに対応する局面のValueを設定する。 + + Args: + value (float): 設定するValueの値。 + """ + self.raw_value = value + + def update_child_value(self, index: int, value: float) -> NoReturn: """子ノードにValueを加算し、Virtual Lossを元に戻す。 @@ -214,6 +225,7 @@ def print_search_result(self, board: GoBoard, pv_dict: Dict[str, List[str]]) -> """ value = np.divide(self.children_value_sum, self.children_visits, \ out=np.zeros_like(self.children_value_sum), where=(self.children_visits != 0)) + print_err(f"raw_value={self.raw_value:.4f}") for i in range(self.num_children): if self.children_visits[i] > 0: pos = board.coordinate.convert_to_gtp_format(self.action[i]) @@ -221,6 +233,7 @@ def print_search_result(self, board: GoBoard, pv_dict: Dict[str, List[str]]) -> msg += f"visits={self.children_visits[i]:5d}, " msg += f"policy={self.children_policy[i]:.4f}, " msg += f"value={value[i]:.4f}, " + msg += f"raw_value={self.children_value[i]:.4f}, " msg += f"pv={','.join(pv_dict[pos])}" print_err(msg) @@ -231,9 +244,12 @@ def set_gumbel_noise(self) -> NoReturn: self.noise = np.random.gumbel(loc=0.0, scale=1.0, size=self.noise.size) - def calculate_completed_q_value(self) -> np.array: + def calculate_completed_q_value(self, use_mixed_value :bool=True) -> np.array: """Completed-Q valueを計算する。 + Args: + use_mixed_value (bool, optional): Mixed value approximation使用フラグ. デフォルトはTrue. + Returns: np.array: Completed-Q value. """ @@ -246,7 +262,13 @@ def calculate_completed_q_value(self) -> np.array: sum_prob = np.sum(policy) v_pi = np.sum(policy * q_value) - return np.where(self.children_visits[:self.num_children] > 0, q_value, v_pi / sum_prob) + if use_mixed_value: + value = (self.raw_value * np.ones_like(self.children_policy[:self.num_children]) + \ + self.node_visits * v_pi / sum_prob) / (self.node_visits + 1.0) + else: + value = self.raw_value + + return np.where(self.children_visits[:self.num_children] > 0, q_value, value) def calculate_improved_policy(self) -> np.array: diff --git a/mcts/time_manager.py b/mcts/time_manager.py index c90ca73..10a24da 100644 --- a/mcts/time_manager.py +++ b/mcts/time_manager.py @@ -6,6 +6,7 @@ from board.stone import Stone from mcts.constant import CONST_VISITS, CONST_TIME, REMAINING_TIME, VISITS_PER_SEC +from mcts.node import MCTSNode class TimeControl(Enum): @@ -139,3 +140,20 @@ def is_time_over(self) -> bool: if time.time() - self.start_time > self.time_limit: return True return False + + +def is_move_decided(root: MCTSNode, threshold: int) -> bool: + """着手が決定したか否かを判定する。 + + Args: + root (MCTSNode): 現局面のルートノード。 + threshold (int): 探索回数の閾値。 + + Returns: + bool: 探索打ち切り判定結果。 + """ + sorted_visits = sorted(root.children_visits) + remaining_visits = threshold - root.node_visits + if sorted_visits[-1] - sorted_visits[-2] > remaining_visits: + return True + return False diff --git a/mcts/tree.py b/mcts/tree.py index 8f77a48..3bd7d1c 100644 --- a/mcts/tree.py +++ b/mcts/tree.py @@ -20,7 +20,7 @@ MAX_CONSIDERED_NODES, RESIGN_THRESHOLD, MCTS_TREE_SIZE from mcts.sequential_halving import get_candidates_and_visit_pairs from mcts.node import MCTSNode -from mcts.time_manager import TimeControl, TimeManager +from mcts.time_manager import TimeControl, TimeManager, is_move_decided class MCTSTree: # pylint: disable=R0902 """モンテカルロ木探索の実装クラス。 @@ -147,7 +147,8 @@ def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \ copy_board(dst=search_board,src=board) start_color = color self.search_mcts(search_board, start_color, self.current_root, []) - if time_manager.is_time_over(): + if time_manager.is_time_over() or \ + is_move_decided(self.get_root(), threshold): break if len(analysis_query) > 0: @@ -270,6 +271,7 @@ def process_mini_batch(self, board: GoBoard, use_logit: bool=False): # pylint: d for policy, value_dist, path, node_index in zip(policy_data, \ value_data, self.batch_queue.path, self.batch_queue.node_index): self.node[node_index].update_policy(policy) + self.node[node_index].set_raw_value(value_dist[1] * 0.5 + value_dist[2]) if path: value = value_dist[0] + value_dist[1] * 0.5 diff --git a/pipeline.sh b/pipeline.sh index 57d7001..f32697e 100644 --- a/pipeline.sh +++ b/pipeline.sh @@ -1,5 +1,5 @@ for i in `seq 1 100` ; do - python3.6 selfplay_main.py --save-dir archive --model model/rl-model.bin --use-gpu true - python3.6 get_final_status.py - python3.6 train.py --rl true --kifu-dir archive + python3 selfplay_main.py --save-dir archive --model model/rl-model.bin --use-gpu true + python3 get_final_status.py + python3 train.py --rl true --kifu-dir archive done diff --git a/program.py b/program.py index f74b9a6..72e26e6 100644 --- a/program.py +++ b/program.py @@ -24,4 +24,7 @@ # Version 0.7.0 : lz-analyze, lz-genmove_analyze, cgos-analyze, cgos-genmove_analyzeコマンドのサポート。 # 強化学習に関するバグと超劫の判定処理のバグの修正。 # Version 0.7.1 : 解析コマンドのバグ修正。 -VERSION="0.7.1" +# Version 0.8.0 : SHOTでMixed value approximationを使うように変更 +# 持ち時間の残りが少なくなった時にプログラムが落ちる不具合を修正。 +# 強化学習の棋譜生成時に経過情報の表示を追加。 +VERSION="0.8.0" diff --git a/selfplay/worker.py b/selfplay/worker.py index caaa1b2..ac18750 100644 --- a/selfplay/worker.py +++ b/selfplay/worker.py @@ -1,8 +1,10 @@ """自己対戦実行ワーカの実装。 """ +import glob import os import random -from typing import List +import time +from typing import List, NoReturn import numpy as np from board.constant import PASS, RESIGN @@ -17,7 +19,7 @@ # pylint: disable=R0913,R0914 def selfplay_worker(save_dir: str, model_file_path: str, index_list: List[int], \ - size: int, visits: int, use_gpu: bool): + size: int, visits: int, use_gpu: bool) -> NoReturn: """自己対戦実行ワーカ。 Args: @@ -86,3 +88,19 @@ def selfplay_worker(save_dir: str, model_file_path: str, index_list: List[int], record.set_index(index) record.write_record(winner, board.get_komi(), is_resign, score) + + +def display_selfplay_progress_worker(save_dir: str, num_data: int) -> NoReturn: + """自己対戦の進捗を表示する。 + + Args: + save_dir (str): 生成した棋譜ファイルが保存されるディレクトリのパス。 + """ + start_time = time.time() + while True: + time.sleep(60) + current_num_data = len(glob.glob(os.path.join(save_dir, "*.sgf"))) + current_time = time.time() + msg = f"Generating {current_num_data:5d}/{num_data:5d} games " + msg += f"({3600 * current_num_data / (current_time - start_time):.4f} games/hour)." + print(msg) diff --git a/selfplay_main.py b/selfplay_main.py index c63d498..635d0b8 100644 --- a/selfplay_main.py +++ b/selfplay_main.py @@ -3,11 +3,12 @@ import glob import math import os +import threading import time from concurrent.futures import ProcessPoolExecutor import click from board.constant import BOARD_SIZE -from selfplay.worker import selfplay_worker +from selfplay.worker import selfplay_worker, display_selfplay_progress_worker from learning_param import SELF_PLAY_VISITS, NUM_SELF_PLAY_WORKERS, \ NUM_SELF_PLAY_GAMES @@ -57,6 +58,9 @@ def selfplay_main(save_dir: str, process: int, num_data: int, size: int, \ with ProcessPoolExecutor(max_workers=process) as executor: futures = [executor.submit(selfplay_worker, os.path.join(save_dir, str(kifu_dir_index)), \ model, file_list, size, visits, use_gpu) for file_list in file_indice] + monitoring_worker = threading.Thread(target=display_selfplay_progress_worker, \ + args=(os.path.join(save_dir, str(kifu_dir_index)), num_data, ), daemon=True) + monitoring_worker.start() for future in futures: future.result()