-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #110 from kobanium/develop
Develop
- Loading branch information
Showing
19 changed files
with
756 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ | |
**/*.npz | ||
**/*.bin | ||
**/*.ckpt | ||
archive/* | ||
archive/* | ||
.coverage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import sys | ||
import select | ||
import time | ||
|
||
|
||
def animate_mcts(mcts, board, to_move, pv_wait_sec, move_wait_sec): | ||
previous_pv = [] | ||
def callback(path): | ||
_animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv) | ||
finished = _stdin_has_data() | ||
return finished | ||
mcts.search_with_callback(board, to_move, callback) | ||
|
||
|
||
def _stdin_has_data(): | ||
rlist, _, _ = select.select([sys.stdin], [], [], 0) | ||
return bool(rlist) | ||
|
||
|
||
def _animate_path(path, mcts, board, pv_wait_sec, move_wait_sec, previous_pv): | ||
# 今回探索した系列の属性値 | ||
root_index, i = path[0] | ||
root = mcts.node[root_index] | ||
if root.children_visits[i] == 0: | ||
return | ||
coordinate = board.coordinate | ||
move = coordinate.convert_to_gtp_format(root.action[i]) | ||
pv = [coordinate.convert_to_gtp_format(mcts.node[index].action[child_index]) for (index, child_index) in path] | ||
pv_visits = [str(mcts.node[index].children_visits[child_index]) for (index, child_index) in path] | ||
pv_winrate = [str(int(10000 * _get_winrate(mcts, index, child_index, depth))) for depth, (index, child_index) in enumerate(path)] | ||
|
||
# lz-analyze の本来の出力内容を加工 | ||
children_status_list = root.get_analysis_status_list(board, mcts.get_pv_lists) | ||
fake_status_list = [status.copy() for status in children_status_list] | ||
target = next((status for status in fake_status_list if status["move"] == move), None) | ||
if target is None: | ||
return # can't happen | ||
# 今回探索した系列の初手を最善手と偽って順位をふり直す | ||
target["order"] = -1 | ||
fake_status_list.sort(key=lambda status: status["order"]) | ||
for order, status in enumerate(fake_status_list): | ||
status["order"] = order | ||
|
||
# PV 欄を差しかえながら複数回出力することで一手ずつアニメーション | ||
for k in range(1, len(pv) + 1): | ||
# 前回の系列と共通な手順はスキップ | ||
if pv[:k] == previous_pv[:k]: | ||
continue | ||
|
||
target["pv"] = " ".join(pv[:k]) | ||
target["pvVisits"] = " ".join(pv_visits[:k]) | ||
target["pvWinrate"] = " ".join(pv_winrate[:k]) | ||
|
||
sys.stdout.write(root.get_analysis_from_status_list("lz", fake_status_list)) | ||
sys.stdout.flush() | ||
time.sleep(max(move_wait_sec, 0.0)) | ||
|
||
previous_pv[:] = pv | ||
time.sleep(max(pv_wait_sec, 0.0)) | ||
|
||
|
||
def _get_winrate(mcts, index, child_index, depth): | ||
node = mcts.node[index] | ||
i = child_index | ||
visits = node.children_visits[i] | ||
value = node.children_value_sum[i] / visits if visits > 0 else node.children_value[i] | ||
winrate = value if depth % 2 == 0 else 1.0 - value | ||
return winrate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# About tree visualization | ||
TamaGo supports visualization of a search treem. | ||
|
||
## Example | ||
``` | ||
(echo 'tamago-readsgf (;SZ[9]KM[7];B[fe];W[de];B[ec])'; | ||
echo 'lz-genmove_analyze 7777777'; | ||
echo 'undo'; | ||
echo 'tamago-dump_tree') \ | ||
| python3 main.py --model model/model.bin --strict-visits 100 \ | ||
| grep dump_version | gzip > tree.json.gz | ||
python3 graph/plot_tree.py tree.json.gz tree_graph | ||
display tree_graph.svg | ||
``` | ||
|
||
![Result of search tree visualization](../../img/tree_graph.png) | ||
|
||
## Command line arguments for graph/plot_tree.py | ||
|
||
| Argument | Description | Value | Example of value | Node | | ||
|---|---|---|---|---| | ||
| INPUT_JSON_PATH | Path to a .json file which has a result of tamago-dump_tree command. | tree.json.gz | | | ||
| OUTPUT_IMAGE_PATH | Path to a image file which has a visualization result. | tree_graph | | Automatically assigned extension(.svg) | | ||
|
||
## Option for graph/plot_tree.py | ||
|
||
| Option | Description | Value | Example of value | Default value | Node | | ||
|---|---|---|---|---|---| | ||
| `--around-pv` | Flag to include all tree nodes. | true or false | true | false | | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# 探索木の可視化について | ||
TamaGoは探索木の可視化機能をサポートしています。 | ||
|
||
## 可視化機能の実行例 | ||
``` | ||
(echo 'tamago-readsgf (;SZ[9]KM[7];B[fe];W[de];B[ec])'; | ||
echo 'lz-genmove_analyze 7777777'; | ||
echo 'undo'; | ||
echo 'tamago-dump_tree') \ | ||
| python3 main.py --model model/model.bin --strict-visits 100 \ | ||
| grep dump_version | gzip > tree.json.gz | ||
python3 graph/plot_tree.py tree.json.gz tree_graph | ||
display tree_graph.svg | ||
``` | ||
|
||
![探索木の可視化結果](../../img/tree_graph.png) | ||
|
||
## graph/plot_tree.pyのコマンドライン引数 | ||
|
||
| 引数 | 概要 | 設定する値 | 設定値の例 | 備考 | | ||
|---|---|---|---|---| | ||
| INPUT_JSON_PATH | tamago-dump_treeコマンドを実行した結果のJSONファイルのパス | tree.json.gz | | | ||
| OUTPUT_IMAGE_PATH | 可視化結果を保持する画像ファイルのパス | tree_graph | | 拡張子(.svg)が自動的に付与される | | ||
|
||
## graph/plot_tree.pyのオプション | ||
|
||
| オプション | 概要 | 設定する値 | 設定値の例 | デフォルト値 | 備考 | | ||
|---|---|---|---|---|---| | ||
| `--around-pv` | 主分岐のまわりのみ表示するフラグ | true または false | true | false | | |
Oops, something went wrong.