-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathutil.py
218 lines (162 loc) · 7.76 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import collections
import sys
import zlib
from itertools import product, chain, groupby
from string import ascii_lowercase, digits
import requests
URL_TEMPLATE = 'http://storage.googleapis.com/books/ngrams/books/{}'
# URL_TEMPLATE = 'http://localhost:8001/{}'
FILE_TEMPLATE = 'googlebooks-{lang}-all-{ngram_len}gram-{version}-{index}.gz'
Record = collections.namedtuple('Record', 'ngram year match_count volume_count')
class StreamInterruptionError(Exception):
"""Raised when a data stream ends before the end of the file"""
def __init__(self, url, message):
self.url = url
self.message = message
def readline_google_store(ngram_len, lang='eng', indices=None, chunk_size=1024 ** 2, verbose=False):
"""Iterate over the data in the Google ngram collectioin.
:param int ngram_len: the length of ngrams to be streamed.
:param str lang: the langueage of the ngrams.
:param iter indices: the file indices to be downloaded.
:param int chunk_size: the size the chunks of raw compressed data.
:param bool verbose: if `True`, then the debug information is shown to `sys.stderr`.
:returns: a iterator over triples `(fname, url, records)`
"""
for fname, url, request in iter_google_store(ngram_len, verbose=verbose, lang=lang, indices=indices):
dec = zlib.decompressobj(32 + zlib.MAX_WBITS)
def lines():
last = b''
compressed_chunks = request.iter_content(chunk_size=chunk_size)
for i, compressed_chunk in enumerate(compressed_chunks):
chunk = dec.decompress(compressed_chunk)
lines = (last + chunk).split(b'\n')
lines, last = lines[:-1], lines[-1]
for line in lines:
line = line.decode('utf-8')
data = line.split('\t')
assert len(data) == 4
ngram = data[0]
other = map(int, data[1:])
yield Record(ngram, *other)
if last:
raise StreamInterruptionError(
url,
"Data stream ended on a non-empty line. This might be due "
"to temporary networking problems.")
yield fname, url, lines()
def ngram_to_cooc(ngram, count, index):
ngram = ngram.split()
middle_index = len(ngram) // 2
item = ngram[middle_index]
context = ngram[:middle_index] + ngram[middle_index + 1:]
item_id = word_to_id(item, index)
context_ids = (word_to_id(c, index) for c in context)
return tuple((p, count) for p in product([item_id], context_ids))
def word_to_id(word, index):
try:
return index[word]
except KeyError:
id_ = len(index)
index[word] = id_
return id_
def count_coccurrence(records, index):
grouped_records = groupby(records, key=lambda r: r.ngram)
ngram_counts = ((ngram, sum(r.match_count for r in records)) for ngram, records in grouped_records)
cooc = (ngram_to_cooc(ngram, count, index) for ngram, count in ngram_counts)
counter = collections.Counter()
for item, count in chain.from_iterable(cooc):
counter[item] += count
return counter
def iter_google_store(ngram_len, lang="eng", indices=None, verbose=False):
"""Iterate over the collection files stored at Google.
:param int ngram_len: the length of ngrams to be streamed.
:param str lang: the langueage of the ngrams.
:param iter indices: the file indices to be downloaded.
:param bool verbose: if `True`, then the debug information is shown to `sys.stderr`.
"""
version = '20120701'
session = requests.Session()
indices = get_indices(ngram_len) if indices is None else indices
for index in indices:
fname = FILE_TEMPLATE.format(
lang=lang,
ngram_len=ngram_len,
version=version,
index=index,
)
url = URL_TEMPLATE.format(fname)
if verbose:
sys.stderr.write(
'Downloading {url} '
''.format(
url=url,
),
)
sys.stderr.flush()
request = session.get(url, stream=True)
assert request.status_code == 200
yield fname, url, request
if verbose:
sys.stderr.write('\n')
def get_indices(ngram_len):
"""Generate the file indeces depening on the ngram length, based on version 20120701.
For 1grams it is::
0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o other p pos
punctuation q r s t u v w x y z
For others::
0 1 2 3 4 5 6 7 8 9 _ADJ_ _ADP_ _ADV_ _CONJ_ _DET_ _NOUN_ _NUM_ _PRON_
_PRT_ _VERB_ a_ aa ab ac ad ae af ag ah ai aj ak al am an ao ap aq ar
as at au av aw ax ay az b_ ba bb bc bd be bf bg bh bi bj bk bl bm bn bo
bp bq br bs bt bu bv bw bx by bz c_ ca cb cc cd ce cf cg ch ci cj ck cl
cm cn co cp cq cr cs ct cu cv cw cx cy cz d_ da db dc dd de df dg dh di
dj dk dl dm dn do dp dq dr ds dt du dv dw dx dy dz e_ ea eb ec ed ee ef
eg eh ei ej ek el em en eo ep eq er es et eu ev ew ex ey ez f_ fa fb fc
fd fe ff fg fh fi fj fk fl fm fn fo fp fq fr fs ft fu fv fw fx fy fz g_
ga gb gc gd ge gf gg gh gi gj gk gl gm gn go gp gq gr gs gt gu gv gw gx
gy gz h_ ha hb hc hd he hf hg hh hi hj hk hl hm hn ho hp hq hr hs ht hu
hv hw hx hy hz i_ ia ib ic id ie if ig ih ii ij ik il im in io ip iq ir
is it iu iv iw ix iy iz j_ ja jb jc jd je jf jg jh ji jj jk jl jm jn jo
jp jq jr js jt ju jv jw jx jy jz k_ ka kb kc kd ke kf kg kh ki kj kk kl
km kn ko kp kq kr ks kt ku kv kw kx ky kz l_ la lb lc ld le lf lg lh li
lj lk ll lm ln lo lp lq lr ls lt lu lv lw lx ly lz m_ ma mb mc md me mf
mg mh mi mj mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz n_ na nb nc
nd ne nf ng nh ni nj nk nl nm nn no np nq nr ns nt nu nv nw nx ny nz o_
oa ob oc od oe of og oh oi oj ok ol om on oo op oq or os ot other ou ov
ow ox oy oz p_ pa pb pc pd pe pf pg ph pi pj pk pl pm pn po pp pq pr ps
pt pu punctuation pv pw px py pz q_ qa qb qc qd qe qf qg qh qi qj qk ql
qm qn qo qp qq qr qs qt qu qv qw qx qy qz r_ ra rb rc rd re rf rg rh ri
rj rk rl rm rn ro rp rq rr rs rt ru rv rw rx ry rz s_ sa sb sc sd se sf
sg sh si sj sk sl sm sn so sp sq sr ss st su sv sw sx sy sz t_ ta tb tc
td te tf tg th ti tj tk tl tm tn to tp tq tr ts tt tu tv tw tx ty tz u_
ua ub uc ud ue uf ug uh ui uj uk ul um un uo up uq ur us ut uu uv uw ux
uy uz v_ va vb vc vd ve vf vg vh vi vj vk vl vm vn vo vp vq vr vs vt vu
vv vw vx vy vz w_ wa wb wc wd we wf wg wh wi wj wk wl wm wn wo wp wq wr
ws wt wu wv ww wx wy wz x_ xa xb xc xd xe xf xg xh xi xj xk xl xm xn xo
xp xq xr xs xt xu xv xw xx xy xz y_ ya yb yc yd ye yf yg yh yi yj yk yl
ym yn yo yp yq yr ys yt yu yv yw yx yy yz z_ za zb zc zd ze zf zg zh zi
zj zk zl zm zn zo zp zq zr zs zt zu zv zw zx zy zz
Nothe, there is not index "qk" for 5grams.
See http://storage.googleapis.com/books/ngrams/books/datasetsv2.html for
more details.
"""
other_indices = ('other', 'punctuation')
if ngram_len == 1:
letter_indices = ascii_lowercase
other_indices += 'pos',
else:
letter_indices = ((''.join(i) for i in product(ascii_lowercase, ascii_lowercase + '_')))
if ngram_len == 5:
letter_indices = (l for l in letter_indices if l != 'qk')
other_indices += (
'_ADJ_',
'_ADP_',
'_ADV_',
'_CONJ_',
'_DET_',
'_NOUN_',
'_NUM_',
'_PRON_',
'_PRT_',
'_VERB_',
)
return chain(digits, letter_indices, other_indices)