-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathwordembedding.py
125 lines (61 loc) · 2.78 KB
/
wordembedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import json
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from api import api_question_embed
class WordEmbedding(nn.Module):
def __init__(self,N_word,word_emb=None):
super(WordEmbedding,self).__init__()
self.N_word = N_word
self.word_emb = word_emb
def gen_x_batch(self,q,col=None):
batch_size = len(q)
val_embs = []
val_len = np.zeros(batch_size,dtype = np.int64)
for i,(q_one,col_one) in enumerate(zip(q,col)):
q_val= None
if self.word_emb:
q_val = [ self.word_emb.get(x,np.zeros(self.N_word,dtype=np.float32)) for x in q_one ]
else:
q_val = api_question_embed(q_one)
val_embs.append( [np.zeros(self.N_word,dtype=np.float32)] + q_val + [np.zeros(self.N_word,dtype=np.float32)] )
val_len[i] = len(q_val) + 2
max_len = max(val_len)
val_emb_array = np.zeros((batch_size,max_len,self.N_word),dtype=np.float32)
for i in range(batch_size):
for j in range(len(val_embs[i])):
val_emb_array[i,j,:] = val_embs[i][j]
input_tensor = torch.from_numpy(val_emb_array)
input_tensor_var = Variable(input_tensor)
return input_tensor_var,val_len
def gen_column_batch(self,cols):
# Stores numbers of columns in the corresponding table to which each question is related
col_len = np.zeros(len(cols), dtype=np.int64)
# create a single list containing all the columns in the batch
names =[]
for i, col in enumerate(cols):
names = names + col
col_len[i] = len(col)
name_inp_var,name_len = self.list_to_batch(names)
return name_inp_var,name_len,col_len
def list_to_batch(self,col_list):
total_columns = len(col_list)
val_embs = []
val_len = np.zeros(total_columns,dtype=np.int64 )
for i,col in enumerate(col_list):
val = None
if self.word_emb:
val = [ self.word_emb.get(x, np.zeros(self.N_word,dtype=np.float32)) for x in col ]
else:
val = api_question_embed(col)
val_embs.append(val)
val_len[i] = len(val)
max_len = max(val_len)
val_emb_array = np.zeros( (total_columns,max_len,self.N_word) , dtype=np.float32 )
for i in range(total_columns):
for j in range( len(val_embs[i]) ):
val_emb_array[i,j,:] = val_embs[i][j]
val_inp = torch.from_numpy(val_emb_array)
val_inp_var = Variable(val_inp)
return val_inp_var , val_len