Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Out of workspace memory in AlignedAlloactor when there is a lot of GPU memory left #362

Closed
jl3676 opened this issue Jul 7, 2024 · 30 comments · Fixed by vllm-project/vllm#7008
Assignees
Labels
bug Something isn't working

Comments

@jl3676
Copy link

jl3676 commented Jul 7, 2024

Hi, I encountered an out of workspace memory error when trying to load the gemma-2-27b model using vllm with the flashinfer backend, which seems to have come from flashinfer. I printed out the GPU memory usage list when this happened (see below), and there was still a lot of free memory across the four GPUs, so I'm not sure what the problem is. I'd appreciate any advice on how to fix this issue. Thanks!

Minimal amount of code to reproduce this error:

from vllm import LLM
import flashinfer
import os

os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER'
model = LLM(model="google/gemma-2-27b-it",
                    dtype="auto",
                    trust_remote_code=True,
                    tokenizer_mode="auto",
                    tensor_parallel_size=4)

Full log:

(VLLM pid=1950315) WARNING 07-07 13:43:20 utils.py:562] Gemma 2 uses sliding window attention for every odd layer, which is currently not supported by vLLM. Disabling sliding window and capping the max length to the sliding window size (4096).
(VLLM pid=1950315) INFO 07-07 13:43:20 config.py:698] Defaulting to use mp for distributed inference
(VLLM pid=1950315) INFO 07-07 13:43:20 llm_engine.py:169] Initializing an LLM engine (v0.5.1) with config: model='google/gemma-2-27b-it', speculative_config=None, tokenizer='google/gemma-2-27b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=google/gemma-2-27b-it, use_v2_block_manager=False, enable_prefix_caching=False)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) INFO 07-07 13:43:21 selector.py:79] Using Flashinfer backend.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) WARNING 07-07 13:43:21 selector.py:80] Flashinfer will be stuck on llama-2-7b, please avoid using Flashinfer as the backend when running on llama-2-7b.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) INFO 07-07 13:43:21 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) INFO 07-07 13:43:21 selector.py:79] Using Flashinfer backend.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) WARNING 07-07 13:43:21 selector.py:80] Flashinfer will be stuck on llama-2-7b, please avoid using Flashinfer as the backend when running on llama-2-7b.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) INFO 07-07 13:43:21 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VLLM pid=1950315) INFO 07-07 13:43:21 selector.py:79] Using Flashinfer backend.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) INFO 07-07 13:43:21 selector.py:79] Using Flashinfer backend.
(VLLM pid=1950315) WARNING 07-07 13:43:21 selector.py:80] Flashinfer will be stuck on llama-2-7b, please avoid using Flashinfer as the backend when running on llama-2-7b.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) WARNING 07-07 13:43:21 selector.py:80] Flashinfer will be stuck on llama-2-7b, please avoid using Flashinfer as the backend when running on llama-2-7b.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) INFO 07-07 13:43:21 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VLLM pid=1950315) INFO 07-07 13:43:21 utils.py:741] Found nccl from library libnccl.so.2
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) INFO 07-07 13:43:21 utils.py:741] Found nccl from library libnccl.so.2
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) INFO 07-07 13:43:21 utils.py:741] Found nccl from library libnccl.so.2
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) INFO 07-07 13:43:21 utils.py:741] Found nccl from library libnccl.so.2
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) INFO 07-07 13:43:21 pynccl.py:63] vLLM is using nccl==2.20.5
(VLLM pid=1950315) INFO 07-07 13:43:21 pynccl.py:63] vLLM is using nccl==2.20.5
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) INFO 07-07 13:43:21 pynccl.py:63] vLLM is using nccl==2.20.5
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) INFO 07-07 13:43:21 pynccl.py:63] vLLM is using nccl==2.20.5
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) WARNING 07-07 13:43:21 custom_all_reduce.py:118] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VLLM pid=1950315) WARNING 07-07 13:43:21 custom_all_reduce.py:118] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) WARNING 07-07 13:43:21 custom_all_reduce.py:118] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) WARNING 07-07 13:43:21 custom_all_reduce.py:118] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) INFO 07-07 13:43:22 selector.py:79] Using Flashinfer backend.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) WARNING 07-07 13:43:22 selector.py:80] Flashinfer will be stuck on llama-2-7b, please avoid using Flashinfer as the backend when running on llama-2-7b.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) INFO 07-07 13:43:22 selector.py:79] Using Flashinfer backend.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) INFO 07-07 13:43:22 selector.py:79] Using Flashinfer backend.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) WARNING 07-07 13:43:22 selector.py:80] Flashinfer will be stuck on llama-2-7b, please avoid using Flashinfer as the backend when running on llama-2-7b.
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) WARNING 07-07 13:43:22 selector.py:80] Flashinfer will be stuck on llama-2-7b, please avoid using Flashinfer as the backend when running on llama-2-7b.
(VLLM pid=1950315) INFO 07-07 13:43:22 selector.py:79] Using Flashinfer backend.
(VLLM pid=1950315) WARNING 07-07 13:43:22 selector.py:80] Flashinfer will be stuck on llama-2-7b, please avoid using Flashinfer as the backend when running on llama-2-7b.
(VLLM pid=1950315) INFO 07-07 13:43:22 weight_utils.py:218] Using model weights format ['*.safetensors']
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) INFO 07-07 13:43:22 weight_utils.py:218] Using model weights format ['*.safetensors']
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) INFO 07-07 13:43:22 weight_utils.py:218] Using model weights format ['*.safetensors']
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) INFO 07-07 13:43:22 weight_utils.py:218] Using model weights format ['*.safetensors']
(VLLM pid=1950315) INFO 07-07 13:43:27 model_runner.py:255] Loading model weights took 12.8146 GB
(VLLM pid=1950315) (VllmWorkerProcess pid=1950429) INFO 07-07 13:43:27 model_runner.py:255] Loading model weights took 12.8146 GB
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) INFO 07-07 13:43:28 model_runner.py:255] Loading model weights took 12.8146 GB
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) INFO 07-07 13:43:28 model_runner.py:255] Loading model weights took 12.8146 GB
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks: RuntimeError: Out of workspace memory in AlignedAlloactor, Traceback (most recent call last):
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/worker/worker.py", line 173, in determine_num_available_blocks
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     self.model_runner.profile_run()
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 874, in profile_run
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1221, in execute_model
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     model_input.attn_metadata.begin_forward()
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/attention/backends/flashinfer.py", line 132, in begin_forward
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     self.prefill_wrapper.begin_forward(
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/flashinfer/prefill.py", line 778, in begin_forward
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     self._wrapper.begin_forward(
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226] RuntimeError: RuntimeError: Out of workspace memory in AlignedAlloactor
(VLLM pid=1950315) (VllmWorkerProcess pid=1950430) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226] 
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks: RuntimeError: Out of workspace memory in AlignedAlloactor, Traceback (most recent call last):
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
ne 115, in decorate_context
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/worker/worker.py", line 173, in determine_num_available_blocks
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     self.model_runner.profile_run()
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 874, in profile_run
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1221, in execute_model
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     model_input.attn_metadata.begin_forward()
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/vllm/attention/backends/flashinfer.py", line 132, in begin_forward
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     self.prefill_wrapper.begin_forward(
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]   File "/home/jingjingl/.conda/envs/harm_project/lib/python3.10/site-packages/flashinfer/prefill.py", line 778, in begin_forward
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226]     self._wrapper.begin_forward(
(VLLM pid=1950315) (VllmWorkerProcess pid=1950428) ERROR 07-07 13:43:36 multiproc_worker_utils.py:226] RuntimeError: RuntimeError: Out of workspace memory in AlignedAlloactor

