Skip to content

Commit

Permalink
remove dead code
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Sep 10, 2024
1 parent 251d1d0 commit bdb6473
Showing 1 changed file with 4 additions and 81 deletions.
85 changes: 4 additions & 81 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)

try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)

import torch

from vllm.distributed import get_pp_group
Expand Down Expand Up @@ -486,70 +479,6 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode

def _advance_step_flashattn(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None

num_seqs = model_input.num_seqs
num_queries = model_input.num_queries
assert num_seqs > 0
assert num_queries > 0
assert num_seqs >= num_queries

attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)

attn_metadata.advance_step(
frozen_model_input,
model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
num_seqs, num_queries)

if frozen_model_input.seq_lens is not None:
for i in range(num_queries):
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]

return model_input

def _advance_step_flashinfer(
self,
model_input: StatefulModelInput,
out: SamplerOutput,
) -> StatefulModelInput:
"""Advance the model input for the next step."""
# Append the output token to the sequence data.
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
attn_metadata = frozen_model_input.attn_metadata
num_seqs = model_input.num_seqs
num_queries = model_input.num_queries

sampled_tokens = model_input.cached_outputs[-1].sampled_token_ids
assert sampled_tokens is not None
assert frozen_model_input.input_tokens is not None
frozen_model_input.input_tokens[:num_queries] = sampled_tokens.flatten(
)

# Update GPU tensors
ops.advance_step_flashinfer(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=self.block_size,
input_tokens=frozen_model_input.input_tokens,
sampled_token_ids=frozen_model_input.input_tokens,
input_positions=frozen_model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables,
paged_kv_indices=attn_metadata.paged_kv_indices,
paged_kv_indptr=attn_metadata.paged_kv_indptr,
paged_kv_last_page_len=attn_metadata.paged_kv_last_page_len,
block_table_bound=attn_metadata.block_table_bound)

return model_input

def _advance_step(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
Expand All @@ -563,24 +492,18 @@ def _advance_step(self, model_input: StatefulModelInput,
num_queries = model_input.num_queries
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
attn_metadata = frozen_model_input.attn_metadata
assert attn_metadata is not None

self.attn_backend.advance_step(
attn_metadata.advance_step(
frozen_model_input,
sampled_token_ids,
self.block_size,
num_seqs,
num_queries,
)

return

if self.attn_backend.get_name() == "flash-attn":
return self._advance_step_flashattn(model_input, out)
elif self.attn_backend.get_name() == "flashinfer":
return self._advance_step_flashinfer(model_input, out)
else:
raise ValueError(
f"Unsupported attention backend: {self.attn_backend}")
return model_input

def load_model(self) -> None:
return self._base_model_runner.load_model()
Expand Down

0 comments on commit bdb6473

Please sign in to comment.