diff --git a/neural_compressor/torch/algorithms/layer_wise/utils.py b/neural_compressor/torch/algorithms/layer_wise/utils.py index 464a25cdee0..974d744f45d 100644 --- a/neural_compressor/torch/algorithms/layer_wise/utils.py +++ b/neural_compressor/torch/algorithms/layer_wise/utils.py @@ -18,7 +18,11 @@ import gc import json +import logging import os +import pickle +from collections import OrderedDict +from functools import partial import torch from accelerate import init_empty_weights @@ -26,11 +30,12 @@ from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass -from neural_compressor.common import options - from .load import load -LWQ_WORKSPACE = os.path.join(options.workspace, "layer_wise_tmp") +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s") +logger = logging.getLogger("layer_wise_tools") + +LWQ_WORKSPACE = os.path.join("layer_wise_tmp") class QDQLayer(torch.nn.Module): @@ -121,7 +126,7 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): return file_path -def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): +def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, save_path=None, **kwargs): """Load a empty model.""" is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: # pragma: no cover @@ -139,6 +144,10 @@ def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, ** model.tie_weights() model.eval() model.path = pretrained_model_name_or_path + + if save_path is None: + save_path = LWQ_WORKSPACE + convert_model(model, save_path) return model @@ -163,6 +172,40 @@ def update_module(model, module_name, new_module): setattr(super_module, module_name.split(".")[-1], new_module) +def get_layers_before_block(model): + """Get the embed layers before blocks.""" + return_layers = [] + block_name = None + + def _forward(module, name, *args, **kwargs): + if name == block_name: + # if 'DecoderLayer' in name: + raise NotImplementedError + if len(module._modules) == 0: + return_layers.append((name, module)) + return module.ori_forward(*args, **kwargs) + + for n, m in model.named_modules(): + if isinstance(m, torch.nn.ModuleList): + block_name = n + "." + m.named_children().__next__()[0] + m.ori_forward = m.forward + m.forward = partial(_forward, m, n) + + try: + model.forward( + input_ids=torch.zeros((1, 1), device="meta", dtype=torch.int), + attention_mask=torch.zeros((1, 1), device="meta", dtype=torch.int), + ) + except NotImplementedError: + pass + + for n, m in model.named_modules(): + m.forward = m.ori_forward + del m.ori_forward + + return return_layers + + def load_layer_wise_quantized_model(path): # pragma: no cover """Load layer wise quantized model.""" model = torch.load(os.path.join(path, "model_arch.pt")) @@ -207,6 +250,8 @@ def load_tensor(path, tensor_name=None, prefix=None): def _get_path(pretrained_model_name_or_path): + if pretrained_model_name_or_path is None: + return None is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: # pragma: no cover path = pretrained_model_name_or_path @@ -216,6 +261,7 @@ def _get_path(pretrained_model_name_or_path): def load_value(model, param_name, path): + logger.debug(f"load value for layer: {param_name}") if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True): input_embeddings = model.get_input_embeddings() modules = get_named_children(model) @@ -244,9 +290,10 @@ def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_pa def forward_pre_hook(name): def hook(module, input): + logger.debug(f"{name} forward hood load value") state_dict = None - if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")): - state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt")) + if os.path.exists(os.path.join(saved_path, f"{name}.pt")): + state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) for n, p in module.named_parameters(): param_name = name + "." + n if state_dict: @@ -254,11 +301,13 @@ def hook(module, input): else: value = load_value(model, param_name, path) set_module_tensor_to_device(model, param_name, device, value) + module = module.to(device) return hook def forward_hook(name): def hook(module, input, output): + logger.debug(f"{name} forward hood clean value") if saved_path: file_path = os.path.join(saved_path, f"{name}.pt") torch.save(module.state_dict(), file_path) @@ -294,3 +343,109 @@ def clean_module_weight(module): new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to("meta") submodule._parameters[n] = new_value gc.collect() + + +def convert_model(empty_model, saved_path=LWQ_WORKSPACE): + def _get_value(name, n): + state_dict = None + if os.path.exists(os.path.join(saved_path, f"{name}.pt")): + state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) + param_name = name + "." + n + if state_dict: + value = state_dict[n] + else: + value = load_value(empty_model, param_name, empty_model.path) + return value + + def _update(module): + state_dict = None + if os.path.exists(os.path.join(saved_path, f"{name}.pt")): + state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) + for n, p in module.named_parameters(): + if str(p.device) != "meta": + continue + param_name = name + "." + n + if state_dict: + value = state_dict[n] + else: + value = load_value(empty_model, param_name, saved_path) + set_module_tensor_to_device(empty_model, param_name, "cpu", value) + file_path = os.path.join(saved_path, f"{name}.pt") + torch.save(module.state_dict(), file_path) + + def _layer_wise_to(module, name, device_or_dtype): + if isinstance(device_or_dtype, torch.dtype): + return module.ori_to(device_or_dtype) + elif len(module._modules) == 0: + # skip method type + if len(module._parameters) == 0 or module.weight.device.type != "meta": + return module.ori_to(device_or_dtype) + else: + for n, _ in module.named_parameters(): + param_name = name + "." + n + value = load_value(empty_model, param_name, empty_model.path) + dtype = None + if hasattr(module, "dtype"): + dtype = module.dtype + set_module_tensor_to_device(module, n, device_or_dtype, value, dtype=dtype) + return module.ori_to(device_or_dtype) + else: + for n, m in module.named_children(): + m.to(device_or_dtype) + return module + + modules = get_named_children(empty_model) + for name, module in modules: + if hasattr(module, "weight"): + # delattr(module, 'weight') + # module.weight = partial(_get_value, name, 'weight')() + module.get_weight = partial(_get_value, name, "weight") + if hasattr(module, "bias") and module.bias is not None: + module.get_bias = partial(_get_value, name, "bias") + module.update = partial(_update, module) + + def _repalce_to(module, name): + if len(module._modules) > 0: + for n, m in module.named_children(): + if len(name) > 0: + n = name + "." + n + _repalce_to(m, n) + module.ori_to = module.to + module.to = partial(_layer_wise_to, module, name) + + _repalce_to(empty_model, "") + + +def load_model_with_hooks( + pretrained_model_name_or_path, cls=AutoModelForCausalLM, device="cpu", clean_weight=True, saved_path=None, **kwargs +): + if saved_path is None: + saved_path = LWQ_WORKSPACE + empty_model = load_empty_model(pretrained_model_name_or_path, cls=cls, **kwargs) + register_weight_hooks(empty_model, empty_model.path, device, clean_weight, saved_path) + return empty_model + + +def layer_wise_save(model, path): + os.makedirs(path, exist_ok=True) + file_path = os.path.join(path, "layer_wise_model.bin") + modules = get_named_children(model) + with open(file_path, "wb") as f: + for name, module in modules: + output = OrderedDict() + if hasattr(module, "get_weight"): + output[f"{name}.weight"] = module.get_weight() + if hasattr(module, "get_bias"): + output[f"{name}.bias"] = module.get_bias() + output = pickle.dumps(output) + f.write(output + b"split_tag") + + +def layer_wise_load(path): + file_path = os.path.join(path, "layer_wise_model.bin") + state_dict = OrderedDict() + data = open(file_path, "rb").read().split(b"split_tag") + for d in data: + if len(d) > 0: + d = pickle.loads(d) + state_dict.update(d)