Skip to content

Commit

Permalink
layer-wise quant for pytorch weight only (#1255)
Browse files Browse the repository at this point in the history
* support layer-wise quant for weight only - rtn

Signed-off-by: Guo, Heng <[email protected]>
  • Loading branch information
n1ck-guo authored Sep 25, 2023
1 parent 595d3a1 commit ebd1e24
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 83 deletions.
29 changes: 29 additions & 0 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4532,6 +4532,19 @@ def rtn_quantize(self, model, tune_cfg):
from .torch_utils.util import fetch_module, set_module
from .torch_utils.weight_only import rtn_quantize

# for layer_wise quant mode
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if recipe_cfgs.get("layer_wise_quant", False):
from neural_compressor.config import options

from .torch_utils.layer_wise_quant.utils import _get_path, load_module

lwq_workspace = os.path.join(options.workspace, "lwq_tmpdir")
os.makedirs(lwq_workspace, exist_ok=True)
model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
assert model_path, "model_path should specify in layer_wise_quant_args."
model_path = _get_path(model_path)

for key, config in tune_cfg["op"].items():
op_name, op_type = key
if config["weight"]["dtype"] == "fp32":
Expand All @@ -4545,6 +4558,11 @@ def rtn_quantize(self, model, tune_cfg):
if algorithm != "RTN":
continue
m = fetch_module(model, op_name)
# load weight if use layer-wise quant mode
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if recipe_cfgs.get("layer_wise_quant", False):
# load weight
load_module(model, op_name, model_path, device=self.device)
m = rtn_quantize(
m,
num_bits,
Expand All @@ -4556,7 +4574,18 @@ def rtn_quantize(self, model, tune_cfg):
enable_mse_search=enable_mse_search,
group_dim=group_dim,
)
if recipe_cfgs.get("layer_wise_quant", False):
# save and clean weight
from .torch_utils.layer_wise_quant.utils import clean_module_weight

torch.save(m.state_dict(), os.path.join(lwq_workspace, f"{op_name}.pt"))
clean_module_weight(m)
set_module(model, op_name, m)
if recipe_cfgs.get("layer_wise_quant", False):
# register hooks
from .torch_utils.layer_wise_quant.utils import register_weight_hooks

register_weight_hooks(model, model_path, device=self.device, clean_weight=True)
return model

def gptq_quantize(self, model, tune_cfg, dataloader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from typing import IO, Any, BinaryIO, Callable, Dict, Optional, Union

from packaging.version import Version
from torch.serialization import (
StorageType,
_get_restore_location,
Expand All @@ -35,31 +36,36 @@

from .utils import torch

torch_version = torch.__version__.split("+")[0]
version = Version(torch_version)

FILE_LIKE = Union[str, os.PathLike, BinaryIO, IO[bytes]]
MAP_LOCATION = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]]

if version.release < Version("1.13.0").release:
UntypedStorage = torch._UntypedStorage
else:
UntypedStorage = torch.UntypedStorage


def _load(zip_file, tensor_name, prefix, map_location, pickle_module, pickle_file="data.pkl", **pickle_load_args):
restore_location = _get_restore_location(map_location)

loaded_storages = {}

from packaging.version import Version

torch_version = torch.__version__.split("+")[0]
version = Version(torch_version)

def load_tensor(dtype, numel, key, location):
name = f"data/{key}"

if version.release < Version("2.0.0").release: # pragma: no cover
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage).storage().untyped()
if version.release < Version("1.13.0").release:
storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()
typed_storage = torch.storage._TypedStorage(wrap_storage=restore_location(storage, location), dtype=dtype)
loaded_storages[key] = typed_storage
elif version.release < Version("2.0.0").release: # pragma: no cover
storage = zip_file.get_storage_from_record(name, numel, UntypedStorage).storage().untyped()
typed_storage = torch.storage.TypedStorage(wrap_storage=restore_location(storage, location), dtype=dtype)
loaded_storages[key] = typed_storage
else:
storage = (
zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
)
storage = zip_file.get_storage_from_record(name, numel, UntypedStorage)._typed_storage()._untyped_storage
typed_storage = torch.storage.TypedStorage(
wrap_storage=restore_location(storage, location), dtype=dtype, _internal=True
)
Expand All @@ -69,28 +75,6 @@ def load_tensor(dtype, numel, key, location):

