Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue]: rPD trace vLLM benchmark failed #71

Open
alexhegit opened this issue Nov 6, 2024 · 3 comments
Open

[Issue]: rPD trace vLLM benchmark failed #71

alexhegit opened this issue Nov 6, 2024 · 3 comments

Comments

@alexhegit
Copy link

Problem Description

runTrace.sh the vLLM benchmark failed

Operating System

Ubuntu22.04 in the docker image rocm/vllm-dev:20241025-tuned

CPU

AMD EPYC 9654 96-Core Processor

GPU

AMD MI300X

ROCm Version

ROCm 6.2.0

ROCm Component

No response

Steps to Reproduce

  1. Start the container
alias drun="docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 256g --net host -v $PWD:/ws -v /data:/data --entrypoint /bin/bash --env HUGGINGFACE_HUB_CACHE=/data/llm -w /ws"

drun --name rPD-vllm rocm/vllm-dev:20241025-tuned
  1. Install rocmProfileData from /app/rocmProfileData in the container
    Follow the instructions from https://github.com/ROCm/rocmProfileData/

  2. Tracing the vLLM benchmark

runTracer.sh python /app/vllm/benchmarks/benchmark_latency.py \
--model /data/llm/Meta-Llama-3.1-8B/ \
--dtype float16 \
--gpu-memory-utilization 0.99 \
--distributed-executor-backend mp \
--tensor-parallel-size 1 \
--batch-size 32 \
--input-len 128 \
--output-len 128
  1. the vLLM benchmark do not start (test benchmark never run completely without any data result , double check by the rocm-smi show the model never loaded and run)
    The log from rPD show VallueError as bellow,
