forked from deeplearning4j/deeplearning4j
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_word_coords.py
36 lines (33 loc) · 944 Bytes
/
plot_word_coords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 18 02:49:59 2014
@author: mcs
"""
import sys, numpy
from numpy import genfromtxt
import matplotlib.pyplot as plt
coords_filename = sys.argv[1]
data = genfromtxt(coords_filename, delimiter=',', dtype=('f20,f20,a50'))
x = data['f0']
y = data['f1']
labels = data['f2']
if len(sys.argv) > 2:
pruning_words = sys.argv[2]
words_to_prune = {}
with open(pruning_words, 'r') as prune_file:
for line in prune_file:
toks = line.split()
if len(toks) > 1:
raise Exception("too many tokens per line")
words_to_prune[toks[0]] = 1
if len(words_to_prune) > 0:
rows_to_prune = []
for i,label in enumerate(labels):
if label in words_to_prune:
rows_to_prune.append(i)
data = numpy.delete(data, rows_to_prune)
x = data['f0']
y = data['f1']
labels = data['f2']
plt.scatter(x,y)
plt.show()