return typed_storage

# def persistent_load(saved_id):
# print(saved_id)
# assert isinstance(saved_id, tuple)
# typename = _maybe_decode_ascii(saved_id[0])
# data = saved_id[1:]

# assert typename == 'storage', \
# f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
# storage_type, key, location, numel = data
# if storage_type is torch.UntypedStorage:
# dtype = torch.uint8
# else:
# dtype = storage_type.dtype

# if key in loaded_storages:
# typed_storage = loaded_storages[key]
# else:
# nbytes = numel * torch._utils._element_size(dtype)
# typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))

# return typed_storage

load_module_mapping: Dict[str, str] = {"torch.tensor": "torch._tensor"}

# Need to subclass Unpickler instead of directly monkey-patching the find_class method
Expand All @@ -115,7 +99,8 @@ def persistent_load(self, saved_id):
typename == "storage"
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, numel = data
if storage_type is torch.UntypedStorage: # pragma: no cover

if storage_type is UntypedStorage: # pragma: no cover
dtype = torch.uint8
else:
dtype = storage_type.dtype
Expand All @@ -126,7 +111,8 @@ def persistent_load(self, saved_id):
name_list = [self.tensor_name]
if prefix:
no_prefix_name = self.tensor_name.split(".")
no_prefix_name.remove(prefix)
if prefix in no_prefix_name:
no_prefix_name.remove(prefix)
no_prefix_name = ".".join(no_prefix_name)
name_list.append(no_prefix_name)
if self.tensor_name and self.metastack[-1][-2] not in name_list:
Expand Down
47 changes: 33 additions & 14 deletions neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
from transformers import AutoConfig
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from ....config import options
from ..model_wrapper import QDQLayer
from ..util import logger
from .torch_load import load

LWQ_WORKSPACE = os.path.join(options.workspace, "lwq_tmpdir")


def get_module(model, key):
"""Get module from model by key name.
Expand Down Expand Up @@ -197,25 +200,41 @@ def _get_path(pretrained_model_name_or_path):
return path


def load_value(model, param_name, path):
if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True):
input_embeddings = model.get_input_embeddings()
modules = get_named_children(model)
for name, module in modules:
if module == input_embeddings:
param_name = name + "." + param_name.split(".")[-1]
prefix = model.base_model_prefix
if "pytorch_model.bin.index.json" in os.listdir(path):
value = load_tensor_from_shard(path, param_name, prefix)
else:
value = load_tensor(os.path.join(path, "pytorch_model.bin"), param_name, prefix)
return value


def load_module(model, module_name, path, device="cpu"):
module = get_module(model, module_name)
for n, p in module.named_parameters():
param_name = module_name + "." + n
value = load_value(model, param_name, path)
set_module_tensor_to_device(model, param_name, device, value)


def register_weight_hooks(model, path, device="cpu", clean_weight=True):
def forward_pre_hook(name):
def load_value(param_name):
if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True):
input_embeddings = model.get_input_embeddings()
for name, module in modules:
if module == input_embeddings:
param_name = name + "." + param_name.split(".")[-1]
prefix = model.base_model_prefix
if "pytorch_model.bin.index.json" in os.listdir(path):
value = load_tensor_from_shard(path, param_name, prefix)
else:
value = load_tensor(os.path.join(path, "pytorch_model.bin"), param_name, prefix)
return value

def hook(module, input):
state_dict = None
if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")):
state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt"))
for n, p in module.named_parameters():
param_name = name + "." + n
value = load_value(param_name)
if state_dict:
value = state_dict[n]
else:
value = load_value(model, param_name, path)
set_module_tensor_to_device(model, param_name, device, value)

return hook
Expand Down
28 changes: 28 additions & 0 deletions neural_compressor/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,34 @@ def save(self, root=None):
gptq_config_path = os.path.join(root, "gptq_config.json")
with open(gptq_config_path, "w") as f:
json.dump(self.gptq_config, f, indent=4)
# for layer_wise quant mode
if self.q_config["recipe_cfgs"].get("layer_wise_quant", False):
from ..adaptor.torch_utils.layer_wise_quant.utils import (
LWQ_WORKSPACE,
_get_path,
get_named_children,
load_value,
set_module_tensor_to_device,
)

modules = get_named_children(self._model)
for name, module in modules:
state_dict = None
if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")):
state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt"))
model_path = _get_path(
self.q_config["recipe_cfgs"]["layer_wise_quant_args"].get("model_path")
)
for n, p in module.named_parameters():
param_name = name + "." + n
if state_dict:
value = state_dict[n]
else:
value = load_value(self._model, param_name, model_path)
# set_module_tensor_to_device(self._model, param_name, "cpu", value)
torch.save(value, os.path.join(root, f"{param_name}.pt"))
# stat_dict = self._model.state_dict()
return
else:
stat_dict["best_configure"] = self.q_config
torch.save(stat_dict, os.path.join(root, "best_model.pt"))
Expand Down
21 changes: 17 additions & 4 deletions neural_compressor/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _load_int8_orchestration(model, tune_cfg, stat_dict, example_inputs, **kwarg
return model


def load_weight_only(checkpoint_dir, model):
def load_weight_only(checkpoint_dir, model, layer_wise=False):
"""Load model in weight_only mode.
Args:
Expand Down Expand Up @@ -226,12 +226,25 @@ def load_weight_only(checkpoint_dir, model):
module = util.fetch_module(model, op_name)
new_module = MulLinear(module)
util.set_module(model, op_name, new_module)
model.load_state_dict(torch.load(weights_file))
if layer_wise or (hasattr(model, "device") and str(model.device)) == "meta":
from ..adaptor.torch_utils.layer_wise_quant.utils import get_named_children, set_module_tensor_to_device

