Skip to content

Commit

Permalink
Fix issue of canine forward requiring input_ids anyway (#26290)
Browse files Browse the repository at this point in the history
* fix issue of canine forward requires input_ids anyway

The `forward` requires `input_ids` for deriving other variables in all cases. Change this to use the given one between `input_ids` and `inputs_embeds`

* fix canine forward

The current `forward` requires (the shape of) `input_ids` for deriving other variables whenever `input_ids` or `inputs_embeds` is provided. Change this to use the given one instead of `input_ids` all the time.

* fix format

* fix format
  • Loading branch information
marcmk6 authored Oct 2, 2023
1 parent 7d77d7f commit 6d02ca4
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/models/canine/modeling_canine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,9 @@ def forward(
# Contextualize character embeddings using shallow Transformer.
# We use a 3D attention mask for the local attention.
# `input_char_encoding`: shape (batch_size, char_seq_len, char_dim)
char_attention_mask = self._create_3d_attention_mask_from_input_mask(input_ids, attention_mask)
char_attention_mask = self._create_3d_attention_mask_from_input_mask(
input_ids if input_ids is not None else inputs_embeds, attention_mask
)
init_chars_encoder_outputs = self.initial_char_encoder(
input_char_embeddings,
attention_mask=char_attention_mask,
Expand Down

0 comments on commit 6d02ca4

Please sign in to comment.