Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] support minicpm3 #8297

Merged
merged 12 commits into from
Sep 14, 2024
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ Decoder-only Language Models
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
-
* - :code:`MiniCPM3ForCausalLM`
- MiniCPM3
- :code:`openbmb/MiniCPM3-4B`
-
* - :code:`MistralForCausalLM`
- Mistral, Mistral-Instruct
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
Expand Down
79 changes: 49 additions & 30 deletions vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,38 +270,47 @@ def __init__(
) -> None:
super().__init__()
self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.rope_theta = getattr(config, "rope_theta", 10000)
self.rope_scaling = getattr(config, "rope_scaling", None)
self.max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self._init_attn_block()
self._init_ffn_block()

def _init_attn_block(self):
self.input_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.self_attn = MiniCPMAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
num_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
rope_theta=self.rope_theta,
rope_scaling=self.rope_scaling,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
)

def _init_ffn_block(self):
self.post_attention_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.num_experts = getattr(self.config, "num_experts", 0)
if self.num_experts == 0:
self.mlp = MiniCPMMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
intermediate_size=self.config.intermediate_size,
hidden_act=self.config.hidden_act,
quant_config=self.quant_config,
)
else:
self.mlp = MiniCPMMoE(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.mlp = MiniCPMMoE(
num_experts=self.config.num_experts,
top_k=self.config.num_experts_per_tok,
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size)

def forward(
self,
Expand Down Expand Up @@ -344,6 +353,8 @@ def __init__(
) -> None:
super().__init__()
self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
Expand All @@ -354,11 +365,15 @@ def __init__(
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self._init_layers()
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def _init_layers(self):
self.layers = nn.ModuleList([
MiniCPMDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
MiniCPMDecoderLayer(self.config, self.cache_config,
self.quant_config)
for _ in range(self.config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids)
Expand Down Expand Up @@ -431,13 +446,11 @@ def __init__(

self.config = config
self.lora_config = lora_config
self.cache_config = cache_config
self.quant_config = quant_config

self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config
self.model = MiniCPMModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self._init_model()
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
Expand All @@ -458,6 +471,12 @@ def __init__(
config.vocab_size)
self.sampler = Sampler()

def _init_model(self):
self.model = MiniCPMModel(config=self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
lora_config=self.lora_config)

def forward(
self,
input_ids: torch.Tensor,
Expand Down
216 changes: 216 additions & 0 deletions vllm/model_executor/models/minicpm3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2024 The ModelBest team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM3 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional

import torch
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM,
MiniCPMModel)


class MiniCPM3Attention(nn.Module):

def __init__(
self,
config,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads

tp_size = get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size

self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config)

self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
self.kv_lora_rank +
self.qk_rope_head_dim,
bias=False,
quant_config=quant_config)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config)
# O projection.
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config)

self.rotary_emb = get_rope(
self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_local_heads,
self.qk_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
q = self.q_a_proj(hidden_states)[0]
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k_pe = latent_cache[:, :, self.kv_lora_rank:]

q_pe, k_pe = self.rotary_emb(
positions,
q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim),
k_pe.reshape(-1, self.qk_rope_head_dim))
q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim)
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

q[..., self.qk_nope_head_dim:] = q_pe

k = torch.empty_like(q)

k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe

q = q.reshape(-1, self.num_local_heads * self.qk_head_dim)
k = k.view(-1, self.num_local_heads * self.qk_head_dim)
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim)

attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view(
-1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)

output, _ = self.o_proj(attn_output)
return output


class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):

def _init_attn_block(self):
self.input_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.self_attn = MiniCPM3Attention(
config=self.config,
hidden_size=self.hidden_size,
num_heads=self.config.num_attention_heads,
qk_nope_head_dim=self.config.qk_nope_head_dim,
qk_rope_head_dim=self.config.qk_rope_head_dim,
v_head_dim=self.config.v_head_dim,
q_lora_rank=self.config.q_lora_rank,
kv_lora_rank=self.config.kv_lora_rank,
rope_theta=self.rope_theta,
rope_scaling=self.rope_scaling,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
)


class MiniCPM3Model(MiniCPMModel):

def _init_layers(self):
self.layers = nn.ModuleList([
MiniCPM3DecoderLayer(self.config, self.cache_config,
self.quant_config)
for _ in range(self.config.num_hidden_layers)
])


class MiniCPM3ForCausalLM(MiniCPMForCausalLM):

def _init_model(self):
self.model = MiniCPM3Model(config=self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
lora_config=self.lora_config)
Loading