Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Distributed inference fails on certain multimodal models #8983

Closed
1 task done
suna-123 opened this issue Oct 1, 2024 · 3 comments · Fixed by #8986
Closed
1 task done

[Bug]: Distributed inference fails on certain multimodal models #8983

suna-123 opened this issue Oct 1, 2024 · 3 comments · Fixed by #8986
Labels
bug Something isn't working

Comments

@suna-123
Copy link

suna-123 commented Oct 1, 2024

Your current environment

The output of `python collect_env.py`

PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04 LTS (x86_64)
GCC version: (Ubuntu 13.2.0-23ubuntu4) 13.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39

Python version: 3.12.3 (main, Jul 31 2024, 17:43:48) [GCC 13.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-1012-aws-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A10G
GPU 1: NVIDIA A10G
GPU 2: NVIDIA A10G
GPU 3: NVIDIA A10G

Nvidia driver version: 550.54.14
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               48
On-line CPU(s) list:                  0-47
Vendor ID:                            AuthenticAMD
Model name:                           AMD EPYC 7R32
CPU family:                           23
Model:                                49
Thread(s) per core:                   2
Core(s) per socket:                   24
Socket(s):                            1
Stepping:                             0
BogoMIPS:                             5599.33
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            768 KiB (24 instances)
L1i cache:                            768 KiB (24 instances)
L2 cache:                             12 MiB (24 instances)
L3 cache:                             96 MiB (6 instances)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-47
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow:   Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.68
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pyzmq==26.2.0
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.45.1
[pip3] triton==3.0.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.1.dev238+ge2c6e0a82
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0	GPU1	GPU2	GPU3	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	PHB	PHB	PHB	0-47	0		N/A
GPU1	PHB	 X 	PHB	PHB	0-47	0		N/A
GPU2	PHB	PHB	 X 	PHB	0-47	0		N/A
GPU3	PHB	PHB	PHB	 X 	0-47	0		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Model Input Dumps

err_execute_model_input_20240930-213352.pkl.zip

🐛 Describe the bug

Sampel code:

from vllm import LLM, SamplingParams
llm = LLM(model="adept/fuyu-8b", tensor_parallel_size=4, pipeline_parallel_size=1)

This sample code throws the following error on an instance with 4 A10G GPUs.

RuntimeError: Error in model execution (input dumped to /tmp/err_execute_model_input_20240930-231823.pkl): shape mismatch: value tensor of shape [16128, 1024] cannot be broadcast to indexing result of shape [16128, 4096]

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@suna-123 suna-123 added the bug Something isn't working label Oct 1, 2024
@Isotr0py
Copy link
Collaborator

Isotr0py commented Oct 1, 2024

Can you provide the full error logs? So that I can figure out which part is going wrong.

@suna-123
Copy link
Author

suna-123 commented Oct 1, 2024

@Isotr0py here you go.


RuntimeError Traceback (most recent call last)
File ~/vllm_venv/lib/python3.12/site-packages/vllm/worker/model_runner_base.py:116, in dump_input_when_exception.._inner.._wrapper(*args, **kwargs)
115 try:
--> 116 return func(*args, **kwargs)
117 except Exception as err:

File ~/vllm_venv/lib/python3.12/site-packages/vllm/worker/model_runner.py:1590, in ModelRunner.execute_model(self, model_input, kv_caches, intermediate_tensors, num_steps)
1588 model_forward_start.record()
-> 1590 hidden_or_intermediate_states = model_executable(
1591 input_ids=model_input.input_tokens,
1592 positions=model_input.input_positions,
1593 kv_caches=kv_caches,
1594 attn_metadata=model_input.attn_metadata,
1595 intermediate_tensors=intermediate_tensors,
1596 **MultiModalInputs.as_kwargs(multi_modal_kwargs,
1597 device=self.device),
1598 **seqlen_agnostic_kwargs)
1600 if (self.observability_config is not None
1601 and self.observability_config.collect_model_forward_time):

File ~/vllm_venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)

File ~/vllm_venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:

File ~/vllm_venv/lib/python3.12/site-packages/vllm/model_executor/models/fuyu.py:285, in FuyuForCausalLM.forward(self, input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, **kwargs)
284 inputs_embeds = self.language_model.model.embed_tokens(input_ids)
--> 285 inputs_embeds = merge_multimodal_embeddings(
286 input_ids, inputs_embeds, vision_embeddings,
287 self.image_token_id)
289 else:

File ~/vllm_venv/lib/python3.12/site-packages/vllm/model_executor/models/utils.py:180, in merge_multimodal_embeddings(input_ids, inputs_embeds, multimodal_embeddings, placeholder_token_id)
176 raise ValueError(
177 f"Attempted to assign {expr} = {flattened.shape[0]} "
178 f"multimodal tokens to {num_expected_tokens} placeholders")
--> 180 inputs_embeds[mask] = flattened
181 return inputs_embeds

RuntimeError: shape mismatch: value tensor of shape [16128, 1024] cannot be broadcast to indexing result of shape [16128, 4096]

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last)
Cell In[2], line 1
----> 1 llm = LLM(model="adept/fuyu-8b", tensor_parallel_size=4, pipeline_parallel_size=1)

