Skip to content

Commit

Permalink
Add new TimeControl mode "STRICT_PLAYOUT"
Browse files Browse the repository at this point in the history
  • Loading branch information
kaorahi committed Jan 3, 2024
1 parent 3606c22 commit c454642
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 18 deletions.
2 changes: 1 addition & 1 deletion gtp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self, board_size: int, superko: bool, model_file_path: str, \
self.use_sequential_halving = use_sequential_halving
self.use_network = False

if mode is TimeControl.CONSTANT_PLAYOUT:
if mode is TimeControl.CONSTANT_PLAYOUT or mode is TimeControl.STRICT_PLAYOUT:
self.time_manager = TimeManager(mode=mode, constant_visits=visits)
if mode is TimeControl.CONSTANT_TIME:
self.time_manager = TimeManager(mode=mode, constant_time=const_time)
Expand Down
9 changes: 8 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
help="コミの値の設定。デフォルトは7.0。")
@click.option('--visits', type=click.IntRange(min=1), default=1000, \
help="1手あたりの探索回数の指定。デフォルトは1000。\
--strict-visitsオプション、--const-timeオプション、または--timeオプションが指定された時は無視する。")
@click.option('--strict-visits', type=click.IntRange(min=1), \
help="1手あたりの探索回数の厳密指定(着手が確定しても打ち切らない)。\
--const-timeオプション、または--timeオプションが指定された時は無視する。")
@click.option('--const-time', type=click.FLOAT, \
help="1手あたりの探索時間の指定。--timeオプションが指定された時は無視する。")
Expand All @@ -42,7 +45,7 @@
@click.option('--cgos-mode', type=click.BOOL, default=False, \
help="全ての石を打ち上げるまでパスしないモード設定。デフォルトはFalse。")
def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halving: bool, \
policy_move: bool, komi: float, visits: int, const_time: float, time: float, \
policy_move: bool, komi: float, visits: int, strict_visits: int, const_time: float, time: float, \
batch_size: int, tree_size: int, cgos_mode: bool):
"""GTPクライアントの起動。
Expand All @@ -55,6 +58,7 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv
sequential_halving (bool): Gumbel AlphaZeroの探索手法で着手生成するフラグ。デフォルトはFalse。
komi (float): コミの値。デフォルトは7.0。
visits (int): 1手あたりの探索回数。デフォルトは1000。
strict_visits (int): 1手あたりの厳密な探索回数(着手が確定しても打ち切らない)。
const_time (float): 1手あたりの探索時間。
time (float): 対局時の持ち時間。
batch_size (int): 探索実行時のニューラルネットワークのミニバッチサイズ。デフォルトはNN_BATCH_SIZE。
Expand All @@ -63,6 +67,9 @@ def gtp_main(size: int, superko: bool, model:str, use_gpu: bool, sequential_halv
"""
mode = TimeControl.CONSTANT_PLAYOUT

if strict_visits is not None:
mode = TimeControl.STRICT_PLAYOUT
visits = strict_visits
if const_time is not None:
mode = TimeControl.CONSTANT_TIME
if time is not None:
Expand Down
32 changes: 18 additions & 14 deletions mcts/time_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TimeControl(Enum):
CONSTANT_PLAYOUT = 0
CONSTANT_TIME = 1
TIME_CONTROL = 2
STRICT_PLAYOUT = 3 # 着手が確定しても打ち切らず指定回数までプレイアウト


class TimeManager:
Expand Down Expand Up @@ -66,7 +67,7 @@ def get_num_visits_threshold(self, color: Stone) -> int:
Returns:
int: 探索回数の閾値。
"""
if self.mode == TimeControl.CONSTANT_PLAYOUT:
if self.mode == TimeControl.CONSTANT_PLAYOUT or self.mode == TimeControl.STRICT_PLAYOUT:
self.time_limit = 10000.0
return int(self.constant_visits)
if self.mode == TimeControl.CONSTANT_TIME:
Expand Down Expand Up @@ -142,18 +143,21 @@ def is_time_over(self) -> bool:
return False


def is_move_decided(root: MCTSNode, threshold: int) -> bool:
"""着手が決定したか否かを判定する。
def is_move_decided(self, root: MCTSNode, threshold: int) -> bool:
"""着手が決定したか否かを判定する。
Args:
root (MCTSNode): 現局面のルートノード。
threshold (int): 探索回数の閾値。
Args:
root (MCTSNode): 現局面のルートノード。
threshold (int): 探索回数の閾値。
Returns:
bool: 探索打ち切り判定結果。
"""
sorted_visits = sorted(root.children_visits)
remaining_visits = threshold - root.node_visits
if sorted_visits[-1] - sorted_visits[-2] > remaining_visits:
return True
return False
Returns:
bool: 探索打ち切り判定結果。
"""
sorted_visits = sorted(root.children_visits)
remaining_visits = threshold - root.node_visits
cutoff = sorted_visits[-1] - sorted_visits[-2]
if self.mode == TimeControl.STRICT_PLAYOUT:
cutoff = 0
if remaining_visits < cutoff:
return True
return False
4 changes: 2 additions & 2 deletions mcts/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
MAX_CONSIDERED_NODES, RESIGN_THRESHOLD, MCTS_TREE_SIZE
from mcts.sequential_halving import get_candidates_and_visit_pairs
from mcts.node import MCTSNode
from mcts.time_manager import TimeControl, TimeManager, is_move_decided
from mcts.time_manager import TimeControl, TimeManager

class MCTSTree: # pylint: disable=R0902
"""モンテカルロ木探索の実装クラス。
Expand Down Expand Up @@ -148,7 +148,7 @@ def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
start_color = color
self.search_mcts(search_board, start_color, self.current_root, [])
if time_manager.is_time_over() or \
is_move_decided(self.get_root(), threshold):
time_manager.is_move_decided(self.get_root(), threshold):
break

if len(analysis_query) > 0:
Expand Down

0 comments on commit c454642

Please sign in to comment.