Skip to content

Commit

Permalink
Merge pull request kobanium#80 from kaorahi/support_lizzie
Browse files Browse the repository at this point in the history
GTPの互換性向上(Lizzie用)
  • Loading branch information
kobanium authored Dec 28, 2023
2 parents bf47a38 + 79483d2 commit fa80979
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 29 deletions.
22 changes: 21 additions & 1 deletion board/go_board.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""碁盤のデータ定義と操作処理。
"""
from typing import List, NoReturn
from typing import List, Tuple, NoReturn
from collections import deque
import numpy as np

Expand Down Expand Up @@ -469,6 +469,26 @@ def get_komi(self) -> float:
"""
return self.komi

def get_to_move(self) -> Stone:
"""手番の色を取得する。
Returns:
Stone: 手番の色。
"""
if self.moves == 1:
return Stone.BLACK
else:
last_move_color, _, _ = self.record.get(self.moves - 1)
return Stone.get_opponent_color(last_move_color)

def get_move_history(self) -> List[Tuple[Stone, int, np.array]]:
"""着手の履歴を取得する。
Returns:
[(Stone, int, np.array), ...]: (着手の色、座標、ハッシュ値) のリスト。
"""
return [self.record.get(m) for m in range(1, self.moves)]

def count_score(self) -> int: # pylint: disable=R0912
"""領地を簡易的にカウントする。
Expand Down
109 changes: 81 additions & 28 deletions gtp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sgf.reader import SGFReader


gtp_command_id = ""

class GtpClient: # pylint: disable=R0902,R0903
"""_Go Text Protocolクライアントの実装クラス
Expand Down Expand Up @@ -163,6 +164,19 @@ def _play(self, color: str, pos: str) -> NoReturn:

respond_success("")

def _undo(self) -> NoReturn:
"""undoコマンドを処理する。
"""
# 一旦クリアして初手から直前手まで打ち直す非効率実装
history = self.board.get_move_history()
if not history:
respond_failure("cannot undo")
return
self._clear_board()
for (color, pos, _) in history[:-1]:
self.board.put_stone(pos, color)
respond_success("")

def _genmove(self, color: str) -> NoReturn:
"""genmoveコマンドを処理する。
入力された手番で思考し、着手を生成する。
Expand Down Expand Up @@ -295,25 +309,59 @@ def _load_sgf(self, arg_list: List[str]) -> NoReturn:

respond_success("")

def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float):
"""analyzeコマンド(lz-analyze, cgos-analyze)の引数を解釈する。
不正な引数の場合は更新間隔として負値を返す。
Args:
arg_list (List[str]): コマンドの引数リスト。
Returns:
(Stone, float): 手番の色、更新間隔(秒)
"""
to_move = self.board.get_to_move()
interval = 0
error_value = (to_move, -1.0)
# 受けつける形式の例
# lz-analyze B 10
# lz-analyze B
# lz-analyze 10
# lz-analyze B interval 10
# lz-analyze interval 10
try:
if arg_list[0][0] in ['B', 'b']:
to_move = Stone.BLACK
arg_list.pop(0)
elif arg_list[0][0] in ['W', 'w']:
to_move = Stone.WHITE
arg_list.pop(0)
if arg_list[0] == "interval":
if len(arg_list) == 1:
return error_value
arg_list.pop(0)
if arg_list[0].isdigit():
interval = int(arg_list[0])/100
arg_list.pop(0)
except IndexError as e:
pass
if arg_list:
return error_value
return (to_move, interval)

def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
"""analyzeコマンド(lz-analyze, cgos-analyze)を実行する。
Args:
mode (str): 解析モード。値は"lz"か"cgos"。
arg_list (List[str]): コマンドの引数リスト (手番の色, 更新間隔)。
"""
interval = 0
if len(arg_list) >= 2:
interval = int(arg_list[1])/100

if arg_list[0][0] in ['B', 'b']:
to_move = Stone.BLACK
elif arg_list[0][0] in ['W', 'w']:
to_move = Stone.WHITE
else:
respond_failure(f"{mode}-analyze color")
to_move, interval = self._decode_analyze_arg(arg_list)
if interval < 0:
respond_failure(f"{mode}-analyze [color] [interval]")
return

respond_success("", ongoing=True)

analysis_query = {
"mode" : mode,
"interval" : interval,
Expand All @@ -328,19 +376,13 @@ def _genmove_analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
mode (str): 解析モード。値は"lz"か"cgos"。
arg_list (List[str]): コマンドの引数リスト(手番の色, 更新間隔)。
"""
color = arg_list[0]
interval = 0
if len(arg_list) >= 2:
interval = int(arg_list[1])/100

