-
Notifications
You must be signed in to change notification settings - Fork 1
/
yaml_easy_train.py
51 lines (45 loc) · 1.74 KB
/
yaml_easy_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from glob import glob
import random
import re
import subprocess
import sys
import yaml
best_engine_arch = subprocess.check_output(
"sh /root/misc/get_native_properties.sh | awk '{print $1}'",
shell=True
).decode().strip()
default_args = {
'seed': random.randint(9999, 9_999_999),
'tui': False,
'workspace-path': '/root/easy-train-data',
'build-engine-arch': best_engine_arch,
'network-testing-book': 'https://github.com/official-stockfish/books/raw/master/UHO_Lichess_4852_v1.epd.zip',
}
yaml_config_file = sys.argv[1]
with open(yaml_config_file, "r") as stream:
try:
args = yaml.safe_load(stream)
args = {**default_args, **args}
# if config filename contains gpu ids, automatically use them
gpus_from_filename = re.search(r"gpu(\d+)", yaml_config_file)
if gpus_from_filename:
gpus_str = ",".join(list(gpus_from_filename.group(1)))
args["gpus"] = gpus_str
# prepare an easy_train.py command for training
command = ["python3 easy_train.py"]
for key,value in sorted(args.items()):
if key == "training-dataset" and isinstance(value, list):
for dataset_component in value:
if "*" in dataset_component:
for glob_match in glob(dataset_component):
command.append(f" --{key} {glob_match}")
else:
command.append(f" --{key} {dataset_component}")
else:
command.append(f" --{key} {value}")
if len(sys.argv) > 2 and sys.argv[2] == "print":
print(" \\\n".join(command))
else:
print(" ".join(command))
except yaml.YAMLError as exc:
print(exc)