Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add some new features for layer-wise quant #1899

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 161 additions & 6 deletions neural_compressor/torch/algorithms/layer_wise/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,24 @@

import gc
import json
import logging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import logging

import os
import pickle
from collections import OrderedDict
from functools import partial

import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from neural_compressor.common import options

from .load import load

LWQ_WORKSPACE = os.path.join(options.workspace, "layer_wise_tmp")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s")
logger = logging.getLogger("layer_wise_tools")
Comment on lines +35 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s")
logger = logging.getLogger("layer_wise_tools")
from neural_compressor.torch.utils import logger


LWQ_WORKSPACE = os.path.join("layer_wise_tmp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LWQ_WORKSPACE = os.path.join("layer_wise_tmp")
from neural_compressor.common import options
LWQ_WORKSPACE = os.path.join(options.workspace, "lwq_tmpdir")



class QDQLayer(torch.nn.Module):
Expand Down Expand Up @@ -121,7 +126,7 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
return file_path


def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs):
def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, save_path=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_path defaults to LWQ_WORKSPACE

"""Load a empty model."""
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local: # pragma: no cover
Expand All @@ -139,6 +144,10 @@ def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **
model.tie_weights()
model.eval()
model.path = pretrained_model_name_or_path

if save_path is None:
save_path = LWQ_WORKSPACE
Comment on lines +148 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove if use default value

convert_model(model, save_path)
return model


Expand All @@ -163,6 +172,40 @@ def update_module(model, module_name, new_module):
setattr(super_module, module_name.split(".")[-1], new_module)


def get_layers_before_block(model):
"""Get the embed layers before blocks."""
return_layers = []
block_name = None

def _forward(module, name, *args, **kwargs):
if name == block_name:
# if 'DecoderLayer' in name:
raise NotImplementedError
if len(module._modules) == 0:
return_layers.append((name, module))
return module.ori_forward(*args, **kwargs)

for n, m in model.named_modules():
if isinstance(m, torch.nn.ModuleList):
block_name = n + "." + m.named_children().__next__()[0]
m.ori_forward = m.forward
m.forward = partial(_forward, m, n)

try:
model.forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it applicable to all or most models?

input_ids=torch.zeros((1, 1), device="meta", dtype=torch.int),
attention_mask=torch.zeros((1, 1), device="meta", dtype=torch.int),
)
except NotImplementedError:
pass

for n, m in model.named_modules():
m.forward = m.ori_forward
del m.ori_forward

return return_layers


def load_layer_wise_quantized_model(path): # pragma: no cover
"""Load layer wise quantized model."""
model = torch.load(os.path.join(path, "model_arch.pt"))
Expand Down Expand Up @@ -207,6 +250,8 @@ def load_tensor(path, tensor_name=None, prefix=None):


def _get_path(pretrained_model_name_or_path):
if pretrained_model_name_or_path is None:
return None
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local: # pragma: no cover
path = pretrained_model_name_or_path
Expand All @@ -216,6 +261,7 @@ def _get_path(pretrained_model_name_or_path):


def load_value(model, param_name, path):
logger.debug(f"load value for layer: {param_name}")
if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True):
input_embeddings = model.get_input_embeddings()
modules = get_named_children(model)
Expand Down Expand Up @@ -244,21 +290,24 @@ def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_pa

def forward_pre_hook(name):
def hook(module, input):
logger.debug(f"{name} forward hood load value")
state_dict = None
if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")):
state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt"))
if os.path.exists(os.path.join(saved_path, f"{name}.pt")):
state_dict = torch.load(os.path.join(saved_path, f"{name}.pt"))
for n, p in module.named_parameters():
param_name = name + "." + n
if state_dict:
value = state_dict[n]
else:
value = load_value(model, param_name, path)
set_module_tensor_to_device(model, param_name, device, value)
module = module.to(device)

return hook

def forward_hook(name):
def hook(module, input, output):
logger.debug(f"{name} forward hood clean value")
if saved_path:
file_path = os.path.join(saved_path, f"{name}.pt")
torch.save(module.state_dict(), file_path)
Expand Down Expand Up @@ -294,3 +343,109 @@ def clean_module_weight(module):
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to("meta")
submodule._parameters[n] = new_value
gc.collect()


