Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix max_length criteria when using inputs_embeds #28994

Merged
merged 18 commits into from
Feb 16, 2024

Conversation

zucchini-nlp
Copy link
Member

What does this PR do?

Fixes #28953 . StoppingCriteria with max_length behaves differently when provided input_ids or inputs_embeds, this happens only on decoder-only models. The PR fixes it so that the criteria accounts for the length of input_embeds when generating

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@gante

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically fulfils the main request of the GH issue, but I'd like for us to go one step further!

In the test you wrote, we check self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1] - 1). Ideally, the final -1 shouldn't be there: we initialize input_ids with decoder_start_id, causing the additional length, and we probably shouldn't. As such, we can add an additional condition in _prepare_decoder_input_ids_for_generation: in this specific case, input_ids should be empty.

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

oh, i see, added a new fix and checked that creating an empty tensor does not break anything

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Thank you for iterating 🤗

Regarding failing CI: it seems unrelated to this PR and main does not have this failure. Therefore, it will likely be solved by rebasing with main and then force-pushing.

@gante gante requested a review from amyeroberts February 14, 2024 10:37
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this! Very nice and clean PR :)

Just some outstanding questions so I can understand what's happening here before approving

@@ -2730,6 +2730,20 @@ def test_max_length_warning_if_different(self):
**model_kwargs,
)

def test_max_length_if_input_embeds(self):
# PT-only test: TF doesn't have StoppingCriteria
article = "Hey, are you conscious?"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a different phrase here? Talking about consciousness with these LLMs isn't ideal

input_len = input_ids.shape[-1]
out_gen = model.generate(input_ids=input_ids, max_length=max_length)
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length)
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding - why is the returned generation when passing in input_ids, a concatenation of the input and newly generated tokens, but for embeds we only return the new embeddings?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The addition of input_length here is needed because the output of generation with inputs_embeds return only newly generated text, while the input_ids return the whole text, including prompt. So, we are just making sure the lengths of both are equal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but why is the behaviour different for embeddings and input_ids?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand the question correctly, the lengths here differ because we return the whole text (prompt+new) when user passes ids. But we cannot recover prompt text from input_embeds, so we just return the newly generated part

Copy link
Member

@gante gante Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @zucchini-nlp wrote.

There is no mismatch if the user passes input_ids and inputs_embeds, as generate continues populating input_ids. But passing both kinda defeats the point of feeding inputs_embeds, which is used mostly for experimental purposes, and thus the shape difference when only inputs_embeds is set. Although we can technically recover input_ids from inputs_embeds (reverse lookup search) in most cases to make the shapes consistent, it's probably not a good use of our engineering time :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp @gante Thanks for the explanation!

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@@ -441,6 +441,9 @@ def _maybe_initialize_input_ids_for_generation(
if isinstance(value, torch.Tensor):
batch_size = value.shape[0]
break

if "inputs_embeds" in model_kwargs:
return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding - am I correct in understanding when using input_embeds we don't use any initialization then, this is just an empty placeholder?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, when we initialized with size 1 filled with BOS tokens, that ruined max_length by one token. We want want the final generation be a continuation of input_embeds and not start with BOS

@@ -1421,6 +1424,11 @@ def generate(
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length

# adjust max_length when using `input_embeds` in decoder-only models
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than saying what this is doing (we can tell from the code) it would be useful for the comment to explain why we need to do this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts done for all comments

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for iterating!

@gante
Copy link
Member

gante commented Feb 15, 2024

@amyeroberts unrelated CI failures, I believe this can be merged 🤗

@amyeroberts
Copy link
Collaborator

@zucchini-nlp Can you try rebasing? Fixes should have been merged into main with resolve the currently failing tests

@zucchini-nlp
Copy link
Member Author

@amyeroberts thanks, now it's all green and can be merged

@amyeroberts amyeroberts merged commit aee11fe into huggingface:main Feb 16, 2024
21 checks passed
zucchini-nlp added a commit to zucchini-nlp/transformers that referenced this pull request Feb 19, 2024
* fix max_length for inputs_embeds

* make style

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <[email protected]>

* Static Cache: load models with MQA or GQA (huggingface#28975)

* fix

* fix tests

* fix tests

* Update src/transformers/generation/utils.py

Co-authored-by: amyeroberts <[email protected]>

* more fixes

* make style

---------

Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
@zucchini-nlp zucchini-nlp deleted the fix/max_length_generation branch February 26, 2024 12:47
itazap pushed a commit that referenced this pull request May 14, 2024
* fix max_length for inputs_embeds

* make style

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <[email protected]>

* Static Cache: load models with MQA or GQA (#28975)

* fix

* fix tests

* fix tests

* Update src/transformers/generation/utils.py

Co-authored-by: amyeroberts <[email protected]>

* more fixes

* make style

---------

Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Report inconsistent output length from decoder-only model generate with input_ids and inputs_embeds
4 participants