if color.lower()[0] == 'b':
genmove_color = Stone.BLACK
elif color.lower()[0] == 'w':
genmove_color = Stone.WHITE
else:
respond_failure(f"{mode}-genmove_analyze color")
genmove_color, interval = self._decode_analyze_arg(arg_list)
if interval < 0:
respond_failure(f"{mode}-analyze [color] [interval]")
return

respond_success("", ongoing=True)

if self.use_network:
# モンテカルロ木探索で着手生成
analysis_query = {
Expand Down Expand Up @@ -369,13 +411,24 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
"""Go Text Protocolのクライアントの実行処理。
入力されたコマンドに対応する処理を実行し、応答メッセージを表示する。
"""
global gtp_command_id
while True:
command = input()

command_list = command.rstrip().split(' ')

gtp_command_id = ""
input_gtp_command = command_list[0]

# 入力されたコマンドの冒頭が数字なら、それを id とみなす。
# (参照)
# Specification of the Go Text Protocol, version 2, draft 2
# の「2.5 Command Structure」
# http://www.lysator.liu.se/~gunnar/gtp/gtp2-spec-draft2/gtp2-spec.html#SECTION00035000000000000000
if input_gtp_command.isdigit():
gtp_command_id = command_list.pop(0)
input_gtp_command = command_list[0]

if input_gtp_command == "version":
_version()
elif input_gtp_command == "protocol_version":
Expand All @@ -392,6 +445,8 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
self._komi(command_list[1])
elif input_gtp_command == "play":
self._play(command_list[1], command_list[2])
elif input_gtp_command == "undo":
self._undo()
elif input_gtp_command == "genmove":
self._genmove(command_list[1])
elif input_gtp_command == "boardsize":
Expand Down Expand Up @@ -445,40 +500,38 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
self.board.display_self_atari(Stone.WHITE)
respond_success("")
elif input_gtp_command == "lz-analyze":
print_out("= ")
self._analyze("lz", command_list[1:])
print("")
elif input_gtp_command == "lz-genmove_analyze":
print_out("= ")
self._genmove_analyze("lz", command_list[1:])
elif input_gtp_command == "cgos-analyze":
print_out("= ")
self._analyze("cgos", command_list[1:])
print("")
elif input_gtp_command == "cgos-genmove_analyze":
print_out("= ")
self._genmove_analyze("cgos", command_list[1:])
elif input_gtp_command == "hash_record":
print_err(self.board.record.get_hash_history())
respond_success("")
else:
respond_failure("unknown_command")

def respond_success(response: str) -> NoReturn:
def respond_success(response: str, ongoing: bool = False) -> NoReturn:
"""コマンド処理成功時の応答メッセージを表示する。
Args:
response (str): 表示する応答メッセージ。
ongoing (bool): 追加の応答メッセージが後に続くかどうか。
"""
print("= " + response + '\n')
terminator = "" if ongoing else '\n'
print(f"={gtp_command_id} " + response + terminator)

def respond_failure(response: str) -> NoReturn:
"""コマンド処理失敗時の応答メッセージを表示する。
Args:
response (str): 表示する応答メッセージ。
"""
print("= ? " + response + '\n')
print(f"?{gtp_command_id} " + response + '\n')

def _version() -> NoReturn:
"""versionコマンドを処理する。
Expand Down

0 comments on commit fa80979

Please sign in to comment.