Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Real-time word by word heatmap generation of user input. #33

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions data_utils/lazy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle as pkl
import time
from itertools import accumulate
from threading import Lock

import torch

Expand All @@ -29,9 +30,13 @@ def make_lazy(path, strs, data_type='data'):
datapath = os.path.join(lazypath, data_type)
lenpath = os.path.join(lazypath, data_type+'.len.pkl')
if not torch.distributed._initialized or torch.distributed.get_rank() == 0:
with open(datapath, 'w') as f:
f.write(''.join(strs))
str_ends = list(accumulate(map(len, strs)))
with open(datapath, 'wb') as f:
str_ends = []
str_cnt = 0
for s in strs:
f.write(s.encode('utf-8'))
str_cnt += len(s)
str_ends.append(str_cnt)
pkl.dump(str_ends, open(lenpath, 'wb'))
else:
while not os.path.exists(lenpath):
Expand All @@ -53,14 +58,15 @@ def __init__(self, path, data_type='data', mem_map=False):
lazypath = get_lazy_path(path)
datapath = os.path.join(lazypath, data_type)
#get file where array entries are concatenated into one big string
self._file = open(datapath, 'r')
self._file = open(datapath, 'rb')
self.file = self._file
#memory map file if necessary
self.mem_map = mem_map
if self.mem_map:
self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
lenpath = os.path.join(lazypath, data_type+'.len.pkl')
self.ends = pkl.load(open(lenpath, 'rb'))
self.read_lock = Lock()

def __getitem__(self, index):
"""read file and splice strings based on string ending array `ends` """
Expand Down Expand Up @@ -88,6 +94,7 @@ def file_read(self, start=0, end=None):
"""read specified portion of file"""
#TODO: Solve race condition
#Seek to start of file read
self.read_lock.acquire()
self.file.seek(start)
##### Getting context-switched here
#read to end of file if no end point provided
Expand All @@ -96,8 +103,11 @@ def file_read(self, start=0, end=None):
#else read amount needed to reach end point
else:
rtn = self.file.read(end-start)
self.read_lock.release()
#TODO: @raulp figure out mem map byte string bug
#if mem map'd need to decode byte string to string
#rtn = rtn.decode('utf-8')
rtn = str(rtn)
if self.mem_map:
rtn = rtn.decode('unicode_escape')
return rtn
Expand Down
83 changes: 61 additions & 22 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import seaborn as sns
sns.set_style({'font.family': 'monospace'})

import sys, termios, tty, cv2

parser = argparse.ArgumentParser(description='PyTorch Sentiment Discovery Generation/Visualization')

Expand All @@ -43,11 +44,11 @@
parser.add_argument('--tied', action='store_true',
help='tie the word embedding and softmax weights')
parser.add_argument('--load_model', type=str, default='model.pt',
help='model checkpoint to use')
help='model checkpoint to use') #use imdb_clf.pt model provided in the readme.
parser.add_argument('--save', type=str, default='generated.txt',
help='output file for generated text')
parser.add_argument('--gen_length', type=int, default='1000',
help='number of tokens to generate')
help='number of tokens to generate') #use --gen_length -1
parser.add_argument('--seed', type=int, default=-1,
help='random seed')
parser.add_argument('--temperature', type=float, default=1.0,
Expand All @@ -63,8 +64,9 @@
help='generates heatmap of main neuron activation [not working yet]')
parser.add_argument('--overwrite', type=float, default=None,
help='Overwrite value of neuron s.t. generated text reads as a +1/-1 classification')
parser.add_argument('--text', default='',
help='warm up generation with specified text first')
#dont need --text arg.
#parser.add_argument('--text', default='',
# help='warm up generation with specified text first')
args = parser.parse_args()

args.data_size = 256
Expand Down Expand Up @@ -109,7 +111,7 @@ def get_neuron_and_polarity(sd, neuron):
return neuron, 1
if neuron is None:
val, neuron = torch.max(torch.abs(weight[0].float()), 0)
neuron = neuron[0]
neuron = neuron.item()
val = weight[0][neuron]
if val >= 0:
polarity = 1
Expand Down Expand Up @@ -196,13 +198,24 @@ def make_heatmap(text, values, save=None, polarity=1):
plt.figure(figsize=(cell_width*n_limit, cell_height*num_rows))
hmap=sns.heatmap(values, annot=text, mask=mask, fmt='', vmin=-1, vmax=1, cmap='RdYlGn',
xticklabels=False, yticklabels=False, cbar=False)
plt.tight_layout()
#plt.tight_layout()
if save is not None:
plt.savefig(save)
# clear plot for next graph since we returned `hmap`
plt.clf()
plt.close()
return hmap

#return each character entered by the user
def getchar():
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(sys.stdin.fileno())
ch = sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
return ch

neuron, polarity = get_neuron_and_polarity(sd, args.neuron)
neuron = neuron if args.visualize or args.overwrite is not None else None
Expand All @@ -224,19 +237,45 @@ def make_heatmap(text, values, save=None, polarity=1):

outchrs = []
outvals = []
#with open(args.save, 'w') as outf:
with torch.no_grad():
if args.text != '':
chrs, vals = process_text(args.text, model, input, args.temperature, neuron, mask, args.overwrite, polarity)
outchrs += chrs
outvals += vals
chrs, vals = generate(args.gen_length, model, input, args.temperature, neuron, mask, args.overwrite, polarity)
outchrs += chrs
outvals += vals
outstr = ''.join(outchrs)
print(outstr)
with open(args.save, 'w') as f:
f.write(outstr)

if args.visualize:
make_heatmap(outchrs, outvals, os.path.splitext(args.save)[0]+'.png', polarity)
input_chars = []
text = ""
print("Enter Text:")

#In this loop, word by word user input is processed for heatmap generation.
#To exit from this loop, press esc.
while True:
sys.stdout.flush()
c = getchar()

if (c == "\x1b"):
print("\n")
exit(0)
elif (c == "\x7f" and len(input_chars) > 0):
input_chars.pop()
sys.stdout.write("\b \b")
continue
elif (c=='\r'):
print()
continue
print(c,end='')
input_chars.append(c)
text = ''.join(input_chars)

if (c == " " or c == "." or c == "!" or c == "@" or c == "#" or c == "$" or c == "%" or c == "&" or c == "*" or c == "?"):
#with open(args.save, 'w') as outf:
with torch.no_grad():
if text != '':
chrs, vals = process_text(text, model, input, args.temperature, neuron, mask, args.overwrite, polarity)
outchrs += chrs
outvals += vals
del input_chars[:]
chrs, vals = generate(args.gen_length, model, input, args.temperature, neuron, mask, args.overwrite, polarity)
outchrs += chrs
outvals += vals
if args.visualize:
make_heatmap(outchrs, outvals, os.path.splitext(args.save)[0]+'.png', polarity)
output_img = cv2.imread(os.path.splitext(args.save)[0]+'.png')
cv2.imshow("output",output_img)
cv2.waitKey(1)
if 0xFF == ord('q'):
sys.exit()