Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CANINE unexpectedly requires input_ids anyway #26288

Closed
4 tasks
marcmk6 opened this issue Sep 20, 2023 · 0 comments · Fixed by #26290
Closed
4 tasks

CANINE unexpectedly requires input_ids anyway #26288

marcmk6 opened this issue Sep 20, 2023 · 0 comments · Fixed by #26290

Comments

@marcmk6
Copy link
Contributor

marcmk6 commented Sep 20, 2023

System Info

  • transformers version: 4.33.2
  • Platform: Linux-5.15.109+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.17.2
  • Safetensors version: 0.3.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu118 (False)
  • Tensorflow version (GPU?): 2.13.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.7.2 (cpu)
  • Jax version: 0.4.14
  • JaxLib version: 0.4.14
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import CanineModel, BertModel
import torch

BERT_model = BertModel.from_pretrained('bert-base-uncased')
canine_model = CanineModel.from_pretrained('google/canine-c')

fake_input = torch.rand(1, 10, 768)

_ = BERT_model.forward(inputs_embeds=fake_input) # no error
_ = canine_model.forward(inputs_embeds=fake_input) # error

The error message

File /miniconda3/envs/tmp/lib/python3.10/site-packages/transformers/models/canine/modeling_canine.py:1172, in CanineModel.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
   1162 input_char_embeddings = self.char_embeddings(
   1163     input_ids=input_ids,
   1164     position_ids=position_ids,
   1165     token_type_ids=token_type_ids,
   1166     inputs_embeds=inputs_embeds,
   1167 )
   1169 # Contextualize character embeddings using shallow Transformer.
   1170 # We use a 3D attention mask for the local attention.
   1171 # `input_char_encoding`: shape (batch_size, char_seq_len, char_dim)
-> 1172 char_attention_mask = self._create_3d_attention_mask_from_input_mask(input_ids, attention_mask)
   1173 init_chars_encoder_outputs = self.initial_char_encoder(
   1174     input_char_embeddings,
   1175     attention_mask=char_attention_mask,
   1176     output_attentions=output_attentions,
   1177     output_hidden_states=output_hidden_states,
   1178 )
   1179 input_char_encoding = init_chars_encoder_outputs.last_hidden_state

File /miniconda3/envs/tmp/lib/python3.10/site-packages/transformers/models/canine/modeling_canine.py:1042, in CanineModel._create_3d_attention_mask_from_input_mask(self, from_tensor, to_mask)
   1031 def _create_3d_attention_mask_from_input_mask(self, from_tensor, to_mask):
   1032     """
   1033     Create 3D attention mask from a 2D tensor mask.
   1034 
   (...)
   1040         float Tensor of shape [batch_size, from_seq_length, to_seq_length].
   1041     """
-> 1042     batch_size, from_seq_length = from_tensor.shape[0], from_tensor.shape[1]
   1044     to_seq_length = to_mask.shape[1]
   1046     to_mask = torch.reshape(to_mask, (batch_size, 1, to_seq_length)).float()

AttributeError: 'NoneType' object has no attribute 'shape'

Expected behavior

According to doc, the forward should work with either input_ids or inputs_embeds provided. But it turns out input_ids is used for deriving other variables in the code in all cases.

@marcmk6 marcmk6 changed the title CANINE requires input_ids anyway CANINE unexpectedly requires input_ids anyway Sep 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant