Skip to content

Commit

Permalink
Add StableLM support (#410)
Browse files Browse the repository at this point in the history
Co-authored-by: Casper <[email protected]>
  • Loading branch information
Isotr0py and casper-hansen authored Apr 6, 2024
1 parent 33dfb04 commit e9f6269
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 9 deletions.
3 changes: 2 additions & 1 deletion awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
from .mixtral import MixtralAWQForCausalLM
from .qwen2 import Qwen2AWQForCausalLM
from .gemma import GemmaAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM
from .stablelm import StableLmAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"llava": LlavaAWQForCausalLM,
"qwen2": Qwen2AWQForCausalLM,
"gemma": GemmaAWQForCausalLM,
"stablelm": StableLmAWQForCausalLM,
"starcoder2": Starcoder2AWQForCausalLM,
}

Expand Down
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"llava": "AutoModelForVision2Seq",
"qwen2": "AutoModelForCausalLM",
"gemma": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM",
}

Expand Down
136 changes: 136 additions & 0 deletions awq/models/stablelm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.stablelm import StableLmForCausalLM as OldStableLmForCausalLM
from transformers.models.stablelm.modeling_stablelm import (
StableLmDecoderLayer as OldStableLmDecoderLayer,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm


class StableLmAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "StableLmDecoderLayer"
max_seq_len_key = "max_position_embeddings"

@staticmethod
def fuse_layers(model: OldStableLmForCausalLM):
fuser = StableLmFuser(model)
fuser.fuse_transformer()

@staticmethod
def get_model_layers(model: OldStableLmForCausalLM):
return model.model.layers

@staticmethod
def get_act_for_scaling(module: OldStableLmForCausalLM):
return dict(is_scalable=False)

@staticmethod
def move_embed(model: OldStableLmForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(
module: OldStableLmDecoderLayer, input_feat, module_kwargs
):
layers = []

# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)

# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)

# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)

# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

return layers


class StableLmFuser:
def __init__(self, model: OldStableLmForCausalLM):
self.model = model

self.stablelm_blocks: List[Tuple[str, OldStableLmDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "StableLmDecoderLayer".lower() in module.__class__.__name__.lower()
]

def fuse_transformer(self):
blocks = []

module: OldStableLmDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = module.input_layernorm
norm_2 = module.post_attention_layernorm
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
partial_rotary_factor=self.model.config.partial_rotary_factor,
)
)

self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
30 changes: 22 additions & 8 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def __init__(self, head_dim, max_seq_len, device, rope_theta):
super(RoPE, self).__init__()

self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(
head_dim, max_seq_len * 2, rope_theta
).to(device),
self.precompute_freqs_cis(head_dim, max_seq_len * 2, rope_theta).to(device),
requires_grad=False,
)

Expand Down Expand Up @@ -118,6 +116,7 @@ def __init__(
use_alibi=False,
attention_shapes=None,
rope_theta=10000,
partial_rotary_factor=1.0,
head_dim=None,
**kwargs
):
Expand All @@ -127,7 +126,7 @@ def __init__(
self.n_kv_heads = n_kv_heads
self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
self.head_dim = head_dim

if head_dim is None:
self.head_dim = hidden_size // n_heads

Expand Down Expand Up @@ -167,8 +166,9 @@ def __init__(
self.is_neox = False
else:
self.alibi = None
self.rope = RoPE(self.head_dim, max_seq_len, dev, rope_theta)
self.rotary_dim = self.head_dim
self.partial_rotary_factor = partial_rotary_factor
self.rotary_dim = int(self.head_dim * self.partial_rotary_factor)
self.rope = RoPE(self.rotary_dim, max_seq_len, dev, rope_theta)
self.is_neox = True

def forward(
Expand Down Expand Up @@ -209,13 +209,27 @@ def forward(
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)

if seqlen > 1 or not FT_INSTALLED:
if seqlen > 1 or self.partial_rotary_factor < 1 or not FT_INSTALLED:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])

if not self.use_alibi:
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
# Partial rotary embedding
if self.partial_rotary_factor < 1:
xq_rot, xq_pass = (
xq[..., : self.rotary_dim],
xq[..., self.rotary_dim :],
)
xk_rot, xk_pass = (
xk[..., : self.rotary_dim],
xk[..., self.rotary_dim :],
)
xq_rot, xk_rot = self.rope.forward(xq_rot, xk_rot, self.start_pos, seqlen)
xq = torch.cat((xq_rot, xq_pass), dim=-1)
xk = torch.cat((xk_rot, xk_pass), dim=-1)
else:
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)

values_store = xv.transpose(2, 1)
keys_store = (
Expand Down
2 changes: 2 additions & 0 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
dev,
max_seq_len,
rope_theta=10000,
partial_rotary_factor=1.0,
use_alibi=False,
head_dim=None,
):
Expand All @@ -103,6 +104,7 @@ def __init__(
max_seq_len=max_seq_len,
use_alibi=use_alibi,
rope_theta=rope_theta,
partial_rotary_factor=partial_rotary_factor,
head_dim=head_dim,
).to(dev)
self.norm_2 = norm_2.to(dev)
Expand Down

0 comments on commit e9f6269

Please sign in to comment.