Creating empty rpd: trace.rpd
rpd_tracer, because
WARNING 11-06 02:52:43 rocm.py:17] `fork` method is not supported by ROCm. VLLM_WORKER_MULTIPROC_METHOD is overridden to `spawn` instead.
Namespace(model='/data/llm/Meta-Llama-3.1-8B/', speculative_model=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, tokenizer=None, quantization=None, tensor_parallel_size=1, input_len=128, output_len=128, batch_size=32, n=1, use_beam_search=False, num_iters_warmup=10, num_iters=30, trust_remote_code=False, max_model_len=None, dtype='float16', enforce_eager=False, kv_cache_dtype='auto', quantization_param_path=None, profile=False, profile_result_dir=None, device='auto', block_size=16, enable_chunked_prefill=False, enable_prefix_caching=False, ray_workers_use_nsight=False, download_dir=None, output_json=None, gpu_memory_utilization=0.99, load_format='auto', distributed_executor_backend='mp', otlp_traces_endpoint=None, num_scheduler_steps=1)
WARNING 11-06 02:52:47 config.py:1711] Casting torch.bfloat16 to torch.float16.
ERROR 11-06 02:52:55 registry.py:270] Error in inspecting model architecture 'LlamaForCausalLM'^M
ERROR 11-06 02:52:55 registry.py:270] Traceback (most recent call last):^M
ERROR 11-06 02:52:55 registry.py:270]   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/model_executor/models/registry.py", line 432, in _run_in_subprocess^M
ERROR 11-06 02:52:55 registry.py:270]     returned.check_returncode()^M
ERROR 11-06 02:52:55 registry.py:270]   File "/opt/conda/envs/py_3.9/lib/python3.9/subprocess.py", line 460, in check_returncode^M
ERROR 11-06 02:52:55 registry.py:270]     raise CalledProcessError(self.returncode, self.args, self.stdout,^M
ERROR 11-06 02:52:55 registry.py:270] subprocess.CalledProcessError: Command '['/opt/conda/envs/py_3.9/bin/python', '-m', 'vllm.model_executor.models.registry']' died with <Signals.SIGABRT: 6>.^M
ERROR 11-06 02:52:55 registry.py:270] ^M
ERROR 11-06 02:52:55 registry.py:270] The above exception was the direct cause of the following exception:^M
ERROR 11-06 02:52:55 registry.py:270] ^M
ERROR 11-06 02:52:55 registry.py:270] Traceback (most recent call last):^M
ERROR 11-06 02:52:55 registry.py:270]   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/model_executor/models/registry.py", line 268, in _try_inspect_model_cls^M
ERROR 11-06 02:52:55 registry.py:270]     return model.inspect_model_cls()^M
ERROR 11-06 02:52:55 registry.py:270]   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/model_executor/models/registry.py", line 230, in inspect_model_cls^M
ERROR 11-06 02:52:55 registry.py:270]     return _run_in_subprocess(^M
ERROR 11-06 02:52:55 registry.py:270]   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/model_executor/models/registry.py", line 435, in _run_in_subprocess^M
ERROR 11-06 02:52:55 registry.py:270]     raise RuntimeError(f"Error raised in subprocess:\n"^M
ERROR 11-06 02:52:55 registry.py:270] RuntimeError: Error raised in subprocess:^M
ERROR 11-06 02:52:55 registry.py:270] rpd_tracer, because^M
ERROR 11-06 02:52:55 registry.py:270] /opt/conda/envs/py_3.9/lib/python3.9/runpy.py:127: RuntimeWarning: 'vllm.model_executor.models.registry' found in sys.modules after import of package 'vllm.model_executor.models', but prior to execution of 'vllm.model_executor.models.registry'; this may result in unpredictable behaviour^M
ERROR 11-06 02:52:55 registry.py:270]   warn(RuntimeWarning(msg))^M
ERROR 11-06 02:52:55 registry.py:270] rocpd_op: 0^M
ERROR 11-06 02:52:55 registry.py:270] rocpd_api_ops: 0^M
ERROR 11-06 02:52:55 registry.py:270] rocpd_kernelapi: 0^M
ERROR 11-06 02:52:55 registry.py:270] rocpd_copyapi: 0^M
ERROR 11-06 02:52:55 registry.py:270] rocpd_api: 0^M
ERROR 11-06 02:52:55 registry.py:270] rocpd_string: 0^M
ERROR 11-06 02:52:55 registry.py:270] rpd_tracer: finalized in 10.585086 ms^M
ERROR 11-06 02:52:55 registry.py:270] double free or corruption (!prev)^M
ERROR 11-06 02:52:55 registry.py:270]
Traceback (most recent call last):
  File "/app/vllm/benchmarks/benchmark_latency.py", line 286, in <module>
    main(args)
  File "/app/vllm/benchmarks/benchmark_latency.py", line 24, in main
    llm = LLM(
  File "vllm/utils.py", line 1181, in vllm.utils.deprecate_args.wrapper.inner
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/entrypoints/llm.py", line 193, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
  File "vllm/engine/llm_engine.py", line 571, in vllm.engine.llm_engine.LLMEngine.from_engine_args
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/engine/arg_utils.py", line 918, in create_engine_config
    model_config = self.create_model_config()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/engine/arg_utils.py", line 853, in create_model_config
    return ModelConfig(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/config.py", line 210, in __init__
    self.multimodal_config = self._init_multimodal_config(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/config.py", line 233, in _init_multimodal_config
    if ModelRegistry.is_multimodal_model(architectures):
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/model_executor/models/registry.py", line 390, in is_multimodal_model
    return self.inspect_model_cls(architectures).supports_multimodal
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/model_executor/models/registry.py", line 359, in inspect_model_cls
    return self._raise_for_unsupported(architectures)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/model_executor/models/registry.py", line 320, in _raise_for_unsupported
    raise ValueError(
ValueError: Model architectures ['LlamaForCausalLM'] are not supported for now. Supported architectures: ['AquilaModel', 'AquilaForCausalLM', 'ArcticForCausalLM', 'BaiChuanForCausalLM', 'BaichuanForCausalLM', 'BloomForCausalLM', 'CohereForCausalLM', 'DbrxForCausalLM', 'DeciLMForCausalLM', 'DeepseekForCausalLM', 'DeepseekV2ForCausalLM', 'ExaoneForCausalLM', 'FalconForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTJForCausalLM', 'GPTNeoXForCausalLM', 'GraniteForCausalLM', 'GraniteMoeForCausalLM', 'InternLMForCausalLM', 'InternLM2ForCausalLM', 'JAISLMHeadModel', 'JambaForCausalLM', 'LlamaForCausalLM', 'LLaMAForCausalLM', 'MambaForCausalLM', 'FalconMambaForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'QuantMixtralForCausalLM', 'Grok1ModelForCausalLM', 'MptForCausalLM', 'MPTForCausalLM', 'MiniCPMForCausalLM', 'MiniCPM3ForCausalLM', 'NemotronForCausalLM', 'OlmoForCausalLM', 'OlmoeForCausalLM', 'OPTForCausalLM', 'OrionForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'Phi3ForCausalLM', 'Phi3SmallForCausalLM', 'PhiMoEForCausalLM', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RWForCausalLM', 'StableLMEpochForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'SolarForCausalLM', 'XverseForCausalLM', 'BartModel', 'BartForConditionalGeneration', 'BertModel', 'Gemma2Model', 'MistralModel', 'Qwen2ForRewardModel', 'Phi3VForCausalLM', 'Blip2ForConditionalGeneration', 'ChameleonForConditionalGeneration', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'FuyuForCausalLM', 'InternVLChatModel', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration', 'LlavaNextVideoForConditionalGeneration', 'LlavaOnevisionForConditionalGeneration', 'MiniCPMV', 'MolmoForCausalLM', 'NVLM_D', 'PaliGemmaForConditionalGeneration', 'PixtralForConditionalGeneration', 'QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'UltravoxModel', 'MllamaForConditionalGeneration', 'EAGLEModel', 'MedusaModel', 'MLPSpeculatorPreTrainedModel']
rocpd_op: 0
rocpd_api_ops: 0
rocpd_kernelapi: 0
rocpd_copyapi: 0
rocpd_api: 0
rocpd_string: 0
rpd_tracer: finalized in 9.810928 ms
double free or corruption (!prev)
/usr/local/bin/runTracer.sh: line 42:  7872 Aborted                 LD_PRELOAD=librpd_tracer.so "$@"

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

@AlexHe99
Copy link

AlexHe99 commented Nov 8, 2024

rPD works fine with vllm v0.6.2+rocm624 ( I install it on the same docker images by building vllm from source code). It maybe a bug of vLLM v0.6.4 triggered by rPD.

root@tw024:/data/vllm# runTracer.sh python ./benchmarks/benchmark_latency.py --model /data/llm/Meta-Llama-3.1-8B/ --dtype float16 --gpu-memory-utilization 0.99 --distributed-executor-backend mp --tensor-parallel-size 1 --batch-size 32 --input-len 128 --output-len 128
Creating empty rpd: trace.rpd
rpd_tracer, because
WARNING 11-08 08:56:29 rocm.py:13] `fork` method is not supported by ROCm. VLLM_WORKER_MULTIPROC_METHOD is overridden to `spawn` instead.
Namespace(model='/data/llm/Meta-Llama-3.1-8B/', speculative_model=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, tokenizer=None, quantization=None, tensor_parallel_size=1, input_len=128, output_len=128, batch_size=32, n=1, use_beam_search=False, num_iters_warmup=10, num_iters=30, trust_remote_code=False, max_model_len=None, dtype='float16', enforce_eager=False, kv_cache_dtype='auto', quantization_param_path=None, profile=False, profile_result_dir=None, device='auto', block_size=16, enable_chunked_prefill=False, enable_prefix_caching=False, use_v2_block_manager=False, ray_workers_use_nsight=False, download_dir=None, output_json=None, gpu_memory_utilization=0.99, load_format='auto', distributed_executor_backend='mp', otlp_traces_endpoint=None)
WARNING 11-08 08:56:32 config.py:1656] Casting torch.bfloat16 to torch.float16.
INFO 11-08 08:56:32 config.py:928] Disabled the custom all-reduce kernel because it is not supported on AMD GPUs.
WARNING 11-08 08:56:32 arg_utils.py:940] The model has a long context length (131072). This may cause OOM errors during the initial memory profiling phase, or result in low performance due to small KV cache space. Consider setting --max-model-len to a smaller value.
INFO 11-08 08:56:32 llm_engine.py:226] Initializing an LLM engine (v0.6.2) with config: model='/data/llm/Meta-Llama-3.1-8B/', speculative_config=None, tokenizer='/data/llm/Meta-Llama-3.1-8B/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/data/llm/Meta-Llama-3.1-8B/, use_v2_block_manager=False, num_scheduler_steps=1, multi_step_stream_outputs=False, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, mm_processor_kwargs=None)
WARNING 11-08 08:56:33 multiproc_gpu_executor.py:53] Reducing Torch parallelism from 192 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 11-08 08:56:34 selector.py:121] Using ROCmFlashAttention backend.
INFO 11-08 08:56:34 model_runner.py:1014] Starting to load model /data/llm/Meta-Llama-3.1-8B/...
INFO 11-08 08:56:34 selector.py:121] Using ROCmFlashAttention backend.
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.18it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:04<00:05,  2.61s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:08<00:03,  3.28s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:12<00:00,  3.60s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:12<00:00,  3.21s/it]

