-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: Add new decoding strategy "DoLa" in .generate()
#29619
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@voidism thank you for this cool PR! 🔥
In addition to the interface and user experience comments left below, there is one task missing: tests. We should add two tests:
- A very small mixin test, to ensure the interface works on all models as expected. See here for an example.
- One (or more) heavy integration test(s), to ensure the method retains its correctness as we add other changes. See here for an example. You can add them on any model you believe it's appropriate.
src/transformers/generation/utils.py
Outdated
mask = final_logits[0] < -1e3 | ||
base_logits[0][mask] = -1e3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment about -1e3
, for future reference? Why not any other number? It is okay if it is simply a number with which you got good results empirically 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The line 2047 is removed, as I can directly get the mask from the _relative_top_filter
function.
The -1e3
in line 2048 is simply a number tested work empirically. Any the number that is not -float("Inf")
should be working as well. I have cleaned up the code and made them all in _relative_top_filter()
function. The -1e3
is assigned as the base_filter_value
variable.
Hi @gante ! Thanks so much for your suggestions! I spent some time to add the code for test cases, and fixed the issues you mentioned. Please let me know if you have any other concerns or suggestions for me to fix! I would be happy to address any of the issues you may have! 🤗 |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for iterating 💛
It almost could be merged as is -- the tests need to be reworked slightly. I've added a few suggestions to further improve the PR while we wait for the green light from a core maintainer 🤗
Hi @gante ! Thanks so much for your great suggestions! I have fixed all the issues you mentioned. Just let me know if you have any other concerns or suggestions! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with the PR 🙌
Hi @gante ! While waiting for the core maintainer's approval, I found that the validation of the parameter ranges in the generation config mainly happens in However, after I committed the new code. A test case of XLM model failed, and it seems to have nothing to do with my commit. The failed case seems related to #29297 I tried syncing with the upstream but it didn't solve the issue. I wonder if you know what's the reason for this failed test case. Sorry for bothering you again!
|
The failed test case was solved after syncing with the upstream! Please ignore my previous comment. |
Hi @amyeroberts ! This PR is ready to merge after some iterations! Would you be able to review it and give me any suggestions you have? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @voidism, thanks for working on adding this!
A few small comments. The main one being that the dola sampling method at the moment is way too large and needs to be broken down into smaller chunks
input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low" | ||
) | ||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | ||
print("Answer here: ", text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print("Answer here: ", text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
input_ids, max_new_tokens=20, temperature=0, dola_layers="low", repetition_penalty=1.2 | ||
) | ||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | ||
print("Answer here: ", text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print("Answer here: ", text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
@@ -788,3 +789,25 @@ def test_model_7b_4bit(self): | |||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True) | |||
|
|||
self.assertEqual(output_text, EXPECTED_TEXTS) | |||
|
|||
def test_model_2b_bf16_dola(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather we didn't add an integration test for each of these models for this new generation method, as it's expensive to run. Doing this for each new generation approach isn't scalable.
Rather, it's better to just have one integration test for specific generation methods, which checks the output for a select model cc @gante
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts I'd rather have it tested in a few key models, as we've been doing in the past for other generation methods -- generation tests are prone to false positives (due to argmax/sampling) and false negatives (due to a problem in the model used in a test).
But I understand our testing limitations, leaving the final call to you 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Temperally removed the test for gemma! Let me know if you want me to add it back! 🤗
tests/generation/test_utils.py
Outdated
for model_name in [ | ||
"wav2vec", | ||
"clvp", | ||
"bark", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - one line
for model_name in [ | |
"wav2vec", | |
"clvp", | |
"bark", | |
] | |
for model_name in ["wav2vec", "clvp", "bark"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
tests/generation/test_utils.py
Outdated
"bark", | ||
] | ||
): | ||
self.skipTest("Skip speech models") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I previously skipped these speech models because they don't have the regular output_embeddings to perform early exit. And the early exit is required for dola decoding. However, it's actually not just because they are speech models, we should simply check the output_embeddings to decide whether to skip!
Thus, I changed this part to
if model.get_output_embeddings() is None:
self.skipTest("DoLa is not supported for models that don't have output embeddings")
src/transformers/generation/utils.py
Outdated
streamer.end() | ||
|
||
if return_dict_in_generate: | ||
if self.config.is_encoder_decoder: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the tests it says that this isn't supperted by encoder_decoder models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this part of the code that has if self.config.is_encoder_decoder:
!
} | ||
generation_kwargs.update({"dola_layers": "low"}) | ||
output_dola = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) | ||
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to do a test which does a single forward pass and checks that the expected logits are selected i.e. the dola method should be decoupled from generate itself and we test passing logits to the dola method and then the logit outputs. I believe this is a more general issue with the generation testing however.
Specifically, this test doesn't really convince me that the implementation is correct (not do the integration tests, unless they've been generated from the official dola implementation), but that they functionally work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to do a test which does a single forward pass and checks that the expected logits are selected i.e. the dola method should be decoupled from generate itself and we test passing logits to the dola method and then the logit outputs. I believe this is a more general issue with the generation testing however.
100% Agreed. However, this is not an issue with the DoLA method, but with the structure of generate
. At the moment, each decoding function is a monolith where we can't isolate an iteration of the loop. Me and @zucchini-nlp are working to fix this problem, so we can breakdown (and test) each piece of the core functionality. For instance, you've recently reviewed a PR where the stopping condition of the generation loop was moved into a shared function, which works towards this goal 🤗
What this pattern of (legacy) tests does is to catch flagrant API issues and/or model incompatibilities, not to detect whether the decoding method matches its original implementation. And that's the extent of what we can do in unit tests, until we rework things :)
@amyeroberts What I mean with this comment is that it shouldn't be @voidism's responsibility to break down the _dola_decoding
function nor to rework tests, @voidism is simply following the existing pattern. It is our (mine and @zucchini-nlp's) responsibility to ensure what you wrote becomes true -- in fact, it is easier for us to refactor things if they keep the same imperfect pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the correctness of DoLa. I am the first author of DoLa paper and I have kept tracking whether the new code in this PR can reproduce the old numbers in my paper.
The left-hand side is the new numbers I tested using the current version of code.
The right-hand side is the screenshot of my paper, where the numbers are from the official implementation and the experiments I did last year.
The original implementation was based on v4.28.1
. The numbers changed a little bit (also for the greedy decoding baseline), which I think it's because of the version changes as well as the different machines and gpus I used. But the same level of improvement can be achieved by the new code in this PR, e.g. ~4% on StrQA with llama-7b.
I can also provide more tests to validate the consistency between this PR and my official dola implementation if you think it's needed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@voidism Thanks for providing these numbers! I think these are good enough to have a reasonable degree of certainty in the application in the absence of being able to fully test at the moment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have checked that my latest commit today (based on v4.41.0
) can also reproduce the scores here!
# DoLa decoding with contrasting lower part of layers (layers 0,2,...,14) | ||
>>> dola_low_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers='low', repetition_penalty=1.2) | ||
>>> tokenizer.batch_decode(dola_low_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True) | ||
['\nThe Declaration of Independence was signed on July 4, 1776.\nWhat was the date of the signing of the Declaration of Independence?\nThe Declaration of Independence was signed on July 4,'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it - the outputs are the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, otherwise users won't feel compelled into using the technique
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed here! I switched back to show the output example of dola_layers='high'
as suggested by @gante last time, and removed the low
outputs here. In this case, the high
output is different from the vanilla decoding outputs and it makes more sense to the readers.
- If the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function. | ||
- Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. For example, setting `dola_layers=[28,30]` will contrast the final layer (32-th layer) with the 28-th and 30-th layers. | ||
|
||
The paper suggested that contrasting `'high'` layers to improve short-answer tasks like TruthfulQA, and contrasting `'low'` layers to improve all the other long-answer reasoning tasks, such as GSM8K, StrategyQA, FACTOR, and VicunaQA. Applying DoLa to smaller models like GPT-2 is not recommended, as the results shown in the Appendix N of the paper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be good to use a better demo for low
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched back to show a demo of high
! I can also try to find prompt cases that make vanilla
and low
and high
all very different, if you think it's needed!
- For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` are used for `'low'` and `'high'` layers, respectively. | ||
- For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for `'low'` and `'high'` layers, respectively. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm - is this from the paper? It seems pretty arbitratry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the layer selection logic is in the Appendix F of my paper. For llama-7b we use [0, 16) and [16, 32). For llama-13b/33b/65b we use [0, 20) and [N-20, N), where N = 40/60/80 for 13b/33b/65b. They are selected based on the validation set results. In this PR, I renamed this layer selection as low
or high
for simplicity.
Hi @amyeroberts ! Thanks so much for all of your great suggestions! They are very helpful and they improved my code and the test cases! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating on this! Just a few small suggestions - otherwise looking great!
} | ||
generation_kwargs.update({"dola_layers": "low"}) | ||
output_dola = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) | ||
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@voidism Thanks for providing these numbers! I think these are good enough to have a reasonable degree of certainty in the application in the absence of being able to fully test at the moment
The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models" (https://arxiv.org/abs/2309.03883) in ICLR 2024. | ||
|
||
Parameters: | ||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slight mismatch between docstring and method signature e.g. do_sample
missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed the mismatch. Now the docstring and method signature are consistent!
src/transformers/generation/utils.py
Outdated
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) | ||
|
||
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT | ||
>>> outputs = model._dola_decoding( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should show calling a private method in an example. My understanding from recent refactors is that this is now taken from the generation config @gante Is this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. We can remove this example :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the example!
src/transformers/generation/utils.py
Outdated
# using final layer as the mature layer | ||
mature_layer = self.config.num_hidden_layers | ||
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function | ||
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the mature layer and the 0-th layer if it's the mature layer. Notice that DoLa is not helping much to shallow models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ultra nit
Possibly mature
-> final
to clarify? I'm not sure what a mature layer is i.e. above when it says using the final layer as the mature layer.
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the mature layer and the 0-th layer if it's the mature layer. Notice that DoLa is not helping much to shallow models. | |
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the mature layer and the 0-th layer otherwise. Notice that DoLa does not help shallow models much. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed all the mature
layer into final
layer!
if return_dict_in_generate: | ||
if output_scores: | ||
scores += (next_token_scores,) | ||
if output_logits: | ||
raw_logits += (final_layer_next_token_logits,) | ||
if output_attentions: | ||
decoder_attentions += ( | ||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | ||
) | ||
if self.config.is_encoder_decoder: | ||
cross_attentions += (outputs.cross_attentions,) | ||
|
||
if output_hidden_states: | ||
decoder_hidden_states += ( | ||
(outputs.decoder_hidden_states,) | ||
if self.config.is_encoder_decoder | ||
else (outputs.hidden_states,) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for future @gante - this looks like something we can abstract out for this and other generation methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed 👍
src/transformers/generation/utils.py
Outdated
else: | ||
# 1. Stacking all premature_layers into a new dimension | ||
stacked_premature_layers = torch.stack( | ||
[candidate_premature_logits[i] for i in candidate_premature_layers], dim=0 | ||
) | ||
|
||
# 2. Calculate the softmax values for mature_layer and all premature_layers | ||
softmax_mature_layer = F.softmax(final_logits, dim=-1) # shape: (batch_size, vocab_size) | ||
softmax_premature_layers = F.softmax( | ||
stacked_premature_layers, dim=-1 | ||
) # shape: (num_premature_layers, batch_size, vocab_size) | ||
|
||
# 3. Calculate M, the average distribution | ||
M = 0.5 * ( | ||
softmax_mature_layer[None, :, :] + softmax_premature_layers | ||
) # shape: (num_premature_layers, batch_size, vocab_size) | ||
|
||
# 4. Calculate log-softmax for the KL divergence | ||
log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) # shape: (batch_size, vocab_size) | ||
log_softmax_premature_layers = F.log_softmax( | ||
stacked_premature_layers, dim=-1 | ||
) # shape: (num_premature_layers, batch_size, vocab_size) | ||
|
||
# 5. Calculate the KL divergences and then the JS divergences | ||
kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], M, reduction="none").mean( | ||
-1 | ||
) # shape: (num_premature_layers, batch_size) | ||
kl2 = F.kl_div(log_softmax_premature_layers, M, reduction="none").mean( | ||
-1 | ||
) # shape: (num_premature_layers, batch_size) | ||
js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) | ||
|
||
# 6. Reduce the batchmean | ||
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) | ||
premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())] | ||
|
||
base_logits = candidate_premature_logits[premature_layer] | ||
final_logits, base_logits = _relative_top_filter(final_logits, base_logits) | ||
logits = final_logits - base_logits | ||
return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We can just do an early return here which avoids having the main block of code indented
- Comments above the line of code to avoid unnecessary spliting
else: | |
# 1. Stacking all premature_layers into a new dimension | |
stacked_premature_layers = torch.stack( | |
[candidate_premature_logits[i] for i in candidate_premature_layers], dim=0 | |
) | |
# 2. Calculate the softmax values for mature_layer and all premature_layers | |
softmax_mature_layer = F.softmax(final_logits, dim=-1) # shape: (batch_size, vocab_size) | |
softmax_premature_layers = F.softmax( | |
stacked_premature_layers, dim=-1 | |
) # shape: (num_premature_layers, batch_size, vocab_size) | |
# 3. Calculate M, the average distribution | |
M = 0.5 * ( | |
softmax_mature_layer[None, :, :] + softmax_premature_layers | |
) # shape: (num_premature_layers, batch_size, vocab_size) | |
# 4. Calculate log-softmax for the KL divergence | |
log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) # shape: (batch_size, vocab_size) | |
log_softmax_premature_layers = F.log_softmax( | |
stacked_premature_layers, dim=-1 | |
) # shape: (num_premature_layers, batch_size, vocab_size) | |
# 5. Calculate the KL divergences and then the JS divergences | |
kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], M, reduction="none").mean( | |
-1 | |
) # shape: (num_premature_layers, batch_size) | |
kl2 = F.kl_div(log_softmax_premature_layers, M, reduction="none").mean( | |
-1 | |
) # shape: (num_premature_layers, batch_size) | |
js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) | |
# 6. Reduce the batchmean | |
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) | |
premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())] | |
base_logits = candidate_premature_logits[premature_layer] | |
final_logits, base_logits = _relative_top_filter(final_logits, base_logits) | |
logits = final_logits - base_logits | |
return logits | |
return logits | |
# 1. Stacking all premature_layers into a new dimension | |
stacked_premature_layers = torch.stack( | |
[candidate_premature_logits[i] for i in candidate_premature_layers], dim=0 | |
) | |
# 2. Calculate the softmax values for mature_layer and all premature_layers | |
# shape: (batch_size, vocab_size) | |
softmax_mature_layer = F.softmax(final_logits, dim=-1) | |
# shape: (num_premature_layers, batch_size, vocab_size) | |
softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1) | |
# 3. Calculate M, the average distribution | |
# shape: (num_premature_layers, batch_size, vocab_size) | |
M = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers) | |
# 4. Calculate log-softmax for the KL divergence | |
# shape: (batch_size, vocab_size) | |
log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) | |
# shape: (num_premature_layers, batch_size, vocab_size) | |
log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1) | |
# 5. Calculate the KL divergences and then the JS divergences | |
# shape: (num_premature_layers, batch_size) | |
kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], M, reduction="none").mean(-1) | |
# shape: (num_premature_layers, batch_size) | |
kl2 = F.kl_div(log_softmax_premature_layers, M, reduction="none").mean(-1) | |
js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) | |
# 6. Reduce the batchmean | |
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) | |
premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())] | |
base_logits = candidate_premature_logits[premature_layer] | |
final_logits, base_logits = _relative_top_filter(final_logits, base_logits) | |
logits = final_logits - base_logits | |
return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
src/transformers/generation/utils.py
Outdated
# 3. Calculate M, the average distribution | ||
M = 0.5 * ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a rule, no single letter vars should be used - let's use something more descriptive e.g. avg_dist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it to avg_dist
!
return logits | ||
|
||
|
||
def _relative_top_filter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definition of objects should go above the lines they're first used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed the order!
src/transformers/generation/utils.py
Outdated
base_filter_value=-1e-3, | ||
min_tokens_to_keep: int = 1, | ||
) -> torch.FloatTensor: | ||
"""Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Link is great! We should add a short sentence saying what this function does too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a description!
if not hasattr(config, "use_cache"): | ||
config.use_cache = False | ||
else: | ||
config.use_cache = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on https://github.com/huggingface/transformers/pull/29619/files#r1538243054
if not hasattr(config, "use_cache"): | |
config.use_cache = False | |
else: | |
config.use_cache = True | |
# Some models don't support the cache and returning past_key_values | |
if not hasattr(config, "use_cache"): | |
config.use_cache = False | |
else: | |
config.use_cache = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the comment!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@voidism are you intending to continue the PR? 🤗 Or do you need a hand? |
Hi @gante Sorry that I was busy with my midterm for the past few weeks 😔 so I forgot to fix this for a while... I will continue fixing the PR this or next week! |
@voidism no worries, focus on your midterms 💪 we'll be here when you're ready to continue 🙌 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi @gante and @amyeroberts I am back and fixed all the suggestions from @amyeroberts last time! Sorry that I was busy with midterm exams and paper deadlines last month 😔, so I stopped fixing this PR for a while. 🥲 It's my fault that you guys might need to spend more time recalling our discussions from almost two months ago. I am really sorry about that! 🥲 In addition to fixing all the suggestions from last time, I have synced this PR with the latest transformers Let me know if you have any other concerns or suggestions. I recently have more free time so I can assure you guys that I will fix any of your new suggestions as soon as I can! No more procrastination I promise! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding and iterating!
All looks good to me. As it's been open for a while, I'd like a quick re-review from @gante to confirm this is still in-line with the current generate patterns
Thanks @amyeroberts so much for approving the changes! 🙌 Hi @gante Just let me know if the current version looks good or not. I will be happy to fix any suggestions or concerns you have! Thanks! 🤗 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@voidism my turn to apologise for the delay, I'm catching up with issues :) I've re-checked the PR and I'm happy with it! I'm going to merge this Monday (to avoid breaking our CI on a weekend 😉 ) |
Hi @gante No problem! Thanks so much for your help!! 🤗 |
rebased yet again (previous |
Ran the following slow tests locally (with the expected results):
|
@voidism finally all CI issues were sorted -- thank you for bearing with us 🤗 I will communicate about this feature tomorrow! 💪 |
Hi @gante Thanks a lot for your help! Handling these CI tests isn't easy (I learned a lot from it 😂). I really appreciate your effort. So happy that we finally made it! 🤗 |
@voidism hehe it looks annoying, but it is essential to ensure all our features are playing nicely with each other 🤗 |
What does this PR do?
Fixes #29524
We add the support for a new decoding strategy proposed in a recent paper of ICLR 2024.
The main revisions are in src/transformers/generation/utils.py and src/transformers/generation/configuration_utils.py
We also update the documentation and add the test code. Run the test by:
Before submitting
Pull Request section?
model.generate()
function #29524documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante is the main contributor of the part of
.generate()
function, which this PR focuses on.