Skip to content

Commit

Permalink
Merge pull request #77 from kobanium/support#69
Browse files Browse the repository at this point in the history
support mixed value approximation and bugfix
  • Loading branch information
kobanium authored Aug 19, 2023
2 parents 0073a77 + f7e4668 commit bf47a38
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ignore=CVS

# Add files or directories matching the regex patterns to the ignore-list. The
# regex matches against paths and can be in Posix or Windows format.
ignore-paths=LICENSE,README.md,model,requirements.txt,pipeline.sh
ignore-paths=LICENSE,README.md,model,requirements.txt,pipeline.sh,CONTRIBUTORS,test

# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths. The default value ignores emacs file
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTORS
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Yuki Kobayashi ("kobanium")

Contributors
============
"CGLemon", Chinese translation and GTP analyze commands implementation.
"CGLemon", Chinese translation, GTP analyze commands implementation, and various contributions.
28 changes: 25 additions & 3 deletions mcts/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, num_actions: int=MAX_ACTIONS):
self.node_visits = 0
self.virtual_loss = 0
self.node_value_sum = 0.0
self.raw_value = 0.0
self.action = [0] * num_actions
self.children_index = np.zeros(num_actions, dtype=np.int32)
self.children_value = np.zeros(num_actions, dtype=np.float64)
Expand All @@ -43,8 +44,9 @@ def expand(self, policy: Dict[int, float]) -> NoReturn:
policy (Dict[int, float]): 候補手に対応するPolicyのマップ。
"""
self.node_visits = 0
self.node_value_sum = 0.0
self.virtual_loss = 0
self.node_value_sum = 0.0
self.raw_value = 0.0
self.action = [0] * MAX_ACTIONS
self.children_index.fill(NOT_EXPANDED)
self.children_value.fill(0.0)
Expand Down Expand Up @@ -103,6 +105,15 @@ def set_leaf_value(self, index: int, value: float) -> NoReturn:
self.children_value[index] = value


def set_raw_value(self, value: float) -> NoReturn:
"""ノードに対応する局面のValueを設定する。
Args:
value (float): 設定するValueの値。
"""
self.raw_value = value


def update_child_value(self, index: int, value: float) -> NoReturn:
"""子ノードにValueを加算し、Virtual Lossを元に戻す。
Expand Down Expand Up @@ -214,13 +225,15 @@ def print_search_result(self, board: GoBoard, pv_dict: Dict[str, List[str]]) ->
"""
value = np.divide(self.children_value_sum, self.children_visits, \
out=np.zeros_like(self.children_value_sum), where=(self.children_visits != 0))
print_err(f"raw_value={self.raw_value:.4f}")
for i in range(self.num_children):
if self.children_visits[i] > 0:
pos = board.coordinate.convert_to_gtp_format(self.action[i])
msg = f"pos={pos}, "
msg += f"visits={self.children_visits[i]:5d}, "
msg += f"policy={self.children_policy[i]:.4f}, "
msg += f"value={value[i]:.4f}, "
msg += f"raw_value={self.children_value[i]:.4f}, "
msg += f"pv={','.join(pv_dict[pos])}"
print_err(msg)

Expand All @@ -231,9 +244,12 @@ def set_gumbel_noise(self) -> NoReturn:
self.noise = np.random.gumbel(loc=0.0, scale=1.0, size=self.noise.size)


def calculate_completed_q_value(self) -> np.array:
def calculate_completed_q_value(self, use_mixed_value :bool=True) -> np.array:
"""Completed-Q valueを計算する。
Args:
use_mixed_value (bool, optional): Mixed value approximation使用フラグ. デフォルトはTrue.
Returns:
np.array: Completed-Q value.
"""
Expand All @@ -246,7 +262,13 @@ def calculate_completed_q_value(self) -> np.array:
sum_prob = np.sum(policy)
v_pi = np.sum(policy * q_value)

return np.where(self.children_visits[:self.num_children] > 0, q_value, v_pi / sum_prob)
if use_mixed_value:
value = (self.raw_value * np.ones_like(self.children_policy[:self.num_children]) + \
self.node_visits * v_pi / sum_prob) / (self.node_visits + 1.0)
else:
value = self.raw_value

return np.where(self.children_visits[:self.num_children] > 0, q_value, value)


def calculate_improved_policy(self) -> np.array:
Expand Down
18 changes: 18 additions & 0 deletions mcts/time_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from board.stone import Stone
from mcts.constant import CONST_VISITS, CONST_TIME, REMAINING_TIME, VISITS_PER_SEC
from mcts.node import MCTSNode


class TimeControl(Enum):
Expand Down Expand Up @@ -139,3 +140,20 @@ def is_time_over(self) -> bool:
if time.time() - self.start_time > self.time_limit:
return True
return False


def is_move_decided(root: MCTSNode, threshold: int) -> bool:
"""着手が決定したか否かを判定する。
Args:
root (MCTSNode): 現局面のルートノード。
threshold (int): 探索回数の閾値。
Returns:
bool: 探索打ち切り判定結果。
"""
sorted_visits = sorted(root.children_visits)
remaining_visits = threshold - root.node_visits
if sorted_visits[-1] - sorted_visits[-2] > remaining_visits:
return True
return False
6 changes: 4 additions & 2 deletions mcts/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
MAX_CONSIDERED_NODES, RESIGN_THRESHOLD, MCTS_TREE_SIZE
from mcts.sequential_halving import get_candidates_and_visit_pairs
from mcts.node import MCTSNode
from mcts.time_manager import TimeControl, TimeManager
from mcts.time_manager import TimeControl, TimeManager, is_move_decided

class MCTSTree: # pylint: disable=R0902
"""モンテカルロ木探索の実装クラス。
Expand Down Expand Up @@ -147,7 +147,8 @@ def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
copy_board(dst=search_board,src=board)
start_color = color
self.search_mcts(search_board, start_color, self.current_root, [])
if time_manager.is_time_over():
if time_manager.is_time_over() or \
is_move_decided(self.get_root(), threshold):
break