File ~/vllm_venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py:214, in LLM.init(self, model, tokenizer, tokenizer_mode, skip_tokenizer_init, trust_remote_code, tensor_parallel_size, dtype, quantization, revision, tokenizer_revision, seed, gpu_memory_utilization, swap_space, cpu_offload_gb, enforce_eager, max_context_len_to_capture, max_seq_len_to_capture, disable_custom_all_reduce, disable_async_output_proc, mm_processor_kwargs, **kwargs)
189 raise TypeError(
190 "There is no need to pass vision-related arguments anymore.")
191 engine_args = EngineArgs(
192 model=model,
193 tokenizer=tokenizer,
(...)
212 **kwargs,
213 )
--> 214 self.llm_engine = LLMEngine.from_engine_args(
215 engine_args, usage_context=UsageContext.LLM_CLASS)
216 self.request_counter = Counter()

File ~/vllm_venv/lib/python3.12/site-packages/vllm/engine/llm_engine.py:564, in LLMEngine.from_engine_args(cls, engine_args, usage_context, stat_loggers)
562 executor_class = cls._get_executor_cls(engine_config)
563 # Create the LLM engine.
--> 564 engine = cls(
565 **engine_config.to_dict(),
566 executor_class=executor_class,
567 log_stats=not engine_args.disable_log_stats,
568 usage_context=usage_context,
569 stat_loggers=stat_loggers,
570 )
572 return engine

File ~/vllm_venv/lib/python3.12/site-packages/vllm/engine/llm_engine.py:339, in LLMEngine.init(self, model_config, cache_config, parallel_config, scheduler_config, device_config, load_config, lora_config, speculative_config, decoding_config, observability_config, prompt_adapter_config, executor_class, log_stats, usage_context, stat_loggers, input_registry, use_cached_outputs)
325 self.model_executor = executor_class(
326 model_config=model_config,
327 cache_config=cache_config,
(...)
335 observability_config=self.observability_config,
336 )
338 if not self.model_config.embedding_mode:
--> 339 self._initialize_kv_caches()
341 # If usage stat is enabled, collect relevant info.
342 if is_usage_stats_enabled():

File ~/vllm_venv/lib/python3.12/site-packages/vllm/engine/llm_engine.py:474, in LLMEngine._initialize_kv_caches(self)
467 def _initialize_kv_caches(self) -> None:
468 """Initialize the KV cache in the worker(s).
469
470 The workers will determine the number of blocks in both the GPU cache
471 and the swap CPU cache.
472 """
473 num_gpu_blocks, num_cpu_blocks = (
--> 474 self.model_executor.determine_num_available_blocks())
476 if self.cache_config.num_gpu_blocks_override is not None:
477 num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override

File ~/vllm_venv/lib/python3.12/site-packages/vllm/executor/distributed_gpu_executor.py:39, in DistributedGPUExecutor.determine_num_available_blocks(self)
29 """Determine the number of available KV blocks.
30
31 This invokes determine_num_available_blocks on each worker and takes
(...)
36 - tuple[num_gpu_blocks, num_cpu_blocks]
37 """
38 # Get the maximum number of blocks that can be allocated on GPU and CPU.
---> 39 num_blocks = self._run_workers("determine_num_available_blocks", )
41 # Since we use a shared centralized controller, we take the minimum
42 # number of blocks across all workers to make sure all the memory
43 # operators can be applied to all workers.
44 num_gpu_blocks = min(b[0] for b in num_blocks)

File ~/vllm_venv/lib/python3.12/site-packages/vllm/executor/multiproc_gpu_executor.py:185, in MultiprocessingGPUExecutor._run_workers(self, method, async_run_tensor_parallel_workers_only, max_concurrent_workers, *args, **kwargs)
179 worker_outputs = [
180 worker.execute_method(method, *args, **kwargs)
181 for worker in self.workers
182 ]
184 driver_worker_method = getattr(self.driver_worker, method)
--> 185 driver_worker_output = driver_worker_method(*args, **kwargs)
187 # Get the results of the workers.
188 return [driver_worker_output
189 ] + [output.get() for output in worker_outputs]

File ~/vllm_venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)

File ~/vllm_venv/lib/python3.12/site-packages/vllm/worker/worker.py:223, in Worker.determine_num_available_blocks(self)
219 torch.cuda.empty_cache()
221 # Execute a forward pass with dummy inputs to profile the memory usage
222 # of the model.
--> 223 self.model_runner.profile_run()
225 # Calculate the number of blocks that can be allocated with the
226 # profiled peak memory.
227 torch.cuda.synchronize()

File ~/vllm_venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)

File ~/vllm_venv/lib/python3.12/site-packages/vllm/worker/model_runner.py:1236, in GPUModelRunnerBase.profile_run(self)
1231 if not get_pp_group().is_first_rank:
1232 intermediate_tensors = self.model.make_empty_intermediate_tensors(
1233 batch_size=batch_size,
1234 dtype=self.model_config.dtype,
1235 device=self.device)
-> 1236 self.execute_model(model_input, kv_caches, intermediate_tensors)
1237 torch.cuda.synchronize()
1238 return

File ~/vllm_venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)

File ~/vllm_venv/lib/python3.12/site-packages/vllm/worker/model_runner_base.py:152, in dump_input_when_exception.._inner.._wrapper(*args, **kwargs)
146 raise type(err)(f"Error in model execution: "
147 f"{str(err)}") from err
149 logger.info(
150 "Completed writing input of failed execution to %s.",
151 filename)
--> 152 raise type(err)(
153 f"Error in model execution (input dumped to {filename}): "
154 f"{str(err)}") from err

RuntimeError: Error in model execution (input dumped to /tmp/err_execute_model_input_20240930-231823.pkl): shape mismatch: value tensor of shape [16128, 1024] cannot be broadcast to indexing result of shape [16128, 4096]

@Isotr0py
Copy link
Collaborator

Isotr0py commented Oct 1, 2024

@suna-123 #8986 should fix this (tested with tp_size=2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants