Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support minicpm3.0 #605

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion awq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@
from .deepseek_v2 import DeepseekV2AWQForCausalLM
from .minicpm import MiniCPMAWQForCausalLM
from .internlm2 import InternLM2AWQForCausalLM
from .qwen2vl import Qwen2VLAWQForCausalLM
from .minicpm3 import MiniCPM3AWQForCausalLM
from .qwen2vl import Qwen2VLAWQForCausalLM
1 change: 1 addition & 0 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"deepseek_v2": DeepseekV2AWQForCausalLM,
"minicpm": MiniCPMAWQForCausalLM,
"internlm2": InternLM2AWQForCausalLM,
"minicpm3": MiniCPM3AWQForCausalLM,
"qwen2_vl": Qwen2VLAWQForCausalLM,
}

Expand Down
1 change: 1 addition & 0 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"cohere": "AutoModelForCausalLM",
"deepseek_v2": "AutoModelForCausalLM",
"minicpm": "AutoModelForCausalLM",
"minicpm3":"AutoModelForCausalLM",
"internlm2": "AutoModelForCausalLM",
"qwen2_vl": "AutoModelForVision2Seq",
}
Expand Down
69 changes: 69 additions & 0 deletions awq/models/minicpm3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .base import BaseAWQForCausalLM

class MiniCPM3AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MiniCPMDecoderLayer"
max_seq_len_key = "max_position_embeddings"

@staticmethod
def get_model_layers(model):
print(model.model.layers)
return model.model.layers

@staticmethod
def get_act_for_scaling(module):
return dict(is_scalable=False)

@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []

# mlp
layers.append(
dict(
prev_op=module.self_attn.q_a_layernorm,
layers=[
module.self_attn.q_b_proj,

],
inp=input_feat["self_attn.q_b_proj"],
module2inspect=module.self_attn.q_b_proj,
kwargs=module_kwargs,
)
)

layers.append(
dict(
prev_op=module.self_attn.kv_a_layernorm,
layers=[
module.self_attn.kv_b_proj,
],
inp=input_feat["self_attn.kv_b_proj"],
module2inspect=module.self_attn.kv_b_proj,
kwargs=module_kwargs,
)
)


# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)

layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj,module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp
)
)

return layers
96 changes: 96 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,102 @@ model.model.config.use_cache = model.model.generation_config.use_cache = True
model.save_quantized(quant_path, safetensors=True, shard_size="4GB")
```

### Another Custom Quantizer (MiniCPM3 Example)

Here we introduce another custom quantizer from the MiniCPM team at OpenBMB. We only
modify the weight clipping mechanism to make quantization work.

```python
import torch
from transformers import AutoTokenizer

from awq import AutoAWQForCausalLM
from awq.quantize.quantizer import AwqQuantizer, clear_memory

class CPM3AwqQuantizer(AwqQuantizer):
@torch.no_grad()
def _compute_best_clip(
self,
w: torch.Tensor,
input_feat: torch.Tensor,
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)

# Compute input feature step size (minimum 1)
step_size = max(1, input_feat.shape[1] // n_sample_token)
input_feat = input_feat[:, ::step_size]

w = w.reshape(org_w_shape[0], 1, -1, group_size)

oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM
if org_w_shape[0] % oc_batch_size != 0:
oc_batch_size = org_w_shape[0]
assert org_w_shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []

for i_b in range(org_w_shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]

org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1

best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group

for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = -max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w)[0]
cur_out = (input_feat * q_w).sum(dim=-1)

# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)

best_max_val = torch.cat(best_max_val_all, dim=0)

clear_memory(input_feat)
clear_memory(org_out)

return best_max_val.squeeze(1)

model_path = 'openbmb/MiniCPM3-4B'
quant_path = 'minicpm3-4b-awq'
quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, safetensors=False
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config, quantizer_cls=CPM3AwqQuantizer)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')
```

## Basic Inference

### Inference With GPU
Expand Down