-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
64 lines (59 loc) · 2.23 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from dataset import get_data, get_iterators
from fitness import fit, evaluate
from utils import load_embedding
import pickle
if __name__ == "__main__":
config = {
"embedding_path": "./dataset/embeddings/w2v/w2v_title_300Epochs_1MinCount_9ContextWindow_100d.txt",
"epochs": 30,
"lr": 1e-03,
"rnn_units": 100,
"convs_filter_banks": 32,
"convs_kernel_size": 2,
"denses_depth1": 32,
"denses_depth2": 16,
"similarity_type": "dot",
"automl_path": './exps/w2v_10Epochs_100d_CrossEntropy_BothDenses.pickle',
}
if config["automl_path"]:
with open(config["automl_path"], "rb") as f:
params = pickle.load(f)
max_f1 = 0
best_params = {}
for param in params:
if param[1] > max_f1:
best_params = param[0]
max_f1 = param[1]
config["lr"] = best_params["lr"]
config["rnn_units"] = int(best_params["rnn_units"])
config["convs_filter_banks"] = int(best_params["convs_filter_banks"])
config["denses_depth1"] = int(best_params["dense_depth1"])
config["denses_depth2"] = int(best_params["dense_depth1"])
config["similarity_type"] = (
"dot" if best_params["similarity_type"] == 0 else "cosine"
)
print(config)
train_ds, valid_ds, test_ds, TEXT = get_data(
train_path="./dataset/computers/train/computers_train_splitted_medium.json",
valid_path="./dataset/computers/valid/computers_splitted_valid_medium.json",
test_path="./dataset/computers/test/computers_gs.json",
)
train_dl, valid_dl, test_dl = get_iterators(train_ds, valid_ds, test_ds)
load_embedding(TEXT, config["embedding_path"])
model = fit(
TEXT,
train_dl,
valid_dl,
config=config,
hidden_dim=config["rnn_units"],
conv_depth=config["convs_filter_banks"],
kernel_size=config["convs_kernel_size"],
dense_depth1=config["denses_depth1"],
dense_depth2=config["denses_depth2"],
lr=config["lr"],
similarity="dot",
loss="CrossEntropyLoss",
validate_each_epoch=True,
trainable=True
)
evaluate(model, test_dl, print_results=True)