From d1789c0f24f9c942a7e21815cb4ea6d553bd6ef1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 28 Dec 2024 02:48:24 +0800 Subject: [PATCH] [Misc]Add BNB quantization for MolmoForCausalLM (#11551) Signed-off-by: Jee Jee Li Signed-off-by: Bowen Wang --- vllm/model_executor/model_loader/loader.py | 26 +++++-- vllm/model_executor/models/molmo.py | 90 ++++++++++++++++------ 2 files changed, 83 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f2d9293b31a83..4bca13cb2f60c 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -11,7 +11,8 @@ import warnings from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast +from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, + Tuple, cast) import gguf import huggingface_hub @@ -706,6 +707,8 @@ def __init__(self, load_config: LoadConfig): # Store all module names (from transformers) that support # BNB quantization. self.target_modules: List[str] = [] + # mapping weight names from transformers to vllm. + self.weight_mapper: Callable = lambda name: name def _get_weight_files( self, @@ -763,9 +766,12 @@ def _prepare_weights(self, model_name_or_path: str, def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): if use_safetensors: - return safetensors_weights_iterator(hf_weights_files) + iterator = safetensors_weights_iterator(hf_weights_files) else: - return pt_weights_iterator(hf_weights_files) + iterator = pt_weights_iterator(hf_weights_files) + for name, param in iterator: + # mapping weight names from transformers to vllm. + yield self.weight_mapper(name), param def _get_quantized_weights_iterator( self, @@ -782,12 +788,12 @@ def _get_quantized_weights_iterator( try: import bitsandbytes - if bitsandbytes.__version__ < "0.44.0": + if bitsandbytes.__version__ < "0.45.0": raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.44.0.") + "install bitsandbytes>=0.45.0.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.44.0 via " - "`pip install bitsandbytes>=0.44.0` to use " + raise ImportError("Please install bitsandbytes>=0.45.0 via " + "`pip install bitsandbytes>=0.45.0` to use " "bitsandbytes quantizer.") from err hf_weights_files, use_safetensors = self._prepare_weights( @@ -991,7 +997,7 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: if isinstance(module, (LinearBase, )): last_name = name.split(".")[-1] if sub_modules := inverse_stacked_mapping.get(last_name, []): - # Map vllm's names to transformers' names. + # Map vllm's names to transformers's names. for sub_name in sub_modules: self.target_modules.append( name.replace(last_name, sub_name)) @@ -1013,6 +1019,10 @@ def _load_weights(self, model_config: ModelConfig, f"Model {type(model).__name__} does not support BitsAndBytes " "quantization yet.") + # For some models like Molmo, we need to use hf_to_vllm_mapper + # to ensure correct loading of weights. + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) # Modules whose weights might have fused on disk # we need their output_sizes to make shard in flight correctly with TP self.maybe_fused_weights_modules: Dict[str, List[int]] = {} diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 8938f62d0c494..5d52d2c3e6b48 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -461,30 +461,71 @@ def forward( return output -class MolmoMLP(nn.Module): +class SwiGLU(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gate = x.chunk(2, dim=-1) + # Note that the order is reversed compared to + # SiluAndMul. + return x * F.silu(gate) + + +class LanuageModelMLP(nn.Module): """Molmo's LLM mlp.""" def __init__(self, config: PretrainedConfig, input_dim: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - proj_name: str = "gate_up_proj") -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size // 2 - # Molmo's LLM proj weights are already merged into the disk, while - # image_projector proj is separate. If the same proj_name were used, it - # would create ambiguity and make it difficult to support BNB and LoRA. - self.proj_name = proj_name - setattr( - self, proj_name, - MergedColumnParallelLinear( - input_dim or self.hidden_size, - [self.intermediate_size] * 2, - bias=False, - quant_config=quant_config, - )) + self.gate_up_proj = MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + # Activation function. + self.act_fn = SwiGLU() + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class ImageProjectorMLP(nn.Module): + """Molmo's image_projector mlp.""" + + def __init__( + self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // 2 + + self.merged_linear = MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) # Activation function. self.act_fn = SiluAndMul() @@ -500,7 +541,7 @@ def forward( self, x: torch.Tensor, ) -> torch.Tensor: - gate_up, _ = getattr(self, self.proj_name)(x) + gate_up, _ = self.merged_linear(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x @@ -523,9 +564,7 @@ def __init__( prefix=f"{prefix}.self_attn") # MLP block. - self.mlp = MolmoMLP(config, - quant_config=quant_config, - proj_name="gate_up_proj") + self.mlp = LanuageModelMLP(config, quant_config=quant_config) # LayerNorm assert config.layer_norm_type == "rms" @@ -617,11 +656,10 @@ def __init__( vision_config, nlayers=len(self.vit_layers), quant_config=quant_config) - self.image_projector = MolmoMLP( + self.image_projector = ImageProjectorMLP( config, input_dim=vision_config.image_emb_dim, quant_config=quant_config, - proj_name="merged_linear", ) image_dim = vision_config.image_emb_dim * len(self.vit_layers) @@ -842,10 +880,6 @@ def load_weights(self, weights: Iterable[Tuple[str, loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "gate_up_proj" in name: - up_proj, gate_proj = loaded_weight.chunk(2, dim=0) - loaded_weight = torch.cat([gate_proj, up_proj], dim=0) - if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): @@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): }, ) + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + "gate_proj": ("merged_linear", 0), + "up_proj": ("merged_linear", 1), + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config