def convert_model(empty_model, saved_path=LWQ_WORKSPACE):
def _get_value(name, n):
state_dict = None
if os.path.exists(os.path.join(saved_path, f"{name}.pt")):
state_dict = torch.load(os.path.join(saved_path, f"{name}.pt"))
param_name = name + "." + n
if state_dict:
value = state_dict[n]
else:
value = load_value(empty_model, param_name, empty_model.path)
return value

def _update(module):
state_dict = None
if os.path.exists(os.path.join(saved_path, f"{name}.pt")):
state_dict = torch.load(os.path.join(saved_path, f"{name}.pt"))
for n, p in module.named_parameters():
if str(p.device) != "meta":
continue
param_name = name + "." + n
if state_dict:
value = state_dict[n]
else:
value = load_value(empty_model, param_name, saved_path)
set_module_tensor_to_device(empty_model, param_name, "cpu", value)
file_path = os.path.join(saved_path, f"{name}.pt")
torch.save(module.state_dict(), file_path)

def _layer_wise_to(module, name, device_or_dtype):
if isinstance(device_or_dtype, torch.dtype):
return module.ori_to(device_or_dtype)
elif len(module._modules) == 0:
# skip method type
if len(module._parameters) == 0 or module.weight.device.type != "meta":
return module.ori_to(device_or_dtype)
else:
for n, _ in module.named_parameters():
param_name = name + "." + n
value = load_value(empty_model, param_name, empty_model.path)
dtype = None
if hasattr(module, "dtype"):
dtype = module.dtype
set_module_tensor_to_device(module, n, device_or_dtype, value, dtype=dtype)
return module.ori_to(device_or_dtype)
else:
for n, m in module.named_children():
m.to(device_or_dtype)
return module

modules = get_named_children(empty_model)
for name, module in modules:
if hasattr(module, "weight"):
# delattr(module, 'weight')
# module.weight = partial(_get_value, name, 'weight')()
module.get_weight = partial(_get_value, name, "weight")
if hasattr(module, "bias") and module.bias is not None:
module.get_bias = partial(_get_value, name, "bias")
module.update = partial(_update, module)

def _repalce_to(module, name):
if len(module._modules) > 0:
for n, m in module.named_children():
if len(name) > 0:
n = name + "." + n
_repalce_to(m, n)
module.ori_to = module.to
module.to = partial(_layer_wise_to, module, name)

_repalce_to(empty_model, "")


def load_model_with_hooks(
pretrained_model_name_or_path, cls=AutoModelForCausalLM, device="cpu", clean_weight=True, saved_path=None, **kwargs
):
if saved_path is None:
saved_path = LWQ_WORKSPACE
Comment on lines +420 to +423
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pretrained_model_name_or_path, cls=AutoModelForCausalLM, device="cpu", clean_weight=True, saved_path=None, **kwargs
):
if saved_path is None:
saved_path = LWQ_WORKSPACE
pretrained_model_name_or_path, cls=AutoModelForCausalLM, device="cpu", clean_weight=True, saved_path=LWQ_WORKSPACE, **kwargs
):

empty_model = load_empty_model(pretrained_model_name_or_path, cls=cls, **kwargs)
register_weight_hooks(empty_model, empty_model.path, device, clean_weight, saved_path)
return empty_model


def layer_wise_save(model, path):
os.makedirs(path, exist_ok=True)
file_path = os.path.join(path, "layer_wise_model.bin")
modules = get_named_children(model)
with open(file_path, "wb") as f:
for name, module in modules:
output = OrderedDict()
if hasattr(module, "get_weight"):
output[f"{name}.weight"] = module.get_weight()
if hasattr(module, "get_bias"):
output[f"{name}.bias"] = module.get_bias()
output = pickle.dumps(output)
f.write(output + b"split_tag")


def layer_wise_load(path):
file_path = os.path.join(path, "layer_wise_model.bin")
state_dict = OrderedDict()
data = open(file_path, "rb").read().split(b"split_tag")
for d in data:
if len(d) > 0:
d = pickle.loads(d)
state_dict.update(d)
Loading