Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama-2-7b-chat-hf will not allow temperature to be 0.0 #687

Closed
SleepingSkipper opened this issue Aug 18, 2023 · 3 comments
Closed

Llama-2-7b-chat-hf will not allow temperature to be 0.0 #687

SleepingSkipper opened this issue Aug 18, 2023 · 3 comments

Comments

@SleepingSkipper
Copy link

I am running the Llama-2-7b-chat-hf model on Huggingface.
When I set temperature=0.0 or temperature=0, I get
ValueError: temperature has to be a strictly positive float, but is 0.0.
Until a week ago, It was working with the same code and environment.

My code and error message;

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_name="meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = AutoModelForCausalLM.from_pretrained(
    model_name, 
    quantization_config=bnb_config, 
    trust_remote_code=True
)
model_4bit.config.use_cache = False
model = model_4bit 
tokenizer = AutoTokenizer.from_pretrained(model_name)

def generate(text):
    prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
Summarize following sentence in three lines.
### Input:
{text}
### Response:"""
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    input_ids.to(device)
    with torch.no_grad():
        outputs = model.generate(inputs=input_ids,
                                temperature=0.0,
                                max_new_tokens=500)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    
text = """FC Barcelona's Spanish defender Jordi Alba and Turkish midfielder Arda Turan have returned to full training, according to the Spanish newspaper Marca on March 28. J. Alba returned to full training after suffering an injury in the Copa del Rey match against Athletic Bilbao on March 17. Arda, who missed the match against Atletico Madrid on March 27 due to a high fever, has also returned to the squad and is now in good shape for the match against Atletico Madrid."""

generate(text)

>> 
ValueError                                Traceback (most recent call last)
Cell In[12], line 5
      2 input_ids.to(device)
      3 with torch.no_grad():
----> 5     outputs = model.generate(inputs=input_ids,
      6                               temperature=0.0,
      7                                 max_new_tokens=500)
      8 print(tokenizer.decode(outputs[0], skip_special_tokens=True))

File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/transformers/generation/utils.py:1604, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1586     return self.contrastive_search(
   1587         input_ids,
   1588         top_k=generation_config.top_k,
   (...)
   1599         **model_kwargs,
   1600     )
   1602 elif is_sample_gen_mode:
   1603     # 11. prepare logits warper
-> 1604     logits_warper = self._get_logits_warper(generation_config)
   1606     # 12. expand input_ids with `num_return_sequences` additional sequences per batch
   1607     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1608         input_ids=input_ids,
   1609         expand_size=generation_config.num_return_sequences,
   1610         is_encoder_decoder=self.config.is_encoder_decoder,
   1611         **model_kwargs,
   1612     )

File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/transformers/generation/utils.py:809, in GenerationMixin._get_logits_warper(self, generation_config)
    806 # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
    807 # all samplers can be found in `generation_utils_samplers.py`
    808 if generation_config.temperature is not None and generation_config.temperature != 1.0:
--> 809     warpers.append(TemperatureLogitsWarper(generation_config.temperature))
    810 min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
    811 if generation_config.top_k is not None and generation_config.top_k != 0:

File ~/anaconda3/envs/llama2/lib/python3.9/site-packages/transformers/generation/logits_process.py:231, in TemperatureLogitsWarper.__init__(self, temperature)
    229 def __init__(self, temperature: float):
    230     if not isinstance(temperature, float) or not (temperature > 0):
--> 231         raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
    233     self.temperature = temperature

ValueError: `temperature` has to be a strictly positive float, but is 0.0
@ArthurZucker
Copy link

Feel free to post this on transformers repo if you still have this issue as you are using transformers

@gante
Copy link

gante commented Aug 24, 2023

Hey @SleepingSkipper 👋 transformers .generate() maintainer here.

We've been adding validation to .generate, adding exceptions to breaking operations and warnings to other incorrect (but harmless output-wise) operations.

Setting temperature=0.0 means a division by 0 operation will occur, which opens a pandora's box of problems :) I'm assuming you want to run greedy decoding, in which case the correct flag is do_sample=False.

To actions from this issue:

  1. Short term: The message in exception will be improved to nudge towards the use of do_sample=False
  2. Long term: We were already thinking of triggering greedy methods when temperature = 0.0, this issue further reinforces it.

@SleepingSkipper
Copy link
Author

@gante Thank you for the info!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants