-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add some new features for layer-wise quant #1899
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -18,19 +18,24 @@ | |||||||||||||
|
||||||||||||||
import gc | ||||||||||||||
import json | ||||||||||||||
import logging | ||||||||||||||
import os | ||||||||||||||
import pickle | ||||||||||||||
from collections import OrderedDict | ||||||||||||||
from functools import partial | ||||||||||||||
|
||||||||||||||
import torch | ||||||||||||||
from accelerate import init_empty_weights | ||||||||||||||
from accelerate.utils import set_module_tensor_to_device | ||||||||||||||
from transformers import AutoConfig, AutoModelForCausalLM | ||||||||||||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass | ||||||||||||||
|
||||||||||||||
from neural_compressor.common import options | ||||||||||||||
|
||||||||||||||
from .load import load | ||||||||||||||
|
||||||||||||||
LWQ_WORKSPACE = os.path.join(options.workspace, "layer_wise_tmp") | ||||||||||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s") | ||||||||||||||
logger = logging.getLogger("layer_wise_tools") | ||||||||||||||
Comment on lines
+35
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
LWQ_WORKSPACE = os.path.join("layer_wise_tmp") | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
|
||||||||||||||
class QDQLayer(torch.nn.Module): | ||||||||||||||
|
@@ -121,7 +126,7 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): | |||||||||||||
return file_path | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): | ||||||||||||||
def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, save_path=None, **kwargs): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||
"""Load a empty model.""" | ||||||||||||||
is_local = os.path.isdir(pretrained_model_name_or_path) | ||||||||||||||
if is_local: # pragma: no cover | ||||||||||||||
|
@@ -139,6 +144,10 @@ def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, ** | |||||||||||||
model.tie_weights() | ||||||||||||||
model.eval() | ||||||||||||||
model.path = pretrained_model_name_or_path | ||||||||||||||
|
||||||||||||||
if save_path is None: | ||||||||||||||
save_path = LWQ_WORKSPACE | ||||||||||||||
Comment on lines
+148
to
+149
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove if use default value |
||||||||||||||
convert_model(model, save_path) | ||||||||||||||
return model | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
|
@@ -163,6 +172,40 @@ def update_module(model, module_name, new_module): | |||||||||||||
setattr(super_module, module_name.split(".")[-1], new_module) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def get_layers_before_block(model): | ||||||||||||||
"""Get the embed layers before blocks.""" | ||||||||||||||
return_layers = [] | ||||||||||||||
block_name = None | ||||||||||||||
|
||||||||||||||
def _forward(module, name, *args, **kwargs): | ||||||||||||||
if name == block_name: | ||||||||||||||
# if 'DecoderLayer' in name: | ||||||||||||||
raise NotImplementedError | ||||||||||||||
if len(module._modules) == 0: | ||||||||||||||
return_layers.append((name, module)) | ||||||||||||||
return module.ori_forward(*args, **kwargs) | ||||||||||||||
|
||||||||||||||
for n, m in model.named_modules(): | ||||||||||||||
if isinstance(m, torch.nn.ModuleList): | ||||||||||||||
block_name = n + "." + m.named_children().__next__()[0] | ||||||||||||||
m.ori_forward = m.forward | ||||||||||||||
m.forward = partial(_forward, m, n) | ||||||||||||||
|
||||||||||||||
try: | ||||||||||||||
model.forward( | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it applicable to all or most models? |
||||||||||||||
input_ids=torch.zeros((1, 1), device="meta", dtype=torch.int), | ||||||||||||||
attention_mask=torch.zeros((1, 1), device="meta", dtype=torch.int), | ||||||||||||||
) | ||||||||||||||
except NotImplementedError: | ||||||||||||||
pass | ||||||||||||||
|
||||||||||||||
for n, m in model.named_modules(): | ||||||||||||||
m.forward = m.ori_forward | ||||||||||||||
del m.ori_forward | ||||||||||||||
|
||||||||||||||
return return_layers | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def load_layer_wise_quantized_model(path): # pragma: no cover | ||||||||||||||
"""Load layer wise quantized model.""" | ||||||||||||||
model = torch.load(os.path.join(path, "model_arch.pt")) | ||||||||||||||
|
@@ -207,6 +250,8 @@ def load_tensor(path, tensor_name=None, prefix=None): | |||||||||||||
|
||||||||||||||
|
||||||||||||||
def _get_path(pretrained_model_name_or_path): | ||||||||||||||
if pretrained_model_name_or_path is None: | ||||||||||||||
return None | ||||||||||||||
is_local = os.path.isdir(pretrained_model_name_or_path) | ||||||||||||||
if is_local: # pragma: no cover | ||||||||||||||
path = pretrained_model_name_or_path | ||||||||||||||
|
@@ -216,6 +261,7 @@ def _get_path(pretrained_model_name_or_path): | |||||||||||||
|
||||||||||||||
|
||||||||||||||
def load_value(model, param_name, path): | ||||||||||||||
logger.debug(f"load value for layer: {param_name}") | ||||||||||||||
if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True): | ||||||||||||||
input_embeddings = model.get_input_embeddings() | ||||||||||||||
modules = get_named_children(model) | ||||||||||||||
|
@@ -244,21 +290,24 @@ def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_pa | |||||||||||||
|
||||||||||||||
def forward_pre_hook(name): | ||||||||||||||
def hook(module, input): | ||||||||||||||
logger.debug(f"{name} forward hood load value") | ||||||||||||||
state_dict = None | ||||||||||||||
if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")): | ||||||||||||||
state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt")) | ||||||||||||||
if os.path.exists(os.path.join(saved_path, f"{name}.pt")): | ||||||||||||||
state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) | ||||||||||||||
for n, p in module.named_parameters(): | ||||||||||||||
param_name = name + "." + n | ||||||||||||||
if state_dict: | ||||||||||||||
value = state_dict[n] | ||||||||||||||
else: | ||||||||||||||
value = load_value(model, param_name, path) | ||||||||||||||
set_module_tensor_to_device(model, param_name, device, value) | ||||||||||||||
module = module.to(device) | ||||||||||||||
|
||||||||||||||
return hook | ||||||||||||||
|
||||||||||||||
def forward_hook(name): | ||||||||||||||
def hook(module, input, output): | ||||||||||||||
logger.debug(f"{name} forward hood clean value") | ||||||||||||||
if saved_path: | ||||||||||||||
file_path = os.path.join(saved_path, f"{name}.pt") | ||||||||||||||
torch.save(module.state_dict(), file_path) | ||||||||||||||
|
@@ -294,3 +343,109 @@ def clean_module_weight(module): | |||||||||||||
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to("meta") | ||||||||||||||
submodule._parameters[n] = new_value | ||||||||||||||
gc.collect() | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def convert_model(empty_model, saved_path=LWQ_WORKSPACE): | ||||||||||||||
def _get_value(name, n): | ||||||||||||||
state_dict = None | ||||||||||||||
if os.path.exists(os.path.join(saved_path, f"{name}.pt")): | ||||||||||||||
state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) | ||||||||||||||
param_name = name + "." + n | ||||||||||||||
if state_dict: | ||||||||||||||
value = state_dict[n] | ||||||||||||||
else: | ||||||||||||||
value = load_value(empty_model, param_name, empty_model.path) | ||||||||||||||
return value | ||||||||||||||
|
||||||||||||||
def _update(module): | ||||||||||||||
state_dict = None | ||||||||||||||
if os.path.exists(os.path.join(saved_path, f"{name}.pt")): | ||||||||||||||
state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) | ||||||||||||||
for n, p in module.named_parameters(): | ||||||||||||||
if str(p.device) != "meta": | ||||||||||||||
continue | ||||||||||||||
param_name = name + "." + n | ||||||||||||||
if state_dict: | ||||||||||||||
value = state_dict[n] | ||||||||||||||
else: | ||||||||||||||
value = load_value(empty_model, param_name, saved_path) | ||||||||||||||
set_module_tensor_to_device(empty_model, param_name, "cpu", value) | ||||||||||||||
file_path = os.path.join(saved_path, f"{name}.pt") | ||||||||||||||
torch.save(module.state_dict(), file_path) | ||||||||||||||
|
||||||||||||||
def _layer_wise_to(module, name, device_or_dtype): | ||||||||||||||
if isinstance(device_or_dtype, torch.dtype): | ||||||||||||||
return module.ori_to(device_or_dtype) | ||||||||||||||
elif len(module._modules) == 0: | ||||||||||||||
# skip method type | ||||||||||||||
if len(module._parameters) == 0 or module.weight.device.type != "meta": | ||||||||||||||
return module.ori_to(device_or_dtype) | ||||||||||||||
else: | ||||||||||||||
for n, _ in module.named_parameters(): | ||||||||||||||
param_name = name + "." + n | ||||||||||||||
value = load_value(empty_model, param_name, empty_model.path) | ||||||||||||||
dtype = None | ||||||||||||||
if hasattr(module, "dtype"): | ||||||||||||||
dtype = module.dtype | ||||||||||||||
set_module_tensor_to_device(module, n, device_or_dtype, value, dtype=dtype) | ||||||||||||||
return module.ori_to(device_or_dtype) | ||||||||||||||
else: | ||||||||||||||
for n, m in module.named_children(): | ||||||||||||||
m.to(device_or_dtype) | ||||||||||||||
return module | ||||||||||||||
|
||||||||||||||
modules = get_named_children(empty_model) | ||||||||||||||
for name, module in modules: | ||||||||||||||
if hasattr(module, "weight"): | ||||||||||||||
# delattr(module, 'weight') | ||||||||||||||
# module.weight = partial(_get_value, name, 'weight')() | ||||||||||||||
module.get_weight = partial(_get_value, name, "weight") | ||||||||||||||
if hasattr(module, "bias") and module.bias is not None: | ||||||||||||||
module.get_bias = partial(_get_value, name, "bias") | ||||||||||||||
module.update = partial(_update, module) | ||||||||||||||
|
||||||||||||||
def _repalce_to(module, name): | ||||||||||||||
if len(module._modules) > 0: | ||||||||||||||
for n, m in module.named_children(): | ||||||||||||||
if len(name) > 0: | ||||||||||||||
n = name + "." + n | ||||||||||||||
_repalce_to(m, n) | ||||||||||||||
module.ori_to = module.to | ||||||||||||||
module.to = partial(_layer_wise_to, module, name) | ||||||||||||||
|
||||||||||||||
_repalce_to(empty_model, "") | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def load_model_with_hooks( | ||||||||||||||
pretrained_model_name_or_path, cls=AutoModelForCausalLM, device="cpu", clean_weight=True, saved_path=None, **kwargs | ||||||||||||||
): | ||||||||||||||
if saved_path is None: | ||||||||||||||
saved_path = LWQ_WORKSPACE | ||||||||||||||
Comment on lines
+420
to
+423
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
empty_model = load_empty_model(pretrained_model_name_or_path, cls=cls, **kwargs) | ||||||||||||||
register_weight_hooks(empty_model, empty_model.path, device, clean_weight, saved_path) | ||||||||||||||
return empty_model | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def layer_wise_save(model, path): | ||||||||||||||
os.makedirs(path, exist_ok=True) | ||||||||||||||
file_path = os.path.join(path, "layer_wise_model.bin") | ||||||||||||||
modules = get_named_children(model) | ||||||||||||||
with open(file_path, "wb") as f: | ||||||||||||||
for name, module in modules: | ||||||||||||||
output = OrderedDict() | ||||||||||||||
if hasattr(module, "get_weight"): | ||||||||||||||
output[f"{name}.weight"] = module.get_weight() | ||||||||||||||
if hasattr(module, "get_bias"): | ||||||||||||||
output[f"{name}.bias"] = module.get_bias() | ||||||||||||||
output = pickle.dumps(output) | ||||||||||||||
f.write(output + b"split_tag") | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def layer_wise_load(path): | ||||||||||||||
file_path = os.path.join(path, "layer_wise_model.bin") | ||||||||||||||
state_dict = OrderedDict() | ||||||||||||||
data = open(file_path, "rb").read().split(b"split_tag") | ||||||||||||||
for d in data: | ||||||||||||||
if len(d) > 0: | ||||||||||||||
d = pickle.loads(d) | ||||||||||||||
state_dict.update(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.