GPU usage when the error happened:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    On   | 00000000:01:00.0 Off |                  Off |
| 30%   30C    P8    21W / 300W |  14151MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    On   | 00000000:25:00.0 Off |                  Off |
| 30%   29C    P8    17W / 300W |  14119MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA RTX A6000    On   | 00000000:41:00.0 Off |                  Off |
| 30%   29C    P8    21W / 300W |  14119MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA RTX A6000    On   | 00000000:E1:00.0 Off |                  Off |
| 30%   28C    P8    19W / 300W |  14099MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
@yzh119
Copy link
Collaborator

yzh119 commented Jul 7, 2024

Hi @jl3676 , the workspace memory refers to the workspace buffer allocated for flashinfer wrappers, not the entire GPU memory. You can try increasing this value as a workaround:

https://github.com/vllm-project/vllm/blob/3b08fe2b13ced7fe76abe17c99614dd36e4b4788/vllm/worker/model_runner.py#L18

But it's weird to see we still run out of workspace memory even if we allocated 256mb for that. I'll take a look and send a PR to reduce the workspace memory usage. Thanks for reporting this!

@jl3676
Copy link
Author

jl3676 commented Jul 8, 2024

@yzh119 Thank you for the suggestion! I tried increasing the workspace memory buffer to FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 * 18, which is about the highest I could do given my GPU memory, but I still got the same RuntimeError: Out of workspace memory in AlignedAlloactor. Do you see any other potential problems or solutions?

