Skip to content

Commit

Permalink
Improve GTP compatibility for supporting Lizzie
Browse files Browse the repository at this point in the history
  • Loading branch information
kaorahi committed Dec 22, 2023
1 parent 3a5461d commit c5002e1
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 28 deletions.
12 changes: 12 additions & 0 deletions board/go_board.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,18 @@ def get_komi(self) -> float:
"""
return self.komi

def get_to_move(self) -> Stone:
"""手番の色を取得する。
Returns:
Stone: 手番の色。
"""
if self.moves == 1:
return Stone.BLACK
else:
last_move_color, _, _ = self.record.get(self.moves - 1)
return Stone.get_opponent_color(last_move_color)

def count_score(self) -> int: # pylint: disable=R0912
"""領地を簡易的にカウントする。
Expand Down
94 changes: 66 additions & 28 deletions gtp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sgf.reader import SGFReader


gtp_command_id = ""

class GtpClient: # pylint: disable=R0902,R0903
"""_Go Text Protocolクライアントの実装クラス
Expand Down Expand Up @@ -295,25 +296,59 @@ def _load_sgf(self, arg_list: List[str]) -> NoReturn:

respond_success("")

def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float):
"""analyzeコマンド(lz-analyze, cgos-analyze)の引数を解釈する。
不正な引数の場合は更新間隔として負値を返す。
Args:
arg_list (List[str]): コマンドの引数リスト。
Returns:
(Stone, float): 手番の色、更新間隔(秒)
"""
to_move = self.board.get_to_move()
interval = 0
error_value = (to_move, -1.0)
# 受けつける形式の例
# lz-analyze B 10
# lz-analyze B
# lz-analyze 10
# lz-analyze B interval 10
# lz-analyze interval 10
try:
if arg_list[0][0] in ['B', 'b']:
to_move = Stone.BLACK
arg_list.pop(0)
elif arg_list[0][0] in ['W', 'w']:
to_move = Stone.WHITE
arg_list.pop(0)
if arg_list[0] == "interval":
if len(arg_list) == 1:
return error_value
arg_list.pop(0)
if arg_list[0].isdigit():
interval = int(arg_list[0])/100
arg_list.pop(0)
except IndexError as e:
pass
if arg_list:
return error_value
return (to_move, interval)

def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
"""analyzeコマンド(lz-analyze, cgos-analyze)を実行する。
Args:
mode (str): 解析モード。値は"lz"か"cgos"。
arg_list (List[str]): コマンドの引数リスト (手番の色, 更新間隔)。
"""
interval = 0
if len(arg_list) >= 2:
interval = int(arg_list[1])/100

if arg_list[0][0] in ['B', 'b']:
to_move = Stone.BLACK
elif arg_list[0][0] in ['W', 'w']:
to_move = Stone.WHITE
else:
respond_failure(f"{mode}-analyze color")
to_move, interval = self._decode_analyze_arg(arg_list)
if interval < 0:
respond_failure(f"{mode}-analyze [color] [interval]")
return

respond_success("", ongoing=True)

analysis_query = {
"mode" : mode,
"interval" : interval,
Expand All @@ -328,19 +363,13 @@ def _genmove_analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
mode (str): 解析モード。値は"lz"か"cgos"。
arg_list (List[str]): コマンドの引数リスト(手番の色, 更新間隔)。
"""
color = arg_list[0]
interval = 0
if len(arg_list) >= 2:
interval = int(arg_list[1])/100

if color.lower()[0] == 'b':
genmove_color = Stone.BLACK
elif color.lower()[0] == 'w':
genmove_color = Stone.WHITE
else:
respond_failure(f"{mode}-genmove_analyze color")
genmove_color, interval = self._decode_analyze_arg(arg_list)
if interval < 0:
respond_failure(f"{mode}-analyze [color] [interval]")
return

respond_success("", ongoing=True)

if self.use_network:
# モンテカルロ木探索で着手生成
analysis_query = {
Expand Down Expand Up @@ -369,13 +398,24 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
"""Go Text Protocolのクライアントの実行処理。
入力されたコマンドに対応する処理を実行し、応答メッセージを表示する。
"""
global gtp_command_id
while True:
command = input()

command_list = command.rstrip().split(' ')

gtp_command_id = ""
input_gtp_command = command_list[0]

# 入力されたコマンドの冒頭が数字なら、それを id とみなす。
# (参照)
# Specification of the Go Text Protocol, version 2, draft 2
# の「2.5 Command Structure」
# http://www.lysator.liu.se/~gunnar/gtp/gtp2-spec-draft2/gtp2-spec.html#SECTION00035000000000000000
if input_gtp_command.isdigit():
gtp_command_id = command_list.pop(0)
input_gtp_command = command_list[0]

if input_gtp_command == "version":
_version()
elif input_gtp_command == "protocol_version":
Expand Down Expand Up @@ -445,40 +485,38 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
self.board.display_self_atari(Stone.WHITE)
respond_success("")
elif input_gtp_command == "lz-analyze":
print_out("= ")
self._analyze("lz", command_list[1:])
print("")
elif input_gtp_command == "lz-genmove_analyze":
print_out("= ")
self._genmove_analyze("lz", command_list[1:])
elif input_gtp_command == "cgos-analyze":
print_out("= ")
self._analyze("cgos", command_list[1:])
print("")
elif input_gtp_command == "cgos-genmove_analyze":
print_out("= ")
self._genmove_analyze("cgos", command_list[1:])
elif input_gtp_command == "hash_record":
print_err(self.board.record.get_hash_history())
respond_success("")
else:
respond_failure("unknown_command")

def respond_success(response: str) -> NoReturn:
def respond_success(response: str, ongoing: bool = False) -> NoReturn:
"""コマンド処理成功時の応答メッセージを表示する。
Args:
response (str): 表示する応答メッセージ。
ongoing (bool): 追加の応答メッセージが後に続くかどうか。
"""
print("= " + response + '\n')
terminator = "" if ongoing else '\n'
print(f"={gtp_command_id} " + response + terminator)

def respond_failure(response: str) -> NoReturn:
"""コマンド処理失敗時の応答メッセージを表示する。
Args:
response (str): 表示する応答メッセージ。
"""
print("= ? " + response + '\n')
print(f"?{gtp_command_id} " + response + '\n')

def _version() -> NoReturn:
"""versionコマンドを処理する。
Expand Down

0 comments on commit c5002e1

Please sign in to comment.