-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsweep.py
133 lines (114 loc) · 4.13 KB
/
sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import subprocess
import argparse
from pathlib import Path
import shutil
import socket
import util
import run
def get_sweep_argss():
for seed in range(3):
for num_hidden_units_1 in [10, 20]:
for num_hidden_units_2 in [10, 20]:
args = run.get_args_parser().parse_args([])
args.seed = seed
args.num_hidden_units_1 = num_hidden_units_1
args.num_hidden_units_2 = num_hidden_units_2
yield args
def args_to_str(args):
result = ""
for k, v in vars(args).items():
k_str = k.replace("_", "-")
if v is None:
pass
elif isinstance(v, bool):
if v:
result += " --{}".format(k_str)
else:
if isinstance(v, list):
v_str = " ".join(map(str, v))
else:
v_str = v
result += " --{} {}".format(k_str, v_str)
return result
def main(args):
if args.cluster:
hostname = socket.gethostname()
if args.rm:
dir_ = "save/"
if Path(dir_).exists():
shutil.rmtree(dir_, ignore_errors=True)
print("-------------------------------")
print("-------------------------------")
print("-------------------------------")
util.logging.info(
"Launching {} runs on {}".format(
len(list(get_sweep_argss())), f"cluster ({hostname})" if args.cluster else "local",
)
)
print("-------------------------------")
print("-------------------------------")
print("-------------------------------")
for sweep_args in get_sweep_argss():
if args.cluster:
# SBATCH AND PYTHON CMD
args_str = args_to_str(sweep_args)
if args.no_repeat:
sbatch_cmd = "sbatch"
time_option = "12:0:0"
python_cmd = f'--wrap="MKL_THREADING_LAYER=INTEL=1 python -u run.py {args_str}""'
else:
sbatch_cmd = "om-repeat sbatch"
time_option = "2:0:0"
python_cmd = f"MKL_THREADING_LAYER=INTEL=1 python -u run.py {args_str}"
# SBATCH OPTIONS
logs_dir = f"{util.get_save_dir(sweep_args)}/logs"
Path(logs_dir).mkdir(parents=True, exist_ok=True)
job_name = util.get_save_job_name_from_args(sweep_args)
if args.priority:
partition_option = "--partition=tenenbaum "
else:
partition_option = ""
if args.titan_x:
gpu_option = ":titan-x"
else:
gpu_option = ""
gpu_memory_gb = 22
cpu_memory_gb = 16
if "openmind" in hostname:
gpu_memory_option = f"--constraint={gpu_memory_gb}GB "
else:
gpu_memory_option = ""
sbatch_options = (
f"--time={time_option} "
+ "--ntasks=1 "
+ f"--gres=gpu{gpu_option}:1 "
+ gpu_memory_option
+ f"--mem={cpu_memory_gb}G "
+ partition_option
+ f'-J "{job_name}" '
+ f'-o "{logs_dir}/%j.out" '
+ f'-e "{logs_dir}/%j.err" '
)
cmd = " ".join([sbatch_cmd, sbatch_options, python_cmd])
util.logging.info(cmd)
subprocess.call(cmd, shell=True)
else:
run.main(sweep_args)
def get_parser():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--cluster", action="store_true", help="")
parser.add_argument("--titan-x", action="store_true", help="")
parser.add_argument("--rm", action="store_true", help="")
parser.add_argument("--priority", action="store_true", help="runs on lab partition")
parser.add_argument(
"--no-repeat",
action="store_true",
help="run the jobs using standard sbatch."
"if False, queues 2h jobs with dependencies "
"until the script finishes",
)
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
main(args)