From 91964c4145695c1e13ae0075db5a50fad0d67a15 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Thu, 20 Feb 2025 17:39:16 +0100 Subject: [PATCH] Add ov_submodels property to OVBaseModel --- optimum/intel/openvino/modeling_base.py | 13 ++++- .../intel/openvino/modeling_base_seq2seq.py | 19 ++++--- optimum/intel/openvino/modeling_diffusion.py | 31 ++++++++---- .../openvino/modeling_visual_language.py | 14 ++---- optimum/intel/openvino/quantization.py | 46 ++++++----------- tests/openvino/test_exporters_cli.py | 28 +---------- tests/openvino/test_quantization.py | 50 ++++--------------- tests/openvino/utils_tests.py | 12 ++--- 8 files changed, 80 insertions(+), 133 deletions(-) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 932b505b70..25c095628f 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -17,7 +17,7 @@ import warnings from pathlib import Path from tempfile import gettempdir -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union import openvino import torch @@ -204,6 +204,17 @@ def dtype(self) -> Optional[torch.dtype]: return None + @property + def ov_submodels(self) -> Dict[str, openvino.runtime.Model]: + return {submodel_name: getattr(self, submodel_name) for submodel_name in self._ov_submodel_names} + + @property + def _ov_submodel_names(self) -> List[str]: + """ + List of openvino submodel names. Used as keys for a dictionary returned by `.submodels` property. + """ + return ["model"] + @staticmethod def load_model( file_name: Union[str, Path], diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index a61ec2bad8..fa225e17e5 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -15,7 +15,7 @@ import logging import os from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union import openvino from huggingface_hub import hf_hub_download @@ -106,6 +106,13 @@ def __init__( self._openvino_config = OVConfig(quantization_config=quantization_config) self._set_ov_config_parameters() + @property + def _ov_submodel_names(self) -> List[str]: + submodel_names = ["encoder", "decoder"] + if self.decoder_with_past_model is not None: + submodel_names.append("decoder_with_past") + return submodel_names + def _save_pretrained(self, save_directory: Union[str, Path]): """ Saves the model to the OpenVINO IR format so that it can be re-loaded using the @@ -482,13 +489,9 @@ def half(self): raise ValueError( "`half()` is not supported with `compile_only` mode, please intialize model without this option" ) - apply_moc_transformations(self.encoder_model, cf=False) - apply_moc_transformations(self.decoder_model, cf=False) - compress_model_transformation(self.encoder_model) - compress_model_transformation(self.decoder_model) - if self.decoder_with_past_model is not None: - apply_moc_transformations(self.decoder_with_past_model, cf=False) - compress_model_transformation(self.decoder_with_past_model) + for submodel in self.ov_submodels.values(): + apply_moc_transformations(submodel, cf=False) + compress_model_transformation(submodel) return self def forward(self, *args, **kwargs): diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 05cb39e91c..9a51c3bbbc 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -264,6 +264,24 @@ def __init__( if compile and not self._compile_only: self.compile() + @property + def ov_submodels(self) -> Dict[str, openvino.runtime.Model]: + return {name: getattr(getattr(self, name), "model") for name in self._ov_submodel_names} + + @property + def _ov_submodel_names(self) -> List[str]: + submodel_name_candidates = [ + "unet", + "transformer", + "vae_decoder", + "vae_encoder", + "text_encoder", + "text_encoder_2", + "text_encoder_3", + ] + submodel_names = [name for name in submodel_name_candidates if getattr(self, name) is not None] + return submodel_names + def _save_pretrained(self, save_directory: Union[str, Path]): """ Saves the model to the OpenVINO IR format so that it can be re-loaded using the @@ -879,17 +897,8 @@ def half(self): "`half()` is not supported with `compile_only` mode, please intialize model without this option" ) - for component in { - self.unet, - self.transformer, - self.vae_encoder, - self.vae_decoder, - self.text_encoder, - self.text_encoder_2, - self.text_encoder_3, - }: - if component is not None: - compress_model_transformation(component.model) + for submodel in self.ov_submodels.values(): + compress_model_transformation(submodel) self.clear_requests() diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index cd8d07ea95..54b3cf5588 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -385,17 +385,17 @@ def _save_pretrained(self, save_directory: Union[str, Path]): save_directory (`str` or `Path`): The directory where to save the model files. """ - src_models = self.submodels + src_models = self.ov_submodels dst_file_names = { "lm_model": OV_LANGUAGE_MODEL_NAME, "text_embeddings_model": OV_TEXT_EMBEDDINGS_MODEL_NAME, "vision_embeddings_model": OV_VISION_EMBEDDINGS_MODEL_NAME, } - for name in self._submodel_names: + for name in self._ov_submodel_names: if name not in dst_file_names: dst_file_names[name] = f"openvino_{name}.xml" - for name in self._submodel_names: + for name in self._ov_submodel_names: model = src_models[name] dst_file_name = dst_file_names[name] dst_path = os.path.join(save_directory, dst_file_name) @@ -653,17 +653,13 @@ def components(self): return {component_name: getattr(self, component_name) for component_name in self._component_names} @property - def _submodel_names(self): + def _ov_submodel_names(self): model_names = ["lm_model", "text_embeddings_model", "vision_embeddings_model"] for part in self.additional_parts: if getattr(self, part, None) is not None: model_names.append(part + "_model") return model_names - @property - def submodels(self): - return {submodel_name: getattr(self, submodel_name) for submodel_name in self._submodel_names} - def reshape(self, batch_size: int, sequence_length: int): logger.warning("Static shapes are not supported for causal language model.") return self @@ -672,7 +668,7 @@ def half(self): """ Converts all the model weights to FP16 for more efficient inference on GPU. """ - for _, submodel in self.submodels.items(): + for submodel in self.ov_submodels.values(): apply_moc_transformations(submodel, cf=False) compress_model_transformation(submodel) return self diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 9c3838cfa9..451a7dba96 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -393,30 +393,23 @@ def _quantize_ovbasemodel( if calibration_dataset is None: raise ValueError("Calibration dataset is required to run hybrid quantization.") if is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline): - # Apply weight-only quantization to all SD submodels except UNet + # Apply weight-only quantization to all SD submodels except UNet/Transformer quantization_config_copy = quantization_config.clone() quantization_config_copy.dataset = None quantization_config_copy.quant_method = OVQuantizationMethod.DEFAULT - sub_model_names = [ - "vae_encoder", - "vae_decoder", - "text_encoder", - "text_encoder_2", - "text_encoder_3", - ] - sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names)) + sub_models = [v for (k, v) in self.model.ov_submodels.items() if k not in ("unet", "transformer")] for sub_model in sub_models: - _weight_only_quantization(sub_model.model, quantization_config_copy, **kwargs) + _weight_only_quantization(sub_model, quantization_config_copy, **kwargs) - if self.model.unet is not None: - # Apply hybrid quantization to UNet - self.model.unet.model = _hybrid_quantization( - self.model.unet.model, quantization_config, calibration_dataset, **kwargs - ) + unet_is_present = self.model.unet is not None + vision_model = (self.model.unet if unet_is_present else self.model.transformer).model + quantized_vision_model = _hybrid_quantization( + vision_model, quantization_config, calibration_dataset, **kwargs + ) + if unet_is_present: + self.model.unet.model = quantized_vision_model else: - self.model.transformer.model = _hybrid_quantization( - self.model.transformer.model, quantization_config, calibration_dataset, **kwargs - ) + self.model.transformer.model = quantized_vision_model self.model.clear_requests() else: @@ -427,24 +420,13 @@ def _quantize_ovbasemodel( self.model.request = None else: if is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline): - sub_model_names = [ - "vae_encoder", - "vae_decoder", - "text_encoder", - "text_encoder_2", - "unet", - "transformer", - "text_encoder_3", - ] - sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names)) - for sub_model in sub_models: - _weight_only_quantization(sub_model.model, quantization_config, **kwargs) + for submodel in self.model.ov_submodels.values(): + _weight_only_quantization(submodel, quantization_config, **kwargs) self.model.clear_requests() elif isinstance(self.model, OVModelForVisualCausalLM): language_model = self.model.language_model _weight_only_quantization(language_model.model, quantization_config, calibration_dataset, **kwargs) - sub_model_names = ["vision_embeddings", "text_embeddings"] + self.model.additional_parts - sub_models = [getattr(self.model, f"{name}_model") for name in sub_model_names] + sub_models = [v for (k, v) in self.model.ov_submodels.items() if k != "lm_model"] for sub_model in sub_models: _weight_only_quantization(sub_model, OVWeightQuantizationConfig(bits=8, sym=True), **kwargs) self.model.clear_requests() diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 1ba82ffd7d..c0cee70e7f 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -606,29 +606,11 @@ def test_exporters_cli_int8(self, task: str, model_type: str): else _HEAD_TO_AUTOMODELS[model_type.replace("-refiner", "")] ).from_pretrained(tmpdir, **model_kwargs) - if task.startswith("text2text-generation"): - models = [model.encoder, model.decoder] - if task.endswith("with-past") and not model.decoder.stateful: - models.append(model.decoder_with_past) - elif ( - model_type.startswith("stable-diffusion") - or model_type.startswith("flux") - or model_type.startswith("sana") - ): - models = [model.unet or model.transformer, model.vae_encoder, model.vae_decoder] - models.append( - model.text_encoder if model_type in ["stable-diffusion", "sana"] else model.text_encoder_2 - ) - elif task.startswith("image-text-to-text"): - models = list(model.submodels.values()) - else: - models = [model] - expected_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] expected_int8 = [{"int8": it} for it in expected_int8] if task.startswith("text2text-generation") and (not task.endswith("with-past") or model.decoder.stateful): expected_int8 = expected_int8[:2] - check_compression_state_per_model(self, models, expected_int8) + check_compression_state_per_model(self, model.ov_submodels.values(), expected_int8) @parameterized.expand(SUPPORTED_SD_HYBRID_ARCHITECTURES) def test_exporters_cli_hybrid_quantization( @@ -667,13 +649,7 @@ def test_exporters_cli_4bit( else _HEAD_TO_AUTOMODELS[model_type.replace("-refiner", "")] ).from_pretrained(tmpdir, **model_kwargs) - submodels = [] - if task == "text-generation-with-past": - submodels = [model] - elif task == "image-text-to-text": - submodels = list(model.submodels.values()) - - check_compression_state_per_model(self, submodels, expected_num_weight_nodes_per_model) + check_compression_state_per_model(self, model.ov_submodels.values(), expected_num_weight_nodes_per_model) self.assertTrue("--awq" not in option or b"Applying AWQ" in result.stdout) self.assertTrue("--scale-estimation" not in option or b"Applying Scale Estimation" in result.stdout) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 6cf926d3c6..41e8e179e1 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -818,20 +818,6 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust self.assertEqual(model._openvino_config.quantization_config.bits, 8) self.assertEqual(model._openvino_config.dtype, "int8") - if model.export_feature.startswith("text2text-generation"): - models = [model.encoder, model.decoder] - if model.decoder_with_past is not None: - models.append(model.decoder_with_past) - elif model.export_feature == "text-to-image": - models = [model.unet, model.vae_encoder, model.vae_decoder] - models.append(model.text_encoder if model_type in ["stable-diffusion", "sana"] else model.text_encoder_2) - elif model_type == "open-clip": - models = [model.text_model, model.visual_model] - elif model.export_feature == "image-text-to-text": - models = list(model.submodels.values()) - else: - models = [model] - if model_type == "open-clip": pytest.skip(reason="ticket 161043") elif model_type == "t5": @@ -839,9 +825,12 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust else: check_optimization_not_applicable_to_optimized_model(model, quantization_config={"bits": 8}) + submodels = ( + [model.text_model, model.visual_model] if model_type == "open-clip" else model.ov_submodels.values() + ) expected_ov_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] expected_ov_int8 = [{"int8": it} for it in expected_ov_int8] - check_compression_state_per_model(self, models, expected_ov_int8) + check_compression_state_per_model(self, submodels, expected_ov_int8) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION) def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_fake_nodes, expected_int8_nodes): @@ -938,11 +927,7 @@ def test_ovmodel_4bit_auto_compression_with_config( # TODO: Check that AWQ was actually applied pass - submodels = [] - if isinstance(model, OVModelForCausalLM): - submodels = [model.model] - elif isinstance(model, OVModelForVisualCausalLM): - submodels = list(model.submodels.values()) + submodels = list(model.ov_submodels.values()) check_compression_state_per_model(self, submodels, expected_num_weight_nodes_per_model) model.save_pretrained(tmp_dir) @@ -976,21 +961,11 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type, tru model = model_cls.from_pretrained( MODEL_NAMES[model_type], export=True, load_in_8bit=False, trust_remote_code=trust_remote_code ) - if model.export_feature.startswith("text2text-generation"): - models = [model.encoder, model.decoder] - if model.decoder_with_past is not None: - models.append(model.decoder_with_past) - elif model.export_feature == "text-to-image": - models = [model.unet, model.vae_encoder, model.vae_decoder] - models.append(model.text_encoder if model_type in ["stable-diffusion", "sana"] else model.text_encoder_2) - elif model_type == "open-clip": - models = [model.text_model, model.visual_model] - elif model.export_feature == "image-text-to-text": - models = list(model.submodels.values()) - else: - models = [model] - for i, submodel in enumerate(models): + submodels = ( + [model.text_model, model.visual_model] if model_type == "open-clip" else model.ov_submodels.values() + ) + for i, submodel in enumerate(submodels): ov_model = submodel if isinstance(submodel, ov.Model) else submodel.model _, num_weight_nodes = get_num_quantized_nodes(ov_model) self.assertEqual(0, num_weight_nodes["int8"]) @@ -1106,12 +1081,7 @@ def test_ovmodel_4bit_dynamic_with_config( self.assertEqual(model.ov_config["DYNAMIC_QUANTIZATION_GROUP_SIZE"], str(group_size)) self.assertEqual(model.ov_config["KV_CACHE_PRECISION"], "u8") - submodels = [] - if isinstance(model, OVModelForCausalLM): - submodels = [model.model] - elif isinstance(model, OVModelForVisualCausalLM): - submodels = list(model.submodels.values()) - check_compression_state_per_model(self, submodels, expected_num_weight_nodes_per_model) + check_compression_state_per_model(self, model.ov_submodels.values(), expected_num_weight_nodes_per_model) model.save_pretrained(tmp_dir) openvino_config = OVConfig.from_pretrained(tmp_dir) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index b347fbdb45..cdb86962e0 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -190,13 +190,13 @@ "wav2vec2": (34,), "distilbert": (66,), "t5": (64, 104, 84), - "stable-diffusion": (242, 34, 42, 64), - "stable-diffusion-xl": (366, 34, 42, 66), - "stable-diffusion-xl-refiner": (366, 34, 42, 66), + "stable-diffusion": (242, 42, 34, 64), + "stable-diffusion-xl": (366, 42, 34, 64, 66), + "stable-diffusion-xl-refiner": (366, 42, 34, 66), "open-clip": (20, 28), - "stable-diffusion-3": (66, 42, 58, 30), - "flux": (56, 24, 28, 64), - "flux-fill": (56, 24, 28, 64), + "stable-diffusion-3": (66, 58, 42, 30, 30, 32), + "flux": (56, 28, 24, 64, 64), + "flux-fill": (56, 28, 24, 64, 64), "llava": (30, 1, 9), "llava_next": (30, 1, 9), "minicpmv": (30, 1, 26, 6),