@jl3676
Copy link
Author

jl3676 commented Jul 8, 2024

Additional information: I ran the same code with google/gemma-2-9b-it and it worked without needing to increase FLASHINFER_WORKSPACE_BUFFER_SIZE, which led me to believe this issue is specific to the larger model google/gemma-2-27b-it

@yzh119
Copy link
Collaborator

yzh119 commented Jul 8, 2024

Would you mind sharing your config for serving gemma-2-27b-it? (e.g. batch size/tp size, etc)

@yzh119 yzh119 self-assigned this Jul 8, 2024
@yzh119 yzh119 added bug Something isn't working roadmap and removed roadmap labels Jul 8, 2024
@jl3676
Copy link
Author

jl3676 commented Jul 8, 2024

This is the minimal amount of code I can reproduce the error with. The tp size is 4. All config variables other than the ones specified in the code should be the default values set by vllm.LLM.

from vllm import LLM
import flashinfer
import os

os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER'
model = LLM(model="google/gemma-2-27b-it",
                    dtype="auto",
                    trust_remote_code=True,
                    tokenizer_mode="auto",
                    tensor_parallel_size=4)

@yzh119
Copy link
Collaborator

yzh119 commented Jul 8, 2024

thank you!

@jvlinsta
Copy link

jvlinsta commented Jul 9, 2024

Same issue here...
My config:

Initializing an LLM engine (v0.5.1) with config: model='google/gemma-2-27b-it', speculative_config=None, tokenizer='google/gemma-2-27b-it', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=google/gemma-2-27b-it, use_v2_block_manager=False, enable_prefix_caching=False)

Package version:

Name: flashinfer
Version: 0.0.8+cu121torch2.3
Summary: FlashInfer: Kernel Library for LLM Serving
Home-page: https://github.com/flashinfer-ai/flashinfer

FYI: When I set a specific value --max_seq_len=4000 it did start up without issues

@momomobinx
Copy link

same issue

@yzh119
Copy link
Collaborator

yzh119 commented Jul 12, 2024

v0.0.9 was released, would you mind trying the new version?

https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.9

@LinglingGreat
Copy link

LinglingGreat commented Jul 13, 2024

same issue, also in v0.0.9 flashinfer. After I kill all process that belong to me, it works fine. But I have to do that every time.

@SaeedNajafi
Copy link

same issue.

@mohit-rag
Copy link

I am facing this as well!

@jl3676
Copy link
Author

jl3676 commented Jul 17, 2024

v0.0.9 was released, would you mind trying the new version?

https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.9

I just got around to try version 0.0.9 and the newer version 0.1.0 and got the same error with both.

According to this comment, setting --max_seq_len=4000 may be a workaround. How can I set this variable? Sorry if this is a dumb question but it isn't immediately obvious to me...

@zxia545
Copy link

zxia545 commented Jul 21, 2024

Got the same bug.

@warlock135
Copy link

I got the same issue with Llama-3-70b-instruct on 8xA100 40G

@jvlinsta
Copy link

@yzh119 Is there going to be any action on this? This is blocking most folks' use of Gemma2-27B with vllm...
Would be highly appreciated ;)

@yzh119
Copy link
Collaborator

yzh119 commented Jul 29, 2024

Hi @jvlinsta @jl3676, sure I would love to fix the issue ASAP, the main problem is that I cannot reproduce the error on my servers (4xh100), even with @jl3676 's script: #362 (comment)

I'll launch an 8xA100 40G aws server (following @warlock135 's config) and see if I can reproduce it.
@mohit-rag @SaeedNajafi would you mind posting your environment (GPU architecture, how many GPUs, etc)?

@dstnluong-google
Copy link

@yzh119, I encountered this issue on 8xh100.

@PeterGriffinJin
Copy link

PeterGriffinJin commented Jul 30, 2024

@yzh119 same problem on a 8 H100 machine, I even cannot load and run the Gemma-2-8b-it model.

from vllm import LLM
import flashinfer
import os

os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER'
model = LLM(model="google/gemma-2-9b-it", 
                         tensor_parallel_size=8, 
                         trust_remote_code=True, gpu_memory_utilization=0.85)

Error message

RuntimeError: RuntimeError: Out of workspace memory in AlignedAlloactor

I try different gpu_memory_utilization, from 0.4 to 0.85. None of them works. I also try increase FLASHINFER_WORKSPACE_BUFFER_SIZE in https://github.com/vllm-project/vllm/blob/3b08fe2b13ced7fe76abe17c99614dd36e4b4788/vllm/worker/model_runner.py#L18. It does not work either.

