-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhs_e2py.py
140 lines (109 loc) · 4.41 KB
/
hs_e2py.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python
# coding: utf-8
import itertools
import os
import pickle
import numpy as np
import pandas as pd
import torch
from benchmark import KME2P_config, evaluate, loss_weights
from tqdm import auto
from tphenotype.baselines import KME2P
def evaluate_predictor(method, config, loss_weights_, splits, seed=0, epochs=50, steps=(-1,), metric="Hprc"):
results = []
for i, dataset in auto.tqdm(enumerate(splits), total=len(splits), desc=f"{method.__name__}"):
train_set, valid_set, test_set = dataset
torch.random.manual_seed(seed + i)
torch.use_deterministic_algorithms(True)
model = method(**config)
model = model.fit(train_set, loss_weights_, valid_set=valid_set, epochs=epochs, verbose=False)
scores = evaluate(model, test_set, steps)
results.append(scores[metric])
results = np.array(results)
return results, model # pyright: ignore
os.makedirs("hyperparam_selection", exist_ok=True)
def load_data(dataname, verbose=False):
with open(f"data/{dataname}_data_hs.pkl", "rb") as file:
splits = pickle.load(file)
if dataname == "Synth":
feat_list = ["x1", "x2"]
temporal_dims = [0, 1]
elif dataname == "ADNI":
feat_list = ["APOE4", "CDRSB", "Hippocampus"]
temporal_dims = [1, 2]
elif dataname == "ICU":
feat_list = ["Age", "Gender", "GCS", "PaCO2"]
temporal_dims = [2, 3]
else:
raise ValueError(f"unknown dataset {dataname}")
if verbose:
tr_set, va_set, te_set = splits[0]
_, T, x_dim = tr_set["x"].shape
_, _, y_dim = tr_set["y"].shape
print(dataname)
print(f"total samples: {len(tr_set['x'])+len(va_set['x'])+len(te_set['x'])}")
print(f"max length: {T}")
print(f"x_dim: {x_dim}")
print(f"y_dim: {y_dim}")
print(f"features: {feat_list}")
print(f"temporal dims: {temporal_dims}")
return splits, feat_list, temporal_dims
def hyperparam_selection_predictor(dataname, search_space, K, seed=0, epochs=50):
splits, feat_list, temporal_dims = load_data(dataname, verbose=True) # pylint: disable=unused-variable
tr_set, va_set, te_set = splits[0] # pylint: disable=unused-variable
_, T, x_dim = tr_set["x"].shape # pylint: disable=unused-variable
_, _, y_dim = tr_set["y"].shape
# Configuration
e2py_config = KME2P_config.copy()
e2py_config["K"] = K
e2py_config["x_dim"] = x_dim
e2py_config["y_dim"] = y_dim
e2py_config["latent_space"] = "y"
result_file = f"hyperparam_selection/{dataname}_E2Py.csv"
if os.path.exists(result_file):
search_space = {}
scores = pd.DataFrame(columns=["H_mean", "H_std", "config"])
for i, comb in enumerate(itertools.product(*search_space.values())):
if len(comb) == 0:
continue
test_config = e2py_config.copy()
msg = []
for j, k in enumerate(search_space.keys()):
if k in e2py_config:
test_config[k] = comb[j]
msg.append(f"{k}:{comb[j]}")
msg = ",".join(msg)
print(f"test config {msg} ...")
metric = "Hprc" if dataname != "Synth" else "PURITY"
results, model = evaluate_predictor( # pylint: disable=unused-variable
KME2P, test_config, loss_weights, splits, seed=seed, epochs=epochs, metric=metric
)
scores.loc[i, "H_mean"] = np.mean(results) # pyright: ignore
scores.loc[i, "H_std"] = np.std(results) # pyright: ignore
scores.loc[i, "config"] = msg
scores.to_csv(result_file)
scores = pd.read_csv(result_file, index_col=0)
scores = scores.astype({"H_mean": "float"})
best = scores["H_mean"].idxmax()
print("Optimal hyperparameters:")
print(scores.loc[best, "config"])
def read_config(config_str):
config = {}
for item in config_str.split(","):
key, val = item.split(":")
config[key] = val
return config
search_space_ = {
"hidden_size": [5, 10, 20],
"num_layers": [2],
}
for dataname_ in ["Synth", "ICU", "ADNI"]:
result_file_ = f"hyperparam_selection/{dataname_}_K_orig.csv"
scores_ = pd.read_csv(result_file_, index_col=0)
scores_ = scores_.astype({"H_mean": "float"})
best_ = scores_["H_mean"].idxmax()
config_ = read_config(scores_.loc[best_, "config"])
K_ = int(config_["K"])
if dataname_ == "Synth":
K_ = 3
hyperparam_selection_predictor(dataname_, search_space_, K_)