diff --git a/neural_compressor/torch/algorithms/weight_only/awq.py b/neural_compressor/torch/algorithms/weight_only/awq.py index 00d7fb5172c..677f3cb9899 100644 --- a/neural_compressor/torch/algorithms/weight_only/awq.py +++ b/neural_compressor/torch/algorithms/weight_only/awq.py @@ -516,6 +516,9 @@ def block_inference(self, model): """ total_out = [] for args, kwargs in zip(self.total_block_args, self.total_block_kwargs): + # to avoid layer_past: Dynamic_cache when transformers higher than 4.45.1 + if "layer_past" in kwargs.keys() and kwargs["layer_past"] is not None: + kwargs["layer_past"] = None out = model(*args, **kwargs) if isinstance(out, tuple): # pragma: no cover out = out[0] diff --git a/neural_compressor/torch/algorithms/weight_only/save_load.py b/neural_compressor/torch/algorithms/weight_only/save_load.py index 8d1259cad00..7d22c7efbc9 100644 --- a/neural_compressor/torch/algorithms/weight_only/save_load.py +++ b/neural_compressor/torch/algorithms/weight_only/save_load.py @@ -834,19 +834,36 @@ def _load_remaining_pretrained_weight(self, model): resolved_archive_file = [resolved_archive_file] for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - _load_state_dict_into_meta_model( - model=model, - state_dict=state_dict, - loaded_state_dict_keys=self.loaded_state_dict_keys, - start_prefix="", - expected_keys=list(state_dict.keys()), - device_map={"": self.device}, - offload_folder=offload_folder, - state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None, - state_dict_index={} if offload_state_dict else None, - dtype=torch_dtype, - keep_in_fp32_modules=[], - ) + import transformers + from packaging.version import Version + + if Version(transformers.__version__) >= Version("4.45.0"): # pragma: no cover + _load_state_dict_into_meta_model( + model=model, + state_dict=state_dict, + start_prefix="", + expected_keys=list(state_dict.keys()), + device_map={"": self.device}, + offload_folder=offload_folder, + state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None, + state_dict_index={} if offload_state_dict else None, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) + else: + _load_state_dict_into_meta_model( + model=model, + state_dict=state_dict, + loaded_state_dict_keys=self.loaded_state_dict_keys, + start_prefix="", + expected_keys=list(state_dict.keys()), + device_map={"": self.device}, + offload_folder=offload_folder, + state_dict_folder=tempfile.mkdtemp() if offload_state_dict else None, + state_dict_index={} if offload_state_dict else None, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) # make sure token embedding weights are still tied if needed model.tie_weights()