-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
124 lines (103 loc) · 5.48 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import keras.models
from NewsContent import *
from UserContent import *
from preprocessing import *
from Generator import *
import Generator
from AttentionModel import *
from utils import *
from model import *
import os
import numpy as np
import json
import random
print("System version: {}".format(sys.version))
print("Tensorflow version: {}".format(tf.__version__))
print(tf.test.is_gpu_available())
seed = 25
os.environ['PYTHONHASHSEED'] = str(seed)
tf.compat.v1.set_random_seed(seed)
random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
os.environ['TF_DETERMINISTIC_OPS'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
sess = tf.compat.v1.Session(config=config)
# data_root_path = None
embedding_path = '/home/ywy/PycharmProjects/News_Recommendation/data/word_embedding'
# KG_root_path = None
model_save = '/home/ywy/PycharmProjects/News_Recommendation/model_checkpoint/news_model.h5'
popularity_path = '../popularity/'
config = {'title_length': 30,
'body_length': 100,
'max_clicked_news': 50,
'np_ratio': 1,
'news_encoder_name': "SelfAtt",
'user_encoder_name': "popularity_user_modeling",
'attrs': ['title', 'entity'],
'word_filter': 0,
'data_root_path': '/home/ywy/PycharmProjects/News_Recommendation/data',
'embedding_path': '/home/ywy/PycharmProjects/News_Recommendation/data/word_embedding',
'KG_root_path': '/home/ywy/PycharmProjects/News_Recommendation/data/entity',
'popularity_path': '/home/ywy/PycharmProjects/News_Recommendation/data/popularity',
'max_entity_num': 5,
'day': 1,
'popularity_hp': 1,
'title_entity_popularity': 0.5,
'combine': "mean"
}
News = NewsContent(config)
TrainUsers = UserContent(News.news_index, News.entity_popularity, News.news_entity_index, News.word_popularity,
News.title, config, 'train.tsv')
ValidUsers = UserContent(News.news_index, News.entity_popularity, News.news_entity_index, News.word_popularity,
News.title, config, 'val.tsv')
# TestUsers = UserContent(News.news_index, News.entity_popularity, News.news_entity_index, News.word_popularity,
# News.title, config, 'test.tsv')
train_sess, train_buckets, train_news_exist_time, train_user_id, train_label, uid = get_train_input(TrainUsers.session,
News.news_index,
config)
val_impressions, val_userids, val_uid = get_test_input(ValidUsers.session, News.news_index)
# test_impressions, test_userids, test_uid= get_test_input(TestUsers.session, News.news_index)
title_word_embedding_matrix, have_word = load_matrix(embedding_path, News.word_dict)
train_generator = TrainGenerator(News, TrainUsers, train_sess, train_user_id, train_buckets, train_news_exist_time,
train_label, uid, 2)
val_user_generator = UserGenerator(News, ValidUsers, val_uid, 2)
# test_user_generator = UserGenerator(News, TestUsers, test_uid, 2)
val_popularity_generator = PopularityGenerator(val_impressions, 1)
# test_popularity_generator = PopularityGenerator(test_impressions, 1)
news_generator = NewsGenerator(News, 2)
for i in range(50):
model_config = {
'fusion_hp': 0.2,
'popularity_hp': 0.2,
'popularity_time': 2}
# seed = 25
print(model_config['popularity_time'])
model, user_encoder, news_encoder, activity_gater, popularity_encoder, popularity_fusion_encoder = create_pe_model(
config, model_config, News, title_word_embedding_matrix, News.entity_embedding, seed)
if os.path.exists(model_save):
model.load_weights(model_save)
# model.summary()
model.fit(train_generator, epochs=1)
model.save_weights(model_save)
news_scoring = news_encoder.predict(news_generator, verbose=True)
# user_scoring = user_encoder.predict_generator(test_user_generator, verbose=True)
val_user_scoring = user_encoder.predict(val_user_generator, verbose=True)
val_activity_gater = activity_gater.predict(val_user_scoring, verbose=True)
# test_activity_gater = activity_gater.predict(user_scoring, verbose=True)
val_popularity_vec = popularity_fusion_encoder.predict(val_popularity_generator, verbose=True)
# test_popularity_vec = popularity_fusion_encoder.predict(test_popularity_generator, verbose=True)
val_popularity = popularity_encoder.predict(val_popularity_vec, verbose=True)
# test_popularity = popularity_encoder.predict(test_popularity_vec, verbose=True)
val_activity_gater = val_activity_gater[:, 0]
# test_activity_gater = test_activity_gater[:, 0]
# test_rankings = news_ranking(user_scoring, news_scoring, test_impressions, test_activity_gater, test_popularity)
val_rankings = news_ranking(val_user_scoring, news_scoring, val_impressions, val_activity_gater, val_popularity)
# AUC_1, MRR_1, nDCG5_1, nDCG10_1 = evaluate_performance(test_rankings, test_impressions)
# print('test AUC:', AUC_1, 'MRR:', MRR_1, 'nDCG5:', nDCG5_1, 'nDCG10:', nDCG10_1)
AUC_2, MRR_2, nDCG5_2, nDCG10_2 = evaluate_performance(val_rankings, val_impressions)
print('Val AUC:', AUC_2, 'MRR:', MRR_2, 'nDCG5:', nDCG5_2, 'nDCG10:', nDCG10_2)