@yzh119
Copy link
Collaborator

yzh119 commented Jul 31, 2024

I can finally reproduce the error, working on a fix, thanks for your information.

yzh119 added a commit that referenced this issue Jul 31, 2024
…gin forward functions (#413)

This PR makes the following changes to the codebase:
1. make the allocators error information more informative, more
specifically, we print the buffer name and requested buffer size in
runtime errors for debugging.
2. add checks in prefill wrappers `begin_forward` functions to make sure
`qo` and `kv` indptr array size matches.

These efforts are designed for avoiding issues such as #362 , which
needs to be fixed on vllm side, but we should have more friendly
debugging information for locating the potential bugs.
@yzh119
Copy link
Collaborator

yzh119 commented Jul 31, 2024

It turns out to be a bug on vllm side, not flashinfer's.
More specifically, when initialize the cudagraph, vllm always sets paged_kv_indptr to torch.tensor([0]) in prefill wrappers' begin_forward functions:
https://github.com/vllm-project/vllm/blob/c0644cf9cef0002485749defcaa02e3fec359d49/vllm/attention/backends/flashinfer.py#L146-L150

and the length of qo_indptr (which equals batch_size+1) do not matches the length of paged_kv_indptr (which equals 1), we didn't check it on flashinfer side, and thus our scheduler will access some uninitialized values and compute some extremely large workspace buffer size.

@LiuXiaoxuanPKU @comaniac would you mind fix the behavior on vllm? It should be easy: just set paged_kv_indptr to torch.zeros(batch_size+1) for prefill wrappers.

On FlashInfer side, we make some improvements (#413) on the shape checking and more informative error message so that we can avoid such issues in the future.

@kzos
Copy link

kzos commented Jul 31, 2024

now flashinfer 0.1.3 with verbose size check logging, points at exact error.

flashinfer/python/flashinfer/prefill.py", line 791, in begin_forward
(VllmWorkerProcess pid=3031282) ERROR 07-31 14:08:31 multiproc_worker_utils.py:226]     self._wrapper.begin_forward(
(VllmWorkerProcess pid=3031282) ERROR 07-31 14:08:31 multiproc_worker_utils.py:226] RuntimeError: CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1) failed. 1 vs 257
(VllmWorkerProcess pid=3031282) ERROR 07-31 14:08:31 multiproc_worker_utils.py:226]

seems this error now needs to be fixed from vllm as per @yzh119 suggestion. I tried fixing it in prefill.py, seems beyond my understanding atm.

so will just sit and wait on the fix for now.

@LiuXiaoxuanPKU
Copy link
Contributor

LiuXiaoxuanPKU commented Jul 31, 2024

Should be fixed by vllm-project/vllm#7008, feel free to take a try and report any issue.

@PeterGriffinJin
Copy link

Thank you for the fix! @LiuXiaoxuanPKU @comaniac @yzh119

However, now there appears other errors:

If I run inference with sequence length within 4096:

return self._wrapper.forward( [repeated 6x across cluster]
(RayWorkerWrapper pid=175490) ERROR 08-01 00:21:09 worker_base.py:382] RuntimeError: paged_kv_indices must be a CUDA tensor [repeated 6x across cluster]

If I run inference with sequence length more than 4096:

WARNING 08-01 00:25:31 scheduler.py:699] Input prompt (5552 tokens) is too long and exceeds limit of 4096

I assume that with flashinfer we should be able to conduct inference with gemm2 with 8k max length. These errors look confusing and I would appreciate your guidance.

@kzos
Copy link

kzos commented Aug 1, 2024

@LiuXiaoxuanPKU @yzh119 , This patch worked for me,
I had to manually pull this patch to my main branch

Thanks for the prompt fix, Appreciate it <3

@yzh119
Copy link
Collaborator

yzh119 commented Aug 1, 2024

@PeterGriffinJin it should have been fixed in vllm-project/vllm@0badabb

@yzh119
Copy link
Collaborator

yzh119 commented Aug 2, 2024

Closed as vllm-project/vllm#7008 get merged.

@yzh119 yzh119 closed this as completed Aug 2, 2024
@PeterGriffinJin
Copy link

Hi @yzh119,

vllm-project/vllm@0badabb solved the "RuntimeError: paged_kv_indices must be a CUDA tensor" error.

But it seems that the gemma 2 return results from vLLM is still not expected with flashinfer. vllm-project/vllm#7152

@learninmou
Copy link

same error, 8xh100

@LiuXiaoxuanPKU
Copy link
Contributor

same error, 8xh100

Could you check the main branch of vllm, it should be fixed now, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.