From 0eced1478c6796a5e2dcb254a65bbc96af4d1b8b Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 12 Jun 2024 18:49:17 -0700 Subject: [PATCH] Enhance INC WOQ model loading & support Huggingface WOQ model loading (#1826) Signed-off-by: yuwenzho --- docs/3x/PT_WeightOnlyQuant.md | 16 +- .../quantization/habana_fp8/run_llm.py | 12 +- .../quantization/llm/run_clm_no_trainer.py | 5 +- .../torch/algorithms/weight_only/__init__.py | 3 + .../torch/algorithms/weight_only/save_load.py | 604 +++++++++++++++++- .../torch/algorithms/weight_only/utility.py | 26 + .../torch/quantization/load_entry.py | 87 ++- neural_compressor/torch/utils/constants.py | 9 + .../weight_only/test_woq_utility.py | 18 + .../weight_only/test_autoround.py | 2 +- .../quantization/weight_only/test_awq.py | 2 +- .../quantization/weight_only/test_gptq.py | 2 +- .../weight_only/test_load_woq_hf_model.py | 24 + .../quantization/weight_only/test_rtn.py | 2 +- .../quantization/weight_only/test_teq.py | 2 +- 15 files changed, 762 insertions(+), 52 deletions(-) create mode 100644 test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py diff --git a/docs/3x/PT_WeightOnlyQuant.md b/docs/3x/PT_WeightOnlyQuant.md index e7e5c543215..37cc934592a 100644 --- a/docs/3x/PT_WeightOnlyQuant.md +++ b/docs/3x/PT_WeightOnlyQuant.md @@ -31,13 +31,13 @@ Theoretically, round-to-nearest (RTN) is the most straightforward way to quantiz ## Supported Matrix -| Algorithms/Backend | PyTorch eager mode | +| Algorithms/Backend | PyTorch eager mode | |--------------|----------| | RTN | ✔ | | GPTQ | ✔ | | AutoRound| ✔ | | AWQ | ✔ | -| TEQ | ✔ | +| TEQ | ✔ | | HQQ | ✔ | > **RTN:** A quantification method that we can think of very intuitively. It does not require additional datasets and is a very fast quantization method. Generally speaking, RTN will convert the weight into a uniformly distributed integer data type, but some algorithms, such as Qlora, propose a non-uniform NF4 data type and prove its theoretical optimality. @@ -64,8 +64,8 @@ WeightOnlyQuant quantization for PyTorch is using prepare and convert [APIs](./P | bits (int)| [1, ..., 8] | | group_size (int)| [-1, 1, ..., $C_{in}$] | | use_sym (bool)| [True, False] | -| use_double_quant (bool) | [True, False] | -| double_quant_dtype (str) | ['int'] | +| use_double_quant (bool) | [True, False] | +| double_quant_dtype (str) | ['int'] | | double_quant_bits (int) | [1, ..., bits] | | double_quant_use_sym (bool) | [True, False] | | double_quant_group_size (int) | [-1, 1, ..., $C_{in}$] | @@ -98,7 +98,7 @@ model = convert(model) #### GPTQ | gptq_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| -| use_mse_search (bool) | Enables mean squared error (MSE) search | False +| use_mse_search (bool) | Enables mean squared error (MSE) search | False | use_layer_wise (bool) | Enables quantize model per layer | False | | model_path (str) | Model path that is used to load state_dict per layer | | | use_double_quant (bool) | Enables double quantization | False | @@ -120,7 +120,7 @@ model = convert(model) #### AutoRound | autoround_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| -| enable_full_range (bool) | Whether to enable full range quantization | False +| enable_full_range (bool) | Whether to enable full range quantization | False | batch_size (int) | Batch size for training | 8 | | lr_scheduler | The learning rate scheduler to be used | None | | enable_quanted_input (bool) | Whether to use quantized input data | True | @@ -251,8 +251,8 @@ from neural_compressor.torch.quantization import load orig_model = YOURMODEL() loaded_model = load( - "saved_results", model=orig_model -) # Please note that the model parameter passes the original model. + "saved_results", original_model=orig_model +) # Please note that the original_model parameter passes the original model. ``` diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py index 5cd0f046aba..e77ef2c6a33 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py @@ -63,7 +63,7 @@ parser.add_argument("--calib_iters", default=100, type=int, help="calibration iters.") parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], \ - type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa", + type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa", "rte", "openbookqa", "lambada_standard", "wikitext"], help="tasks list for accuracy validation") parser.add_argument("--limit", default=None, type=int, @@ -117,10 +117,10 @@ for examples in calib_dataset: calib_data.append( tokenizer( - examples["text"], - return_tensors="pt", - max_length=64, - padding="max_length", + examples["text"], + return_tensors="pt", + max_length=64, + padding="max_length", truncation=True ) ) @@ -154,7 +154,7 @@ def calib_func(model): -# If torch.matmul and torch.bmm are not replaced by INC module, +# If torch.matmul and torch.bmm are not replaced by INC module, # Below codes can make torch.matmul and torch.bmm run on fp8 by injection. if not args.skip_fp8_mm and args.precision in ['fp8_e4m3', 'fp8_e5m2']: def replace_torch_mm_bmm(): diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py index 5d39cf3a62b..c586f8d765e 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py @@ -367,7 +367,7 @@ def run_fn(model): user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(user_model) user_model = convert(user_model) - + user_model.save(args.output_dir) @@ -377,9 +377,10 @@ def run_fn(model): print("load int8 model") from neural_compressor.torch.quantization import load + user_model, _ = get_user_model() tokenizer = AutoTokenizer.from_pretrained(args.model) config = AutoConfig.from_pretrained(args.model) - user_model = load(os.path.abspath(os.path.expanduser(args.output_dir))) + user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)), user_model) setattr(user_model, "config", config) else: user_model, tokenizer = get_user_model() diff --git a/neural_compressor/torch/algorithms/weight_only/__init__.py b/neural_compressor/torch/algorithms/weight_only/__init__.py index 28f108cb636..fc9ef0a5b3b 100644 --- a/neural_compressor/torch/algorithms/weight_only/__init__.py +++ b/neural_compressor/torch/algorithms/weight_only/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +from .save_load import save, load diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 990ba3f5de4..7494dac86f9 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -14,18 +14,20 @@ # pylint:disable=import-error +import copy import json import os +import re import torch from neural_compressor.common.utils import load_config_mapping, save_config_mapping -from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger +from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, LoadFormat, logger def save(model, output_dir="./saved_results"): os.makedirs(output_dir, exist_ok=True) - qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) + qmodel_weight_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) # saving process save_config_mapping(model.qconfig, qconfig_file_path) @@ -37,14 +39,600 @@ def save(model, output_dir="./saved_results"): # MethodType 'save' not in state_dict del model.save - torch.save(model, qmodel_file_path) + torch.save(model.state_dict(), qmodel_weight_file_path) - logger.info("Save quantized model to {}.".format(qmodel_file_path)) + logger.info("Save quantized model weight to {}.".format(qmodel_weight_file_path)) logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) -def load(output_dir="./saved_results"): - qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) - model = torch.load(qmodel_file_path) - logger.info("Quantized model loading successful.") +def load(model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, device="cpu", **kwargs): + """Load quantized weight-only quantization model. + + 1. Load INC weight-only quantized model in local. + 2. Load HuggingFace weight-only quantized model, + including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. + + Args: + model_name_or_path (str): torch checkpoint directory or hugginface model_name_or_path. + If 'format' is set to 'huggingface', it means the huggingface model_name_or_path. + If 'format' is set to 'default', it means the 'checkpoint_dir'. + Parameter should not be None. it coworks with 'original_model' parameter to load INC + weight-only quantized model in local. + original_model (torch.nn.module, optional): original model before quantization. + Needed if 'format' is set to 'default' and not TorchScript model.Defaults to None. + format (str, optional): 'defult' for loading INC weight-only quantized model. + 'huggingface' for loading huggingface WOQ causal language model. Defaults to "default". + kwargs (remaining dictionary of keyword arguments, optional): + remaining dictionary of keyword arguments for loading huggingface models. + will be passed to the huggingface model's `__init__` method, such as 'trust_remote_code', 'revision'. + Returns: + torch.nn.Module: quantized model + """ + model_loader = WOQModelLoader(model_name_or_path, original_model, format, device, **kwargs) + model = model_loader.load_woq_model() return model + + +class WOQModelLoader: + def __init__(self, model_name_or_path, original_model=None, format=LoadFormat.DEFAULT, device="cpu", **kwargs): + # TODO: When loading WOQ model, use different WeightOnlyLinear module according to device. + self.model_name_or_path = model_name_or_path + self.original_model = original_model + self.format = format + self.device = device + self.kwargs = kwargs + self.quantization_config = {} + self.loaded_state_dict_keys = {} + + def load_woq_model(self): + if self.format == LoadFormat.HUGGINGFACE: + model = self.load_hf_format_woq_model() + logger.info("Loading HuggingFace weight-only quantization model successfully.") + elif self.format == LoadFormat.DEFAULT: + qmodel_weight_file_path = os.path.join( + os.path.abspath(os.path.expanduser(self.model_name_or_path)), WEIGHT_NAME + ) + assert os.path.exists(qmodel_weight_file_path), ( + "Cannot load model weight from path {}. " + "Please make sure '{}' file is saved in your '{}' directory ".format( + qmodel_weight_file_path, WEIGHT_NAME, self.model_name_or_path + ) + ) + + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(self.model_name_or_path)), QCONFIG_NAME) + assert os.path.exists(qconfig_file_path), ( + "Cannot load model quantization config from path {}. " + "Please make sure '{}' file is saved in your '{}' directory".format( + qconfig_file_path, QCONFIG_NAME, self.model_name_or_path + ) + ) + + assert ( + self.original_model is not None + ), "Can't get original model. Please pass `original_model` to load function." + + model = self.load_inc_format_woq_model(qmodel_weight_file_path, qconfig_file_path) + logger.info("Loading weight-only quantization model successfully.") + else: + raise ValueError(f"`format` in load function can only be 'huggingface' or 'default', but get {self.format}") + + return model + + def load_inc_format_woq_model(self, qmodel_weight_file_path, qconfig_file_path): + qweights = torch.load(qmodel_weight_file_path) + self.loaded_state_dict_keys = qweights.keys() + + with open(qconfig_file_path, "r") as file: + self.quantization_config = json.load(file) + + model = self._build_woq_model() + model.load_state_dict(qweights, assign=True) + model.eval() + return model + + def load_hf_format_woq_model(self): + # check required package + from neural_compressor.torch.utils import is_package_available + + if not is_package_available("transformers"): + raise ImportError("Loading huggingface model requires transformers: `pip install transformers`") + if not is_package_available("accelerate"): + raise ImportError("Loading huggingface model requires accelerate: `pip install accelerate`") + + # get model_class and config + model_class, config = self._get_model_class_and_config() + self.quantization_config = config.quantization_config + + # get loaded_state_dict_keys + self.loaded_state_dict_keys = self._get_loaded_state_dict_keys(config) + + # initiate the huggingface model + self.original_model = self._init_hf_model(model_class, config) + + # build weight-only quantization model with WeightOnlyLinear module + model = self._build_woq_model() + + # load quantized weight to woq model + model = self._load_pretrained_weight(model, model_class) + + return model + + def _build_woq_model(self): + """Build weight-only quantization model.""" + from neural_compressor.torch.utils import set_module + + from .modules import MulLinear + + for name, module in self.original_model.named_modules(): + _is_autoround = False + # get quantization config of module + module_quantization_config = self.quantization_config + # pattern will map (module_name, moduele_type) + pattern = rf"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))" + for q_config_key, q_config_value in self.quantization_config.items(): + if re.search(pattern, q_config_key): + if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "autoround": + _is_autoround = True + module_quantization_config = [config for config in q_config_value.values()][0] + + if isinstance(module, torch.nn.Linear): + # module without qweight means it is not quantized, then skip it + loaded_state_dict_keys_set = set(self.loaded_state_dict_keys) + if ( + name + ".qweight" not in loaded_state_dict_keys_set + and name + ".linear.qweight" not in loaded_state_dict_keys_set + ): + continue + + # insert MulLinear module + if name + ".linear.qweight" in loaded_state_dict_keys_set: + new_module = MulLinear(module) + set_module(self.original_model, name, new_module) + name += ".linear" + + # replace `torch.nn.Linear` with `WeightOnlyLinear` + zp = True if name + ".qzeros" in loaded_state_dict_keys_set else False + g_idx = True if name + ".g_idx" in loaded_state_dict_keys_set else False + + kwargs = {} + if _is_autoround: + from auto_round.export.export_to_itrex.model_wrapper import ( + WeightOnlyLinear as AutoRoundWeightOnlyLinear, + ) + + from .utility import convert_dtype_str2torch + + WeightOnlyLinearClass = AutoRoundWeightOnlyLinear + kwargs["groupsize"] = module_quantization_config.get("group_size", 32) + kwargs["scale_dtype"] = convert_dtype_str2torch( + module_quantization_config.get("scale_dtype", "fp16") + ) + else: + from .modules import WeightOnlyLinear as INCWeightOnlyLinear + + WeightOnlyLinearClass = INCWeightOnlyLinear + kwargs["group_size"] = module_quantization_config.get("group_size", 32) + kwargs["g_idx"] = g_idx + + new_module = WeightOnlyLinearClass( + module.in_features, + module.out_features, + dtype=module_quantization_config.get("dtype", "int"), + bits=module_quantization_config.get("bits", 4), + zp=zp, + bias=module.bias is not None, + use_optimum_format=True, + **kwargs, + ) + set_module(self.original_model, name, new_module) + woq_model = self.original_model + return woq_model + + def _get_model_class_and_config(self): + from transformers import AutoConfig, AutoModelForCausalLM + from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code + from transformers.models.auto.auto_factory import _get_model_class + + # Autofactory + kwargs_orig = copy.deepcopy(self.kwargs) + trust_remote_code = self.kwargs.pop("trust_remote_code", None) + kwarg_attn_imp = self.kwargs.pop("attn_implementation", None) + + config = AutoConfig.from_pretrained(self.model_name_or_path) + # quantization_config = config.quantization_config + + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: # pragma: no cover + config._attn_implementation = kwarg_attn_imp + + has_remote_code = hasattr(config, "auto_map") and AutoModelForCausalLM.__name__ in config.auto_map + + has_local_code = type(config) in AutoModelForCausalLM._model_mapping.keys() + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, + self.model_name_or_path, + has_local_code, + has_remote_code, + ) + + if has_remote_code and trust_remote_code: # pragma: no cover + class_ref = config.auto_map[AutoModelForCausalLM.__name__] + model_class = get_class_from_dynamic_module(class_ref, self.model_name_or_path, **kwargs_orig) + if os.path.isdir(self.model_name_or_path): + model_class.register_for_auto_class(AutoModelForCausalLM.__name__) + else: + AutoModelForCausalLM.register(config.__class__, model_class, exist_ok=True) + elif type(config) in AutoModelForCausalLM._model_mapping.keys(): + model_class = _get_model_class(config, AutoModelForCausalLM._model_mapping) + + return model_class, config + + def _get_loaded_state_dict_keys(self, config): + from transformers.configuration_utils import PretrainedConfig + from transformers.modeling_utils import _add_variant, get_checkpoint_shard_files, load_state_dict + from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + cached_file, + download_url, + extract_commit_hash, + has_file, + is_remote_url, + is_safetensors_available, + ) + + subfolder = self.kwargs.pop("subfolder", "") + variant = self.kwargs.pop("variant", None) + cache_dir = self.kwargs.pop("cache_dir", None) + force_download = self.kwargs.pop("force_download", False) + proxies = self.kwargs.pop("proxies", None) + resume_download = self.kwargs.pop("resume_download", False) + local_files_only = self.kwargs.pop("local_files_only", False) + offload_folder = self.kwargs.pop("offload_folder", None) + offload_state_dict = self.kwargs.pop("offload_state_dict", False) + use_auth_token = self.kwargs.pop("use_auth_token", None) + token = self.kwargs.pop("token", None) + from_pipeline = self.kwargs.pop("_from_pipeline", None) + from_auto_class = self.kwargs.pop("_from_auto", False) + revision = self.kwargs.pop("revision", "main") + commit_hash = self.kwargs.pop("_commit_hash", None) + use_safetensors = self.kwargs.pop("use_safetensors", None) + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + if use_auth_token is not None: # pragma: no cover + logger.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. " + "Please use `token` instead." + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + user_agent = { + "file_type": "model", + "framework": "pytorch", + "from_auto_class": from_auto_class, + } + if from_pipeline is not None: # pragma: no cover + user_agent["using_pipeline"] = from_pipeline + + if commit_hash is None: + if not isinstance(config, PretrainedConfig): # pragma: no cover + # We make a call to the config file first (which may be absent) + # to get the commit hash as soon as possible. + resolved_config_file = cached_file( + self.model_name_or_path, + "config.json", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + + is_sharded = False + sharded_metadata = None + + if self.model_name_or_path is not None: # pragma: no cover + self.model_name_or_path = str(self.model_name_or_path) + is_local = os.path.isdir(self.model_name_or_path) + if is_local: + if os.path.isfile( + os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(WEIGHTS_NAME, variant), + ) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(WEIGHTS_NAME, variant), + ) + elif os.path.isfile( + os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant), + ) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + elif os.path.isfile( + os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + elif os.path.isfile( + os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + self.model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + elif os.path.isfile(os.path.join(subfolder, self.model_name_or_path)): + archive_file = self.model_name_or_path + is_local = True + elif is_remote_url(self.model_name_or_path): + filename = self.model_name_or_path + resolved_archive_file = download_url(self.model_name_or_path) + else: + if use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(self.model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + self.model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + raise EnvironmentError( + f"{self.model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or " + f"{_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file(self.model_name_or_path, filename, **cached_file_kwargs) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + self.model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + if variant is not None and has_file(self.model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{self.model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{self.model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}." + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{self.model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{self.model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." + ) from e + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: # pragma: no cover + resolved_archive_file = None + + if is_sharded: # pragma: no cover + # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + self.model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + self.kwargs["sharded_metadata"] = sharded_metadata + + if is_sharded: # pragma: no cover + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = list(state_dict.keys()) + + # set kwargs for next functions to use + self.kwargs["is_sharded"] = is_sharded + self.kwargs["offload_folder"] = offload_folder + self.kwargs["offload_state_dict"] = offload_state_dict + self.kwargs["resolved_archive_file"] = resolved_archive_file + + return loaded_state_dict_keys + + def _init_hf_model(self, model_class, config): + from accelerate.big_modeling import init_empty_weights + from transformers.modeling_utils import no_init_weights + from transformers.utils import ContextManagers + + _fast_init = self.kwargs.pop("_fast_init", True) + torch_dtype = self.kwargs.pop("torch_dtype", "auto") + is_sharded = self.kwargs.pop("is_sharded", False) + sharded_metadata = self.kwargs.pop("sharded_metadata", None) + offload_folder = self.kwargs.pop("offload_folder", None) + offload_state_dict = self.kwargs.pop("offload_state_dict", False) + resolved_archive_file = self.kwargs.pop("resolved_archive_file", None) + + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, + # by checking its first weights entry that is of a floating type + # - we assume all floating dtype weights are of the same dtype + dtype_orig = None + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if ( + hasattr(config, "torch_dtype") + and config.torch_dtype is not None + and config.torch_dtype != "auto" + ): + torch_dtype = config.torch_dtype + else: # pragma: no cover + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + else: + torch_dtype = torch.float32 + else: # pragma: no cover + assert False, f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' + + dtype_orig = model_class._set_default_torch_dtype(torch_dtype) + + init_contexts = [no_init_weights(_enable=_fast_init)] + init_contexts.append(init_empty_weights()) + + with ContextManagers(init_contexts): + model = model_class(config, **self.kwargs) + + # set kwargs for next functions to use + self.kwargs["resolved_archive_file"] = resolved_archive_file + self.kwargs["sharded_metadata"] = sharded_metadata + self.kwargs["torch_dtype"] = torch_dtype + self.kwargs["dtype_orig"] = dtype_orig + self.kwargs["_fast_init"] = _fast_init + self.kwargs["offload_folder"] = offload_folder + self.kwargs["offload_state_dict"] = offload_state_dict + + return model + + def _load_pretrained_weight(self, model, model_class): + resolved_archive_file = self.kwargs.pop("resolved_archive_file", None) + sharded_metadata = self.kwargs.pop("sharded_metadata", None) + torch_dtype = self.kwargs.pop("torch_dtype", torch.float32) + dtype_orig = self.kwargs.pop("dtype_orig", None) + _fast_init = self.kwargs.pop("_fast_init", True) + offload_folder = self.kwargs.pop("offload_folder", None) + offload_state_dict = self.kwargs.pop("offload_state_dict", False) + + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = model_class._load_pretrained_model( + model, + None, + self.loaded_state_dict_keys, + resolved_archive_file, + self.model_name_or_path, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=True, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + return model diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 322fcbd7d82..ce13990c00f 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -1108,3 +1108,29 @@ def forward(self, *args, **kwargs): with torch.no_grad(): self.args_list.append(args) self.kwargs_list.append(kwargs) + + +def convert_dtype_str2torch(str_dtype): + """Converts a string dtype to its corresponding PyTorch dtype. + + Args: + str_dtype (str): The string representation of the dtype. + + Returns: + torch.dtype: The PyTorch dtype. + + Raises: + AssertionError: If the input str_dtype is unsupported. + """ + if isinstance(str_dtype, torch.dtype) or str_dtype is None: + return str_dtype + if str_dtype == "int8": + return torch.int8 + elif str_dtype == "fp32" or str_dtype == "float32" or str_dtype == "auto": + return torch.float + elif str_dtype == "fp16" or str_dtype == "float16": + return torch.float16 + elif str_dtype == "bf16" or str_dtype == "bfloat16": + return torch.bfloat16 + else: + assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype) diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index fb870a92e77..d20f828659d 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -25,43 +25,84 @@ RTNConfig, TEQConfig, ) +from neural_compressor.torch.utils import LoadFormat config_name_mapping = { FP8_QUANT: FP8Config, } -def load(output_dir="./saved_results", model=None): - """The main entry of load for all algorithms. +def load(model_name_or_path, original_model=None, format="default", device="cpu", **kwargs): + """Load quantized model. + + 1. Load INC quantized model in local. + case 1: WOQ + from neural_compressor.torch.quantization import load + load(model_name_or_path="saved_results", original_model=fp32_model) + + case 2: INT8/FP8 + from neural_compressor.torch.quantization import load + load(model_name_or_path='saved_result', original_model=fp32_model) + + case 3: TorchScript (IPEX) + from neural_compressor.torch.quantization import load + load(model_name_or_path='saved_result') + + 2. Load HuggingFace quantized model, including GPTQ/AWQ models and upstreamed INC quantized models in HF model hub. + case 1: WOQ + from neural_compressor.torch.quantization import load + load(model_name_or_path=model_name_or_path) - Args: - output_dir (str, optional): path to quantized model folder. Defaults to "./saved_results". - model (torch.nn.Module, optional): original model, suggest to use empty tensor. + Args: + model_name_or_path (str): torch checkpoint directory or hugginface model_name_or_path. + If 'format' is set to 'huggingface', it means the huggingface model_name_or_path. + If 'format' is set to 'default', it means the 'checkpoint_dir'. + Parameter should not be None. it coworks with 'original_model' parameter to load INC + quantized model in local. + original_model (torch.nn.module or TorchScript model with IPEX or fx graph with pt2e, optional): + original model before quantization. Needed if 'format' is set to 'default' and not TorchScript model. + Defaults to None. + format (str, optional): 'defult' for loading INC quantized model. + 'huggingface' for loading huggingface WOQ causal language model. Defaults to "default". + device (str, optional): 'cpu', 'hpu' or 'cuda'. specify the device the model will be loaded to. + kwargs (remaining dictionary of keyword arguments, optional): + remaining dictionary of keyword arguments for loading huggingface models. + Will be passed to the huggingface model's `__init__` method, such as 'trust_remote_code', 'revision'. Returns: The quantized model """ - from neural_compressor.common.base_config import ConfigRegistry + # TODO: When loading WOQ model, use different WeightOnlyLinear module according to device. + if format == LoadFormat.DEFAULT.value: + from neural_compressor.common.base_config import ConfigRegistry - qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json") - with open(qconfig_file_path, "r") as f: - per_op_qconfig = json.load(f) + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json") + with open(qconfig_file_path, "r") as f: + per_op_qconfig = json.load(f) - if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... - from neural_compressor.torch.algorithms.static_quant import load + if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... + from neural_compressor.torch.algorithms import static_quant - return load(output_dir) - else: - config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) - # select load function - config_object = config_mapping[next(iter(config_mapping))] - if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ - from neural_compressor.torch.algorithms.weight_only.save_load import load + return static_quant.load(model_name_or_path) + else: + config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) + # select load function + config_object = config_mapping[next(iter(config_mapping))] + + if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ + from neural_compressor.torch.algorithms import weight_only - return load(output_dir) + return weight_only.load(model_name_or_path, original_model, format=LoadFormat.DEFAULT) - model.qconfig = config_mapping - if isinstance(config_object, FP8Config): # FP8 - from neural_compressor.torch.algorithms.habana_fp8 import load + original_model.qconfig = config_mapping + if isinstance(config_object, FP8Config): # FP8 + from neural_compressor.torch.algorithms import habana_fp8 - return load(model, output_dir) # pylint: disable=E1121 + return habana_fp8.load(model_name_or_path, original_model) + elif format == LoadFormat.HUGGINGFACE.value: + # now only support load huggingface WOQ causal language model + from neural_compressor.torch.algorithms import weight_only + + return weight_only.load(model_name_or_path, format=LoadFormat.HUGGINGFACE, **kwargs) + else: + raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format)) diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index 429851e311b..a655a70b8ed 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -53,3 +53,12 @@ PT2E_STATIC_QUANT = "pt2e_static_quant" PT2E_DYNAMIC_QUANT = "pt2e_dynamic_quant" + + +# load format name +from enum import Enum + + +class LoadFormat(Enum): + DEFAULT = "default" + HUGGINGFACE = "huggingface" diff --git a/test/3x/torch/algorithms/weight_only/test_woq_utility.py b/test/3x/torch/algorithms/weight_only/test_woq_utility.py index f672ec0ac1c..712ba52d889 100644 --- a/test/3x/torch/algorithms/weight_only/test_woq_utility.py +++ b/test/3x/torch/algorithms/weight_only/test_woq_utility.py @@ -11,3 +11,21 @@ def test_quant_tensor_id(shape): output = quant_tensor(input) id2 = id(output) assert id1 == id2, "quant_tensor function is an in-place operator" + + +def test_convert_dtype_str2torch(): + from neural_compressor.torch.algorithms.weight_only.utility import convert_dtype_str2torch + + # Test for supported dtypes + assert convert_dtype_str2torch("int8") == torch.int8 + assert convert_dtype_str2torch("fp32") == torch.float + assert convert_dtype_str2torch("float32") == torch.float + assert convert_dtype_str2torch("auto") == torch.float + assert convert_dtype_str2torch("fp16") == torch.float16 + assert convert_dtype_str2torch("float16") == torch.float16 + assert convert_dtype_str2torch("bf16") == torch.bfloat16 + assert convert_dtype_str2torch("bfloat16") == torch.bfloat16 + + # Test for unsupported dtypes + with pytest.raises(AssertionError): + convert_dtype_str2torch("int16") diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index b4ca66ad00b..f1539b072b7 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -117,7 +117,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.gptj)) loaded_out = loaded_model(self.inp)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance( diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 54f4af1ffa9..6e44a14acca 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -133,7 +133,7 @@ def calib_func(model): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index cd48edd8c35..be408af2564 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -254,7 +254,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance( diff --git a/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py new file mode 100644 index 00000000000..c12197d211c --- /dev/null +++ b/test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py @@ -0,0 +1,24 @@ +import torch + +from neural_compressor.torch.utils import accelerator + +device = accelerator.current_device_name() + + +class TestHFModelLoad: + def setup_class(self): + self.model_name = "TheBloke/TinyLlama-1.1B-python-v0.1-GPTQ" + self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long).to(device) + + def test_load_hf_woq_model(self): + from neural_compressor.torch.quantization import load + + qmodel = load(model_name_or_path=self.model_name, format="huggingface", torch_dtype=torch.float32) + + woq_linear_num = 0 + for _, module in qmodel.named_modules(): + if module.__class__.__name__ == "WeightOnlyLinear": + woq_linear_num += 1 + assert woq_linear_num == 154, "Incorrect number of WeightOnlyLinear modules" + output = qmodel(self.example_inputs)[0] + assert len(output) > 0, "Not loading the model correctly" diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index d6b31bbca25..f82185cc82e 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -292,7 +292,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." diff --git a/test/3x/torch/quantization/weight_only/test_teq.py b/test/3x/torch/quantization/weight_only/test_teq.py index 79447054050..9f4df1c4226 100644 --- a/test/3x/torch/quantization/weight_only/test_teq.py +++ b/test/3x/torch/quantization/weight_only/test_teq.py @@ -141,7 +141,7 @@ def test_save_and_load(self): from neural_compressor.torch.quantization import load # loading compressed model - loaded_model = load("saved_results") + loaded_model = load("saved_results", copy.deepcopy(self.gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."