Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: fix assisted generation with past_key_values passed as kwargs #31644

Merged
merged 1 commit into from
Jun 26, 2024

Conversation

gante
Copy link
Member

@gante gante commented Jun 26, 2024

What does this PR do?

This PR:

  • Fixes assisted generation when generate is called with the past_key_values kwarg (contrarily to other kwargs, this one shouldn't be passed to the assistant model, as it is the cache of the main model)
  • Renames maximum_length to max_length in the newly added DynamicCache.crop function (max_length is the common variable name to depict a maximum length, standardizes API before the function is released)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

assistant_kwargs[key] = (
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)

# Remove potential default DynamicCache if assistant does not support it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

were tests failing because of this?

Copy link
Member Author

@gante gante Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker assisted generation if we passed past_key_values to generate was! The assistant should not copy the cache from the main model by default (cuz they will likely have different decoders)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for fixing!

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
def crop(self, max_length: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for changing the name?

max_length seems more consistent with the rest of the repo, but wondering if there's another reason?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

none, just consistency 🤗 It wasn't caught in a former PR, fixing it now before the function gets released :)

@gante gante merged commit a3fb96a into huggingface:main Jun 26, 2024
21 checks passed
@gante
Copy link
Member Author

gante commented Jun 26, 2024

Thanks for the quick review 💛

@gante gante deleted the past_kv_assisted_gen branch June 26, 2024 17:24
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants