-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
158 lines (125 loc) · 4.06 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# You must install mecab before running this script.
# $ brew install mecab
# $ brew install mecab-ipadic
# $ pip install mecab-python3
# And gensim
# $ pip install gensim
import numpy as np
from chainer import Variable, optimizers, serializers
from news_chain import NewsChain
import MeCab
from gensim import corpora, matutils
mecab = MeCab.Tagger("-Ochasen")
mecab.parse('') # まずいタイミングでtextがGCされるらしくてmecabがエラー吐く問題を対処するハック。これは酷い。
import json
# usage of json is below... (I'm a beginner in Python :P)
#
# foo = open('categories.json', 'r')
# bar = json.load(foo)
# for debug
from IPython import embed
from IPython.terminal.embed import InteractiveShellEmbed
def extract_words(text):
node = mecab.parseToNode(text)
words = []
while node:
meta = node.feature.split(",")
if meta[0] == "名詞":
words.append(node.surface)
node = node.next
return words
def get_news_list():
f = open('sample_news/news.json', 'r')
news_list = json.load(f)
news_list = [
news_list['YahooNews'],
news_list['news_line_me'],
news_list['TwitterNewsJP'],
news_list['nhk_news']]
f.close()
return news_list
def get_dictionary(news_list):
dictionary_name = 'words.txt'
words = []
for each_news in news_list:
for key in each_news:
news = each_news[key]
text = news['content']
words.append(extract_words(text))
dictionary = corpora.Dictionary(words)
dictionary.save_as_text(dictionary_name)
return dictionary
def convert_text_into_dense(dictionary, text):
words = extract_words(text)
vec = dictionary.doc2bow(words)
dense = list(matutils.corpus2dense([vec], num_terms=len(dictionary)).T[0])
return dense
def convert_text_into_variable(dictionary, text):
dense = convert_text_into_dense(dictionary, text)
return Variable(np.array([dense]))
def prepare_train_variables(dictionary, news_list):
x_list = []
y_list = []
for each_news in news_list:
for key in each_news:
news = each_news[key]
text = news['content']
dense = convert_text_into_dense(dictionary, text)
x_list.append(dense)
y_list.append(int(news['label']))
X = np.array(x_list).astype(np.float32)
Y = np.array(y_list).astype(np.int32)
N = len(X)
Y2 = Y
return X, Y, N, Y2
def train(model, optimizer, xtrain, ytrain):
n = len(xtrain)
bs = 25
for j in range(5000):
sffindx = np.random.permutation(n)
for i in range(0, n, bs):
idx = sffindx[i:(i+bs) if (i+bs) < n else n]
x = Variable(xtrain[idx])
y = Variable(ytrain[idx])
model.zerograds()
loss = model(x, y)
loss.backward()
optimizer.update()
def exec_test(model, x):
y = model.fwd(x)
ans = y.data
nrow, ncol = ans.shape
ok = 0
for i in range(nrow):
cls = np.argmax(ans[i,:])
if cls == yans[i]:
ok += 1
print(ok, "/", nrow, " = ", (ok * 1.0)/nrow)
def classify_text(model, dictionary, text):
x = convert_text_into_variable(dictionary, text)
idx = str(np.argmax(model.fwd(x).data))
f = open('categories.json', 'r')
categories = json.load(f)
f.close()
return categories[idx]
news_list = get_news_list()
dictionary = get_dictionary(news_list)
input_length = len(dictionary)
model = NewsChain(input_length)
optimizer = optimizers.Adam()
optimizer.setup(model)
X, Y, N, Y2 = prepare_train_variables(dictionary, news_list)
index = np.arange(N)
xtrain = X[index[index % 2 != 0]]
ytrain = Y2[index[index % 2 != 0]]
xtest = X[index[index % 2 == 0]]
yans = Y[index[index % 2 == 0]]
train(model, optimizer, xtrain, ytrain)
x = Variable(xtest)
exec_test(model, x)
dictionary.save_as_text('trained/words.txt')
serializers.save_npz('trained/news_classifier.npz', model)
# # saving model:
# serializers.save_npz('hoge_01.npz', model)
# # loading model:
# serializers.load_npz('hoge_01.npz', model)