forked from lancopku/CMAC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
77 lines (62 loc) · 2.33 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import json
import numpy as np
class Dict(object):
def __init__(self):
self.word2id = {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, '<UNK>': 3, '<&&&>': 4}
self.id2word = {0: '<PAD>', 1: '<BOS>', 2: '<EOS>', 3: '<UNK>', 4: '<&&&>'}
self.frequency = {}
def add(self, s):
ids = []
for w in s:
if w in self.word2id:
id = self.word2id[w]
self.frequency[w] += 1
else:
id = len(self.word2id)
self.word2id[w] = id
self.id2word[id] = w
self.frequency[w] = 1
ids.append(id)
return ids
def transform(self, s):
ids = []
for w in s:
if w in self.word2id:
id = self.word2id[w]
else:
id = self.word2id['<UNK>']
ids.append(id)
return ids
def prune(self, k):
sorted_by_value = sorted(self.frequency.items(), key=lambda kv: -kv[1])
newDict = Dict()
newDict.add(list(zip(*sorted_by_value))[0][:k])
return newDict
def save(self, fout):
return json.dump({'word2id': self.word2id, 'id2word': self.id2word}, fout, ensure_ascii=False)
def load(self, fin):
datas = json.load(fin)
self.word2id = datas['word2id']
self.id2word = datas['id2word']
def __len__(self):
return len(self.word2id)
def preprocess(data_kind, dir_path, ch_dict=None):
print(data_kind)
data_path = dir_path + data_kind + '.json'
datas = json.load(open(data_path, 'r', encoding='utf-8'))
for i in range(len(datas)):
src_sentence = datas[i]['src']
tgt_sentence = datas[i]['tgt']
if ch_dict is not None:
ch_dict.add(src_sentence.split())
ch_dict.add(tgt_sentence.split())
if i % 10000 == 0:
print('{} sentences finished.'.format(i))
if ch_dict is not None:
ch_dict.save(open(dir_path + 'dict_whole.json', 'w', encoding='utf-8'))
src_dict = ch_dict.prune(50000)
src_dict.save(open(dir_path + 'dict_50000.json', 'w', encoding='utf-8'))
if __name__ == '__main__':
preprocess('valid', './data/')
preprocess('test', './data/')
preprocess('train', './data/', Dict())