-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathutils.py
60 lines (48 loc) · 1.91 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding: utf-8 -*-
"""Utils.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1pCbkGljQ9r5BJbrp8KlmAOrkvmzlDQ4c
"""
from __future__ import division
from collections import Counter, defaultdict
import os
from random import shuffle
import tensorflow as tf
def _context_windows(region, left_size, right_size):
for i, word in enumerate(region):
start_index = i - left_size
end_index = i + right_size
left_context = _window(region, start_index, i - 1)
right_context = _window(region, i + 1, end_index)
yield (left_context, word, right_context)
def _window(region, start_index, end_index):
"""
Returns the list of words starting from `start_index`, going to `end_index`
taken from region. If `start_index` is a negative number, or if `end_index`
is greater than the index of the last word in region, this function will pad
its return value with `NULL_WORD`.
"""
last_index = len(region) + 1
selected_tokens = region[max(start_index, 0):min(end_index, last_index) + 1]
return selected_tokens
def _device_for_node(n):
if n.type == "MatMul":
return "/gpu:0"
else:
return "/cpu:0"
def _batchify(batch_size, *sequences):
for i in range(0, len(sequences[0]), batch_size):
yield tuple(sequence[i:i+batch_size] for sequence in sequences)
def _plot_with_labels(low_dim_embs, labels, path, size):
import matplotlib.pyplot as plt
assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings"
figure = plt.figure(figsize=size) # in inches
for i, label in enumerate(labels):
x, y = low_dim_embs[i, :]
plt.scatter(x, y)
plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right',
va='bottom')
if path is not None:
figure.savefig(path)
plt.close(figure)