Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 1 for tensor number 3 in the list #14

Open
jaugustin12 opened this issue Nov 20, 2021 · 0 comments

Comments

@jaugustin12
Copy link

I am running EDITNTS: https://github.com/yuedongP/EditNTS without teacher forcing on some training data. When I run main.py I get the error:

  File "/home/jba5337/work/ds440w/EditNTS-Google/editnts.py", line 252, in forward
    output_t = torch.cat((output_edits, attn_applied_org_t, c, hidden_words[0]),
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 1 for tensor number 3 in the list.

Here is what happens when I print c:

c tensor([[[-0.0353, -0.0617, -0.1176,  ...,  0.0507, -0.0174,  0.1828]],

        [[-0.0769, -0.0166, -0.1737,  ..., -0.1302, -0.1488,  0.1480]],

        [[-0.0570, -0.0683, -0.2270,  ..., -0.0820, -0.2011,  0.1915]],

        ...,

        [[-0.1127,  0.0051, -0.2119,  ..., -0.0853, -0.1813,  0.2058]],

        [[-0.0570, -0.0683, -0.2270,  ..., -0.0412, -0.1851,  0.1975]],

        [[-0.1127,  0.0051, -0.2119,  ..., -0.0477, -0.1822,  0.2200]]],
       device='cuda:0', grad_fn=<GatherBackward0>)

size torch.Size([32, 1, 400])

It looks like this is greater than a size 1, so I am unsure where the issue is. Here is the function of where the error is coming from if you could please take a look:

        else: # no teacher forcing
            decoder_input_edit = input_edits[:, :1]
            decoder_input_word=simp_sent[:,:1]
            t, tt = 0, max(MAX_LEN,input_edits.size(1)-1)

            # initialize
            embedded_edits = self.embedding(decoder_input_edit)
            output_edits, hidden_edits = self.rnn_edits(embedded_edits, hidden_org)

            embedded_words = self.embedding(decoder_input_word)
            output_words, hidden_words = self.rnn_words(embedded_words, hidden_org)
            #
            # # give previous word from tgt simp_sent
            # inds = torch.LongTensor(counter_for_keep_ins)
            # dummy = inds.view(-1, 1, 1)
            # dummy = dummy.expand(dummy.size(0), dummy.size(1), output_words.size(2)).cuda()
            # c_word = output_words.gather(1, dummy)

            while t < tt:
                if t>0:
                    embedded_edits = self.embedding(decoder_input_edit)
                    output_edits, hidden_edits = self.rnn_edits(embedded_edits, hidden_edits)

                key_org = self.attn_Projection_org(output_edits)  # bsz x nsteps x nhid
                logits_org = torch.bmm(key_org, encoder_outputs_org.transpose(1, 2))  # bsz x nsteps x encsteps
                attn_weights_org_t = F.softmax(logits_org, dim=-1)  # bsz x nsteps x encsteps
                attn_applied_org_t = torch.bmm(attn_weights_org_t, encoder_outputs_org)  # bsz x nsteps x nhid

                ## find current word
                inds = torch.LongTensor(counter_for_keep_del)
                dummy = inds.view(-1, 1, 1)
                dummy = dummy.expand(dummy.size(0), dummy.size(1), encoder_outputs_org.size(2)).cuda()
                c = encoder_outputs_org.gather(1, dummy)
                print('c',c)
                output_t = torch.cat((output_edits, attn_applied_org_t, c, hidden_words[0]),
                                     2)  # bsz*nsteps x nhid*2
                output_t = self.attn_MLP(output_t)
                output_t = F.log_softmax(self.out(output_t), dim=-1)

                decoder_out.append(output_t)
                decoder_input_edit=torch.argmax(output_t,dim=2)



                # gold_action = input[:, t + 1].vocab_data.cpu().numpy()  # might need to realign here because start added
                pred_action= torch.argmax(output_t,dim=2)
                counter_for_keep_del = [i[0] + 1 if i[1] == 2 or i[1] == 3 or i[1] == 5 else i[0]
                                        for i in zip(counter_for_keep_del, pred_action)]

                # update rnn_words
                # find previous generated word
                # give previous word from tgt simp_sent
                dummy_2 = inds.view(-1, 1).cuda()
                org_t = org_ids.gather(1, dummy_2)
                hidden_words = self.execute_batch(pred_action, org_t, hidden_words)  # we give the editted subsequence
                # hidden_words = self.execute_batch(pred_action, org_t, hidden_org)  #here we only give the word

                t += 1
                check = sum([x >= org_ids.size(1) for x in counter_for_keep_del])
                if check:
                    break
        return torch.cat(decoder_out, dim=1), hidden_edits

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant