-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathquick_start.py
172 lines (141 loc) · 6.62 KB
/
quick_start.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# @Time : 2020/10/6
# @Author : Shanlei Mu
# @Email : [email protected]
"""
recbole.quick_start
########################
"""
import logging
from logging import getLogger
from recbole.config import Config
from recbole.data import create_dataset, data_preparation
from recbole.utils import init_logger, get_model, get_trainer, init_seed
from recbole.utils.utils import set_color
def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True,do_eval=False,noise_ratio=None):
r""" A fast running api, which includes the complete process of
training and testing a model on a specified dataset
Args:
model (str): model name
dataset (str): dataset name
config_file_list (list): config files used to modify experiment parameters
config_dict (dict): parameters dictionary used to modify experiment parameters
saved (bool): whether to save the model
"""
# configurations initialization
config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict)
# init_seed(config['seed'], config['reproducibility'])
# logger initialization
init_logger(config)
logger = getLogger()
import os
log_dir = os.path.dirname(logger.handlers[0].baseFilename)
config['log_dir'] = log_dir
logger.info(config)
dataset_name = dataset
# dataset filtering
dataset = create_dataset(config)
logger.info(dataset)
# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)
# model loading and initialization
model = get_model(config['model'])(config, train_data).to(config['device'])
logger.info(model)
# 设置训练和验证(在验证阶段输出embedding)
# trainer loading and initialization
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
# check_point
model_file=os.getcwd()+"/log/Checkpoint/%s/model.pth"%dataset_name
if do_eval==True:
# model evaluation
print_emb=True
if print_emb:
from recbole.utils.utils import test_embedding
embedding_matrix = model.item_embedding.weight[1:].cpu().detach().numpy()
test_embedding(embedding_matrix, range(model.n_items-1),dataset_name,"duorec")
if noise_ratio==None:
test_result = trainer.evaluate(test_data, load_best_model=saved, model_file=model_file,
show_progress=config['show_progress'])
else:
test_result = trainer.evaluate(test_data, load_best_model=saved,model_file=model_file,show_progress=config['show_progress'],noise_ratio=noise_ratio)
logger.info(set_color('test result', 'yellow') + f': {test_result}')
return {
'valid_score_bigger': config['valid_metric_bigger'],
'test_result': test_result
}
else:
# 训练阶段
# model training
best_valid_score, best_valid_result = trainer.fit(
train_data, valid_data, saved=saved, show_progress=config['show_progress']
)
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import TruncatedSVD
embedding_matrix = model.item_embedding.weight[1:].cpu().detach().numpy()
svd = TruncatedSVD(n_components=2)
svd.fit(embedding_matrix)
comp_tr = np.transpose(svd.components_)
proj = np.dot(embedding_matrix, comp_tr)
cnt = {}
for i in dataset['item_id']:
if i.item() in cnt:
cnt[i.item()] += 1
else:
cnt[i.item()] = 1
freq = np.zeros(embedding_matrix.shape[0])
for i in cnt:
freq[i-1] = cnt[i]
# freq /= freq.max()
sns.set(style='darkgrid')
sns.set_context("notebook", font_scale=1.8, rc={"lines.linewidth": 3, 'lines.markersize': 20})
plt.figure(figsize=(6, 4.5))
plt.scatter(proj[:, 0], proj[:, 1], s=1, c=freq, cmap='viridis_r')
plt.colorbar()
plt.xlim(-2, 2)
plt.ylim(-2, 2)
# plt.axis('square')
# plt.show()
plt.savefig(log_dir + '/' + config['model'] + '-' + config['dataset'] + '.pdf', format='pdf', transparent=False, bbox_inches='tight')
from scipy.linalg import svdvals
svs = svdvals(embedding_matrix)
svs /= svs.max()
np.save(log_dir + '/sv.npy', svs)
sns.set(style='darkgrid')
sns.set_context("notebook", font_scale=1.8, rc={"lines.linewidth": 3, 'lines.markersize': 20})
plt.figure(figsize=(6, 4.5))
plt.plot(svs)
# plt.show()
plt.savefig(log_dir + '/svs.pdf', format='pdf', transparent=False, bbox_inches='tight')
# model evaluation
test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress'],noise_ratio=noise_ratio)
logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
logger.info(set_color('test result', 'yellow') + f': {test_result}')
return {
'best_valid_score': best_valid_score,
'valid_score_bigger': config['valid_metric_bigger'],
'best_valid_result': best_valid_result,
'test_result': test_result
}
def objective_function(config_dict=None, config_file_list=None, saved=True):
r""" The default objective_function used in HyperTuning
Args:
config_dict (dict): parameters dictionary used to modify experiment parameters
config_file_list (list): config files used to modify experiment parameters
saved (bool): whether to save the model
"""
config = Config(config_dict=config_dict, config_file_list=config_file_list)
init_seed(config['seed'], config['reproducibility'])
logging.basicConfig(level=logging.ERROR)
dataset = create_dataset(config)
train_data, valid_data, test_data = data_preparation(config, dataset)
model = get_model(config['model'])(config, train_data).to(config['device'])
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
best_valid_score, best_valid_result = trainer.fit(train_data, valid_data, verbose=False, saved=saved)
test_result = trainer.evaluate(test_data, load_best_model=saved)
return {
'best_valid_score': best_valid_score,
'valid_score_bigger': config['valid_metric_bigger'],
'best_valid_result': best_valid_result,
'test_result': test_result
}