INFO 11-08 08:56:48 model_runner.py:1025] Loading model weights took 14.9888 GB
rpd_tracer, because
rpd_tracer, because
rpd_tracer, because
rpd_tracer, because
rpd_tracer, because
rpd_tracer, because
rpd_tracer, because
rpd_tracer, because
rpd_tracer, because
INFO 11-08 08:56:57 distributed_gpu_executor.py:57] # GPU blocks: 81157, # CPU blocks: 2048
INFO 11-08 08:56:58 model_runner.py:1329] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 11-08 08:56:58 model_runner.py:1333] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 11-08 08:57:08 model_runner.py:1456] Graph capturing finished in 10 secs.
SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=True, max_tokens=128, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None)
Warming up...
Warmup iterations:   0%|                                                                                                                                      | 0/10 [00:00<?, ?it/s]
rpd_tracer, because
Warmup iterations: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:17<00:00,  1.72s/it]
Profiling iterations: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:50<00:00,  1.70s/it]
Avg latency: 1.6951900294904287 seconds
10% percentile latency: 1.6787860402371735 seconds
25% percentile latency: 1.6848566812113859 seconds
50% percentile latency: 1.688172210007906 seconds
75% percentile latency: 1.6905225955124479 seconds
90% percentile latency: 1.692503936321009 seconds
99% percentile latency: 1.8744713573355696 seconds
[rank0]:[W1108 08:58:17.796297280 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
rocpd_op: 0
rocpd_api_ops: 0
rocpd_kernelapi: 0
rocpd_copyapi: 0
rocpd_api: 0
rocpd_string: 0
rpd_tracer: finalized in 1434.479868 ms
double free or corruption (!prev)
/usr/local/bin/runTracer.sh: line 42:  2216 Aborted                 LD_PRELOAD=librpd_tracer.so "$@"
root@tw024:/data/vllm# pip list | grep vllm
vllm                              0.6.2+rocm624

@seungrokj
Copy link

@alexhegit
as this issue stems from the incompatibility of the subprocess vs rdp,
Until rdp supports this subprocess inside the vllm, you can still use rdp by removing subprocess model registry procedure in vllm

https://github.com/ROCm/vllm/blob/2eabfbc263e30399e64ab89375f494c1f3c280f7/vllm/model_executor/models/registry.py#L257-L258

    return _run_in_subprocess(
        lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

->
return _ModelInfo.from_model_cls(self.load_model_cls())

@AlexHe99
Copy link

AlexHe99 commented Nov 8, 2024

@alexhegit as this issue stems from the incompatibility of the subprocess vs rdp, Until rdp supports this subprocess inside the vllm, you can still use rdp by removing subprocess model registry procedure in vllm

https://github.com/ROCm/vllm/blob/2eabfbc263e30399e64ab89375f494c1f3c280f7/vllm/model_executor/models/registry.py#L257-L258

    return _run_in_subprocess(
        lambda: _ModelInfo.from_model_cls(self.load_model_cls()))

-> return _ModelInfo.from_model_cls(self.load_model_cls())

It works with just some log output changed as we remove the subprocess. Thank you @seungrokj

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants