Skip to content

Commit

Permalink
Improve TransformersModel UX (#12785)
Browse files Browse the repository at this point in the history
  • Loading branch information
hmellor authored Feb 6, 2025
1 parent 56534cd commit 1a6fcad
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from typing import Iterable, Optional, Union
from typing import Iterable, Literal, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -72,15 +72,24 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)


def replace_linear_class(
linear: nn.Linear,
style: str,
style: Literal["colwise", "rowwise"],
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
Quant config is not supported yet
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
`quant_config` is not yet supported.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.
Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
"""

if not isinstance(style, str):
Expand All @@ -93,7 +102,10 @@ def replace_linear_class(
}.get(style)

if vllm_linear_cls is None:
raise ValueError(f"Unsupported parallel style value: {style}")
logger.warning(
"Unsupported parallel style value: %s. "
"This layer will not be tensor parallelized.", style)
return linear

class HFCompatibleLinear(vllm_linear_cls):
"""
Expand All @@ -119,25 +131,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
logger.info("Using Transformers backend.")

self.vllm_config = vllm_config
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config

self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size

self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
torch_dtype=vllm_config.model_config.dtype,
trust_remote_code=vllm_config.model_config.trust_remote_code,
)
prefix = self.model.base_model_prefix

# MLP modifications
self.tensor_parallelize(self.model)
self.apply_base_model_tp_plan(self.model)

# Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -170,13 +181,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config.vocab_size, logit_scale)
self.sampler = get_sampler()

def log_replacement(self, name: str, old_module: nn.Module,
new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)

def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
"""
Apply the base model tensor parallelization plan to a module.
Currently only supports linear layers.
"""
if (self.config.base_model_tp_plan is None
and self.vllm_config.parallel_config.tensor_parallel_size > 1):
and get_tensor_model_parallel_world_size() > 1):
raise ValueError(
"Trying to run tensor parallelization but the model does not "
"support it yet!")
Expand All @@ -189,9 +200,9 @@ def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
new_module = replace_linear_class(child_module, style,
self.quant_config)
setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module)
log_replacement(qual_name, child_module, new_module)
else:
self.tensor_parallelize(child_module, prefix=qual_name)
self.apply_base_model_tp_plan(child_module, prefix=qual_name)

def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
Expand All @@ -201,8 +212,8 @@ def replace_vocab_embed_class(self, module: nn.Module):
org_num_embeddings=self.config.vocab_size,
quant_config=None,
)
self.log_replacement("input embedding",
self.model.get_input_embeddings(), new_module)
log_replacement("input embedding", self.model.get_input_embeddings(),
new_module)
self.model.set_input_embeddings(new_module)

def forward(
Expand Down

0 comments on commit 1a6fcad

Please sign in to comment.