Skip to content

Commit

Permalink
[V1][BugFix] Fix edge case in VLM scheduling (vllm-project#12065)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored and abmfy committed Jan 24, 2025
1 parent c0e179c commit b15fe72
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,18 +373,22 @@ def _try_schedule_encoder_inputs(
if self.encoder_cache_manager.has_cache(request, i):
# The encoder input is already computed and cached.
continue
if not self.encoder_cache_manager.can_allocate(request, i):
# The encoder cache is full. We can only schedule the decoder
# tokens just before the encoder input.
num_new_tokens = start_pos - num_computed_tokens
break
if num_encoder_tokens > encoder_budget:
# The encoder budget is exhausted. We can only schedule the
# decoder tokens up until the encoder input.
# NOTE(woosuk): We assume that the encoder tokens should be
# processed altogether, as the encoder usually uses
if (not self.encoder_cache_manager.can_allocate(request, i)
or num_encoder_tokens > encoder_budget):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
# be processed altogether, as the encoder usually uses
# bidirectional attention.
num_new_tokens = start_pos - num_computed_tokens
if num_computed_tokens < start_pos:
# We only schedule the decoder tokens just before the
# encoder input.
num_new_tokens = start_pos - num_computed_tokens
else:
# Because of prefix caching, num_computed_tokens is greater
# than start_pos even though its encoder input is not
# available. In this case, we can't schedule any token for
# the request in this step.
num_new_tokens = 0
break

encoder_budget -= num_encoder_tokens
Expand Down

0 comments on commit b15fe72

Please sign in to comment.