Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: different garbage output of same prompt when inferred with single sequence vs concurrent requests on vllm openai server , temp =0. (mixed batching in longrope)) #10336

Open
1 task done
bhupendrathore opened this issue Nov 14, 2024 · 8 comments
Labels
bug Something isn't working

Comments

@bhupendrathore
Copy link

bhupendrathore commented Nov 14, 2024

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

Model Input Dumps

No response

🐛 Describe the bug

vllm version (latest was failing due to some issues like can not decode) :
0.6.1.post1
hosted the model :

CUDA_VISIBLE_DEVICES=0 python3 -m  vllm.entrypoints.openai.api_server --model csp-phi-3-mini-128k-ft-outputs/qlora_merged_model_csp_phi-ckp-23850 --dtype bfloat16 --gpu-memory-utilization 0.9 --disable-log-requests --max-model-len 14000
import requests
import json
import time
VLLM_INFER_URL = "http://0.0.0.0:8000/v1/completions"
def infer_vllm(prompt:str,max_new_tokens = 800,temp=0.0) -> str:
    '''Infer from hosted vllm server'''
        payload = json.dumps({
        "model": "csp-phi-3-mini-128k-ft-outputs/qlora_merged_model_csp_phi-ckp-23850",
        "prompt": prompt,
        "temperature": temp,
        # "top_k": 50,
        "top_p": 1,
        "max_tokens": max_new_tokens
    })
    headers = {
        'Content-Type': 'application/json'
    }
   
    try:
        status_code_failure = False
        start_time = time.time()
        response =  requests.request("POST", VLLM_INFER_URL, headers=headers, data=payload)
        if response.status_code == 200:
            resp = json.loads(response.text)["choices"][0]["text"]
            return resp
        else:
            print(response.json())
           
            return "None"
        

from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
prompts = data.prompt.tolist()
if True:
    with ThreadPoolExecutor(max_workers=5) as executor:
        list_of_results5 = list(tqdm(executor.map(infer_vllm, prompts[:10]), total=len(prompts[:10])))
 
 #first output sample - lest check second response
print(list_of_results5[2])

#vs 

print(infer_vllm(prompts[2]))

#is different i initially thought this might be due to pad tokens but i don't think so

what can be possible reason of that. does the model's pad tokens can affect that ?

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@bhupendrathore bhupendrathore added the bug Something isn't working label Nov 14, 2024
@bhupendrathore
Copy link
Author

bhupendrathore commented Nov 14, 2024

@jeejeelee
Copy link
Collaborator

Maybe similar issue: #9567

@bhupendrathore
Copy link
Author

bhupendrathore commented Nov 22, 2024

i 've been looking for some more on this.. the way i meant different is sometimes the model would give garbage output in the batching but won't in single inference.

i tried running it with max-len 4096 and. the garbage output issue was gone. and possibly the reason of that might be due to. rope scaling or fp8 kv cache in that particular model phi-3-mini-128k-instruct:
huggingface/transformers#33129 (when i infer with transformers it runs garbage free.)
#6135

@bhupendrathore
Copy link
Author

bhupendrathore commented Nov 26, 2024

@jeejeelee it's because of mixed batching .. even with if all batches are longer than 4096 it doesn't give garbage and if all batches are shorter than 4096 than also no garbage. it s when there are mixed batches, i think the commit also mentions the same where @caiom mentioned

when a batch contains long and short sequences, it will always use long factor, even for short samples. Currently we don't support such mixed batches.

#4298 (comment)

is there something we can do to avoid this or any suggestion from side.

@bhupendrathore bhupendrathore changed the title [Bug]: different output of same prompt when inferred with single sequence vs concurrent requests on vllm openai server , temp =0. [Bug]: different garbage output of same prompt when inferred with single sequence vs concurrent requests on vllm openai server , temp =0. (mixed batching in longrope)) Nov 26, 2024
@jeejeelee
Copy link
Collaborator

@bhupendrathore I currently don't have any ideas - perhaps @DarkLight1337 could provide something more insightful

@DarkLight1337
Copy link
Member

@WoosukKwon may be more familiar with this part of the code.

@Galigator
Copy link

I have the same probem with 0.6.3.post1 . I run the model like that vllm serve neuralmagic/Llama-3.1-Nemotron-70B-Instruct-HF-FP8-dynamic --tensor-parallel-size 2 --max-model-len 8192

The max-model-len have been set to avoid the problem... but it is a shame.

@bhupendrathore
Copy link
Author

bhupendrathore commented Nov 29, 2024

@WoosukKwon any direction for me it depends on model.original_max_position_embeddings (in my case4096 ), and mixed batches is giving garbage that is disabling me to use multi concurrency. if at a time infer all prompts < 4096 or all prompts> 4096 then no garbage is coming out. is there anything i can change in Phi3LongRoPEScaledRotaryEmbedding to avoid this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants