-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: Out of workspace memory in AlignedAlloactor when there is a lot of GPU memory left #362
Comments
Hi @jl3676 , the workspace memory refers to the workspace buffer allocated for flashinfer wrappers, not the entire GPU memory. You can try increasing this value as a workaround: But it's weird to see we still run out of workspace memory even if we allocated 256mb for that. I'll take a look and send a PR to reduce the workspace memory usage. Thanks for reporting this! |
@yzh119 Thank you for the suggestion! I tried increasing the workspace memory buffer to |
Additional information: I ran the same code with |
Would you mind sharing your config for serving gemma-2-27b-it? (e.g. batch size/tp size, etc) |
This is the minimal amount of code I can reproduce the error with. The tp size is 4. All config variables other than the ones specified in the code should be the default values set by vllm.LLM. from vllm import LLM
import flashinfer
import os
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER'
model = LLM(model="google/gemma-2-27b-it",
dtype="auto",
trust_remote_code=True,
tokenizer_mode="auto",
tensor_parallel_size=4) |
thank you! |
Same issue here...
Package version:
FYI: When I set a specific value |
same issue |
v0.0.9 was released, would you mind trying the new version? https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.9 |
same issue, also in v0.0.9 flashinfer. After I kill all process that belong to me, it works fine. But I have to do that every time. |
same issue. |
I am facing this as well! |
I just got around to try version 0.0.9 and the newer version 0.1.0 and got the same error with both. According to this comment, setting |
Got the same bug. |
I got the same issue with Llama-3-70b-instruct on 8xA100 40G |
@yzh119 Is there going to be any action on this? This is blocking most folks' use of Gemma2-27B with vllm... |
Hi @jvlinsta @jl3676, sure I would love to fix the issue ASAP, the main problem is that I cannot reproduce the error on my servers (4xh100), even with @jl3676 's script: #362 (comment) I'll launch an 8xA100 40G aws server (following @warlock135 's config) and see if I can reproduce it. |
@yzh119, I encountered this issue on 8xh100. |
@yzh119 same problem on a 8 H100 machine, I even cannot load and run the Gemma-2-8b-it model.
Error message
I try different gpu_memory_utilization, from 0.4 to 0.85. None of them works. I also try increase |
I can finally reproduce the error, working on a fix, thanks for your information. |
…gin forward functions (#413) This PR makes the following changes to the codebase: 1. make the allocators error information more informative, more specifically, we print the buffer name and requested buffer size in runtime errors for debugging. 2. add checks in prefill wrappers `begin_forward` functions to make sure `qo` and `kv` indptr array size matches. These efforts are designed for avoiding issues such as #362 , which needs to be fixed on vllm side, but we should have more friendly debugging information for locating the potential bugs.
It turns out to be a bug on vllm side, not flashinfer's. and the length of @LiuXiaoxuanPKU @comaniac would you mind fix the behavior on vllm? It should be easy: just set On FlashInfer side, we make some improvements (#413) on the shape checking and more informative error message so that we can avoid such issues in the future. |
now flashinfer 0.1.3 with verbose size check logging, points at exact error.
seems this error now needs to be fixed from vllm as per @yzh119 suggestion. I tried fixing it in prefill.py, seems beyond my understanding atm. so will just sit and wait on the fix for now. |
Should be fixed by vllm-project/vllm#7008, feel free to take a try and report any issue. |
Thank you for the fix! @LiuXiaoxuanPKU @comaniac @yzh119 However, now there appears other errors: If I run inference with sequence length within 4096:
If I run inference with sequence length more than 4096:
I assume that with flashinfer we should be able to conduct inference with gemm2 with 8k max length. These errors look confusing and I would appreciate your guidance. |
@LiuXiaoxuanPKU @yzh119 , This patch worked for me, Thanks for the prompt fix, Appreciate it <3 |
@PeterGriffinJin it should have been fixed in vllm-project/vllm@0badabb |
Closed as vllm-project/vllm#7008 get merged. |
Hi @yzh119, vllm-project/vllm@0badabb solved the "RuntimeError: paged_kv_indices must be a CUDA tensor" error. But it seems that the gemma 2 return results from vLLM is still not expected with flashinfer. vllm-project/vllm#7152 |
same error, 8xh100 |
Could you check the main branch of vllm, it should be fixed now, thanks. |
Hi, I encountered an out of workspace memory error when trying to load the gemma-2-27b model using vllm with the flashinfer backend, which seems to have come from flashinfer. I printed out the GPU memory usage list when this happened (see below), and there was still a lot of free memory across the four GPUs, so I'm not sure what the problem is. I'd appreciate any advice on how to fix this issue. Thanks!
Minimal amount of code to reproduce this error:
Full log:
GPU usage when the error happened:
The text was updated successfully, but these errors were encountered: