You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
File "/home/jba5337/work/ds440w/EditNTS-Google/editnts.py", line 252, in forward
output_t = torch.cat((output_edits, attn_applied_org_t, c, hidden_words[0]),
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 1 for tensor number 3 in the list.
It looks like this is greater than a size 1, so I am unsure where the issue is. Here is the function of where the error is coming from if you could please take a look:
else: # no teacher forcing
decoder_input_edit = input_edits[:, :1]
decoder_input_word=simp_sent[:,:1]
t, tt = 0, max(MAX_LEN,input_edits.size(1)-1)
# initialize
embedded_edits = self.embedding(decoder_input_edit)
output_edits, hidden_edits = self.rnn_edits(embedded_edits, hidden_org)
embedded_words = self.embedding(decoder_input_word)
output_words, hidden_words = self.rnn_words(embedded_words, hidden_org)
#
# # give previous word from tgt simp_sent
# inds = torch.LongTensor(counter_for_keep_ins)
# dummy = inds.view(-1, 1, 1)
# dummy = dummy.expand(dummy.size(0), dummy.size(1), output_words.size(2)).cuda()
# c_word = output_words.gather(1, dummy)
while t < tt:
if t>0:
embedded_edits = self.embedding(decoder_input_edit)
output_edits, hidden_edits = self.rnn_edits(embedded_edits, hidden_edits)
key_org = self.attn_Projection_org(output_edits) # bsz x nsteps x nhid
logits_org = torch.bmm(key_org, encoder_outputs_org.transpose(1, 2)) # bsz x nsteps x encsteps
attn_weights_org_t = F.softmax(logits_org, dim=-1) # bsz x nsteps x encsteps
attn_applied_org_t = torch.bmm(attn_weights_org_t, encoder_outputs_org) # bsz x nsteps x nhid
## find current word
inds = torch.LongTensor(counter_for_keep_del)
dummy = inds.view(-1, 1, 1)
dummy = dummy.expand(dummy.size(0), dummy.size(1), encoder_outputs_org.size(2)).cuda()
c = encoder_outputs_org.gather(1, dummy)
print('c',c)
output_t = torch.cat((output_edits, attn_applied_org_t, c, hidden_words[0]),
2) # bsz*nsteps x nhid*2
output_t = self.attn_MLP(output_t)
output_t = F.log_softmax(self.out(output_t), dim=-1)
decoder_out.append(output_t)
decoder_input_edit=torch.argmax(output_t,dim=2)
# gold_action = input[:, t + 1].vocab_data.cpu().numpy() # might need to realign here because start added
pred_action= torch.argmax(output_t,dim=2)
counter_for_keep_del = [i[0] + 1 if i[1] == 2 or i[1] == 3 or i[1] == 5 else i[0]
for i in zip(counter_for_keep_del, pred_action)]
# update rnn_words
# find previous generated word
# give previous word from tgt simp_sent
dummy_2 = inds.view(-1, 1).cuda()
org_t = org_ids.gather(1, dummy_2)
hidden_words = self.execute_batch(pred_action, org_t, hidden_words) # we give the editted subsequence
# hidden_words = self.execute_batch(pred_action, org_t, hidden_org) #here we only give the word
t += 1
check = sum([x >= org_ids.size(1) for x in counter_for_keep_del])
if check:
break
return torch.cat(decoder_out, dim=1), hidden_edits
The text was updated successfully, but these errors were encountered:
I am running EDITNTS: https://github.com/yuedongP/EditNTS without teacher forcing on some training data. When I run main.py I get the error:
Here is what happens when I print c:
It looks like this is greater than a size 1, so I am unsure where the issue is. Here is the function of where the error is coming from if you could please take a look:
The text was updated successfully, but these errors were encountered: