forked from yuzhTHU/ND2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch.py
112 lines (101 loc) · 3.53 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import json
import time
import signal
import logging
import warnings
import numpy as np
import traceback
from socket import gethostname
from argparse import ArgumentParser
from setproctitle import setproctitle
from ND2.model import NDformer
from ND2.utils import init_logger, AutoGPU, seed_all
from ND2.search import MCTS
from ND2.GDExpr import GDExpr
from ND2.search.reward_solver import RewardSolver
warnings.filterwarnings("ignore", category=RuntimeWarning)
def handler(signum, frame): raise KeyboardInterrupt
signal.signal(signal.SIGINT, handler)
signal.signal(signal.SIGTERM, handler)
logger = logging.getLogger('ND2.search')
def main(args):
# %% Load Data & Init Model
data = json.load(open('./data/synthetic/KUR.json', 'r'))
for k, v in data.items():
data[k] = np.array(v)
data['A'] = data['A'].astype(int)
data['G'] = data['G'].astype(int)
# init Rewarder
rewarder = RewardSolver(
Xv={'omega': data['omega'], 'x': data['x']},
Xe={},
A=data['A'],
G=data['G'],
Y=data['dx'],
mask=None,
)
# init NDformer
ndformer = NDformer(device=args.device)
ndformer.load('./weights/checkpoint.pth', weights_only=False)
ndformer.eval()
ndformer.set_data(
Xv={'omega': data['omega'], 'x': data['x']},
Xe={},
A=data['A'],
G=data['G'],
Y=data['dx'],
root_type='node',
cache_data_emb=True
)
# init Monte-Carlo Tree Search algorithm
est = MCTS(
rewarder=rewarder,
ndformer=ndformer,
vars_node=['x', 'omega'],
vars_edge=[],
# binary=['add', 'sub'],
# unary=['sin', 'aggr', 'sour', 'targ'],
# constant=[],
log_per_episode=10,
log_per_second=None,
beam_size=10,
use_random_simulate=False,
)
# %% Search
try:
est.fit()
except KeyboardInterrupt as e:
logger.info(f'Interrupted manually.')
except Exception:
logger.error(traceback.format_exc())
finally:
logger.note(f'Search finished. Discovered model: {GDExpr.prefix2str(est.best_model)}')
logger.note(' | '.join(f'\033[4m{k}\033[0m:{v}' for k, v in est.best_metric.items()))
save_path = f'./result/search.csv'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
json.dump(dict(
host=gethostname(),
name=args.name,
seed=args.seed,
result=est.best_model,
**est.best_metric,
), open(save_path, 'a'))
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-n', '--name', type=str, default=f'Search_{time.strftime("%Y%m%d_%H%M%S")}')
parser.add_argument('-d', '--device', type=str, default='cuda')
parser.add_argument('-s', '--seed', type=int, default=1)
parser.add_argument('--info_level', choices=['debug', 'info', 'note', 'warning', 'error', 'critical'], default='info')
args, unknown = parser.parse_known_args()
if unknown:
warnings.warn(f'Unknown args: {unknown}')
init_logger(args.name, f'./log/search/{args.name}/info.log', root_name='ND2', info_level=args.info_level)
setproctitle(f'{args.name}@ZihanYu')
if args.seed is None:
args.seed = np.random.randint(0, 32768)
seed_all(args.seed)
if args.device == 'auto':
args.device = AutoGPU().choice_gpu(900, interval=15, force=False)
logger.info(f'Args: {args}')
main(args)