# state_dict = torch.load(weights_file)
modules = get_named_children(model)
for name, module in modules:
for n, p in module.named_parameters():
param_name = name + "." + n
value = torch.load(
os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), f"{param_name}.pt")
)
set_module_tensor_to_device(model, param_name, "cpu", value)
else:
model.load_state_dict(torch.load(weights_file))
logger.info("Load weight_only quantized model")
return model


def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
def load(checkpoint_dir=None, model=None, layer_wise=False, history_cfg=None, **kwargs):
"""Execute the quantize process on the specified model.
Args:
Expand All @@ -248,7 +261,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
"""
weigth_only = kwargs.get("weight_only", False)
if weigth_only:
return load_weight_only(checkpoint_dir, model)
return load_weight_only(checkpoint_dir, model, layer_wise=layer_wise)
if checkpoint_dir is not None:
if isinstance(checkpoint_dir, dict):
stat_dict = checkpoint_dir
Expand Down
59 changes: 59 additions & 0 deletions test/algorithm/test_lwq_weight_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import shutil
import sys
import unittest

sys.path.insert(0, "./")
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell
from neural_compressor.utils.pytorch import load


class TestLayerWise(unittest.TestCase):
def test_layer_wise(self):
model_name_or_path = "facebook/opt-125m"
fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True)

class TestDataset(Dataset):
def __init__(self, size=5, shape=128):
self.len = size
self.input_ids = torch.randint(low=0, high=30522, size=(size, shape), dtype=torch.int64)

def __getitem__(self, index):
return self.input_ids[index]

def __len__(self):
return self.len

eval_dataset = TestDataset()
eval_dataloader = DataLoader(eval_dataset, batch_size=8)

conf = PostTrainingQuantConfig(
approach="weight_only",
recipes={
"layer_wise_quant": True,
"layer_wise_quant_args": {
"model_path": "facebook/opt-125m",
},
"rtn_args": {"enable_full_range": True},
},
)

q_model = quantization.fit(
fp32_model,
conf,
calib_dataloader=eval_dataloader,
eval_func=lambda x: 0.1,
)
ouput_dir = "./saved_model"
q_model.save(ouput_dir)
load_model = load(ouput_dir, fp32_model, weight_only=True)
self.assertNotEqual(load_model.lm_head.weight.device.type, "meta")
shutil.rmtree(ouput_dir)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit ebd1e24

Please sign in to comment.