-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: fix assisted generation with past_key_values
passed as kwargs
#31644
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
assistant_kwargs[key] = ( | ||
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) | ||
) | ||
|
||
# Remove potential default DynamicCache if assistant does not support it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
were tests failing because of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker assisted generation if we passed past_key_values
to generate
was! The assistant should not copy the cache from the main model by default (cuz they will likely have different decoders)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for fixing!
def crop(self, maximum_length: int): | ||
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be | ||
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" | ||
def crop(self, max_length: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for changing the name?
max_length seems more consistent with the rest of the repo, but wondering if there's another reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
none, just consistency 🤗 It wasn't caught in a former PR, fixing it now before the function gets released :)
Thanks for the quick review 💛 |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
This PR:
generate
is called with thepast_key_values
kwarg (contrarily to other kwargs, this one shouldn't be passed to the assistant model, as it is the cache of the main model)maximum_length
tomax_length
in the newly addedDynamicCache.crop
function (max_length
is the common variable name to depict a maximum length, standardizes API before the function is released)