if len(analysis_query) > 0:
Expand Down Expand Up @@ -270,6 +271,7 @@ def process_mini_batch(self, board: GoBoard, use_logit: bool=False): # pylint: d
for policy, value_dist, path, node_index in zip(policy_data, \
value_data, self.batch_queue.path, self.batch_queue.node_index):
self.node[node_index].update_policy(policy)
self.node[node_index].set_raw_value(value_dist[1] * 0.5 + value_dist[2])

if path:
value = value_dist[0] + value_dist[1] * 0.5
Expand Down
6 changes: 3 additions & 3 deletions pipeline.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
for i in `seq 1 100` ; do
python3.6 selfplay_main.py --save-dir archive --model model/rl-model.bin --use-gpu true
python3.6 get_final_status.py
python3.6 train.py --rl true --kifu-dir archive
python3 selfplay_main.py --save-dir archive --model model/rl-model.bin --use-gpu true
python3 get_final_status.py
python3 train.py --rl true --kifu-dir archive
done
5 changes: 4 additions & 1 deletion program.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@
# Version 0.7.0 : lz-analyze, lz-genmove_analyze, cgos-analyze, cgos-genmove_analyzeコマンドのサポート。
# 強化学習に関するバグと超劫の判定処理のバグの修正。
# Version 0.7.1 : 解析コマンドのバグ修正。
VERSION="0.7.1"
# Version 0.8.0 : SHOTでMixed value approximationを使うように変更
# 持ち時間の残りが少なくなった時にプログラムが落ちる不具合を修正。
# 強化学習の棋譜生成時に経過情報の表示を追加。
VERSION="0.8.0"
22 changes: 20 additions & 2 deletions selfplay/worker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""自己対戦実行ワーカの実装。
"""
import glob
import os
import random
from typing import List
import time
from typing import List, NoReturn
import numpy as np

from board.constant import PASS, RESIGN
Expand All @@ -17,7 +19,7 @@

# pylint: disable=R0913,R0914
def selfplay_worker(save_dir: str, model_file_path: str, index_list: List[int], \
size: int, visits: int, use_gpu: bool):
size: int, visits: int, use_gpu: bool) -> NoReturn:
"""自己対戦実行ワーカ。
Args:
Expand Down Expand Up @@ -86,3 +88,19 @@ def selfplay_worker(save_dir: str, model_file_path: str, index_list: List[int],

record.set_index(index)
record.write_record(winner, board.get_komi(), is_resign, score)


def display_selfplay_progress_worker(save_dir: str, num_data: int) -> NoReturn:
"""自己対戦の進捗を表示する。
Args:
save_dir (str): 生成した棋譜ファイルが保存されるディレクトリのパス。
"""
start_time = time.time()
while True:
time.sleep(60)
current_num_data = len(glob.glob(os.path.join(save_dir, "*.sgf")))
current_time = time.time()
msg = f"Generating {current_num_data:5d}/{num_data:5d} games "
msg += f"({3600 * current_num_data / (current_time - start_time):.4f} games/hour)."
print(msg)
6 changes: 5 additions & 1 deletion selfplay_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import glob
import math
import os
import threading
import time
from concurrent.futures import ProcessPoolExecutor
import click
from board.constant import BOARD_SIZE
from selfplay.worker import selfplay_worker
from selfplay.worker import selfplay_worker, display_selfplay_progress_worker
from learning_param import SELF_PLAY_VISITS, NUM_SELF_PLAY_WORKERS, \
NUM_SELF_PLAY_GAMES

Expand Down Expand Up @@ -57,6 +58,9 @@ def selfplay_main(save_dir: str, process: int, num_data: int, size: int, \
with ProcessPoolExecutor(max_workers=process) as executor:
futures = [executor.submit(selfplay_worker, os.path.join(save_dir, str(kifu_dir_index)), \
model, file_list, size, visits, use_gpu) for file_list in file_indice]
monitoring_worker = threading.Thread(target=display_selfplay_progress_worker, \
args=(os.path.join(save_dir, str(kifu_dir_index)), num_data, ), daemon=True)
monitoring_worker.start()
for future in futures:
future.result()

Expand Down

0 comments on commit bf47a38

Please sign in to comment.