-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix issue of canine forward requiring input_ids anyway #26290
Conversation
The `forward` requires `input_ids` for deriving other variables in all cases. Change this to use the given one between `input_ids` and `inputs_embeds`
The current `forward` requires (the shape of) `input_ids` for deriving other variables whenever `input_ids` or `inputs_embeds` is provided. Change this to use the given one instead of `input_ids` all the time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your PR !
The method should correctly retrieve the sequence length as the second dimension corresponds to it: https://github.com/huggingface/transformers/blob/main/src/transformers/models/canine/modeling_canine.py#L1042
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
char_attention_mask = self._create_3d_attention_mask_from_input_mask( | ||
input_ids if input_ids is not None else inputs_embeds, attention_mask | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should use the input_shape
defined above, since the function is only used once, no problem changing it
I'm pretty sure we have tests in our CI for that make sure generation work with use cache and in that case, you need this to work. Good for me as is
…6290) * fix issue of canine forward requires input_ids anyway The `forward` requires `input_ids` for deriving other variables in all cases. Change this to use the given one between `input_ids` and `inputs_embeds` * fix canine forward The current `forward` requires (the shape of) `input_ids` for deriving other variables whenever `input_ids` or `inputs_embeds` is provided. Change this to use the given one instead of `input_ids` all the time. * fix format * fix format
…6290) * fix issue of canine forward requires input_ids anyway The `forward` requires `input_ids` for deriving other variables in all cases. Change this to use the given one between `input_ids` and `inputs_embeds` * fix canine forward The current `forward` requires (the shape of) `input_ids` for deriving other variables whenever `input_ids` or `inputs_embeds` is provided. Change this to use the given one instead of `input_ids` all the time. * fix format * fix format
The current
forward
requires (the shape of)input_ids
for deriving other variables wheneverinput_ids
orinputs_embeds
is provided. Change this to use the given one instead ofinput_ids
all the time.What does this PR do?
Fixes #26288
Who can review?
@ArthurZucker and @younesbelkada