Skip to content

Commit

Permalink
use load_target_platform_capabilities in all facades to manage TPC ch…
Browse files Browse the repository at this point in the history
…ecking and allowing MCT to accept TPC as JSON
  • Loading branch information
ofirgo committed Jan 15, 2025
1 parent 259ff8e commit a60f0b2
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.
# ==============================================================================

from typing import Callable
from typing import Callable, Union
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, CoreConfig
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
from model_compression_toolkit.verify_packages import FOUND_TF

if FOUND_TF:
Expand All @@ -38,7 +39,7 @@ def keras_resource_utilization_data(in_model: Model,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(
mixed_precision_config=MixedPrecisionQuantizationConfig()),
target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC
target_platform_capabilities: Union[TargetPlatformCapabilities, str] = KERAS_DEFAULT_TPC
) -> ResourceUtilization:
"""
Computes resource utilization data that can be used to calculate the desired target resource utilization
Expand All @@ -50,7 +51,7 @@ def keras_resource_utilization_data(in_model: Model,
in_model (Model): Keras model to quantize.
representative_data_gen (Callable): Dataset used for calibration.
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision of how the model should be quantized.
target_platform_capabilities (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to optimize the Keras model according to.
target_platform_capabilities (Union[TargetPlatformCapabilities, str]): FrameworkQuantizationCapabilities to optimize the Keras model according to.
Returns:
Expand Down Expand Up @@ -81,6 +82,7 @@ def keras_resource_utilization_data(in_model: Model,

fw_impl = KerasImplementation()

target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
# Attach tpc model to framework
attach2keras = AttachTpcToKeras()
target_platform_capabilities = attach2keras.attach(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Callable
from typing import Callable, Union

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import PYTORCH
Expand All @@ -23,6 +23,7 @@
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
from model_compression_toolkit.verify_packages import FOUND_TORCH

if FOUND_TORCH:
Expand All @@ -40,7 +41,7 @@
def pytorch_resource_utilization_data(in_model: Module,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(),
target_platform_capabilities: TargetPlatformCapabilities= PYTORCH_DEFAULT_TPC
target_platform_capabilities: Union[TargetPlatformCapabilities, str] = PYTORCH_DEFAULT_TPC
) -> ResourceUtilization:
"""
Computes resource utilization data that can be used to calculate the desired target resource utilization for mixed-precision quantization.
Expand All @@ -50,7 +51,7 @@ def pytorch_resource_utilization_data(in_model: Module,
in_model (Model): PyTorch model to quantize.
representative_data_gen (Callable): Dataset used for calibration.
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision
target_platform_capabilities (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to optimize the PyTorch model according to.
target_platform_capabilities (Union[TargetPlatformCapabilities, str]): FrameworkQuantizationCapabilities to optimize the PyTorch model according to.
Returns:
Expand Down Expand Up @@ -81,6 +82,7 @@ def pytorch_resource_utilization_data(in_model: Module,

fw_impl = PytorchImplementation()

target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
# Attach tpc model to framework
attach2pytorch = AttachTpcToPytorch()
target_platform_capabilities = (
Expand Down
7 changes: 5 additions & 2 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
AttachTpcToKeras
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
Expand Down Expand Up @@ -156,7 +157,8 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da
gptq_representative_data_gen: Callable = None,
target_resource_utilization: ResourceUtilization = None,
core_config: CoreConfig = CoreConfig(),
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
target_platform_capabilities: Union[TargetPlatformCapabilities, str]
= DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
"""
Quantize a trained Keras model using post-training quantization. The model is quantized using a
symmetric constraint quantization thresholds (power of two).
Expand All @@ -180,7 +182,7 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Keras model according to.
Returns:
Expand Down Expand Up @@ -241,6 +243,7 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da

fw_impl = GPTQKerasImplemantation()

target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
# Attach tpc model to framework
attach2keras = AttachTpcToKeras()
framework_platform_capabilities = attach2keras.attach(
Expand Down
6 changes: 4 additions & 2 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.metadata import create_model_metadata
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
from model_compression_toolkit.verify_packages import FOUND_TORCH


Expand Down Expand Up @@ -145,7 +146,7 @@ def pytorch_gradient_post_training_quantization(model: Module,
core_config: CoreConfig = CoreConfig(),
gptq_config: GradientPTQConfig = None,
gptq_representative_data_gen: Callable = None,
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC):
"""
Quantize a trained Pytorch module using post-training quantization.
By default, the module is quantized using a symmetric constraint quantization thresholds
Expand All @@ -169,7 +170,7 @@ def pytorch_gradient_post_training_quantization(model: Module,
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the PyTorch model according to.
Returns:
A quantized module and information the user may need to handle the quantized module.
Expand Down Expand Up @@ -214,6 +215,7 @@ def pytorch_gradient_post_training_quantization(model: Module,

fw_impl = GPTQPytorchImplemantation()

target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
# Attach tpc model to framework
attach2pytorch = AttachTpcToPytorch()
framework_quantization_capabilities = attach2pytorch.attach(target_platform_capabilities,
Expand Down
9 changes: 6 additions & 3 deletions model_compression_toolkit/pruning/keras/pruning_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
# ==============================================================================

from typing import Callable, Tuple
from typing import Callable, Tuple, Union

from model_compression_toolkit import get_target_platform_capabilities
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.pruning.pruner import Pruner
Expand All @@ -43,7 +44,8 @@ def keras_pruning_experimental(model: Model,
target_resource_utilization: ResourceUtilization,
representative_data_gen: Callable,
pruning_config: PruningConfig = PruningConfig(),
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
target_platform_capabilities: Union[TargetPlatformCapabilities, str]
= DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
"""
Perform structured pruning on a Keras model to meet a specified target resource utilization.
This function prunes the provided model according to the target resource utilization by grouping and pruning
Expand All @@ -61,7 +63,7 @@ def keras_pruning_experimental(model: Model,
target_resource_utilization (ResourceUtilization): The target Key Performance Indicators to be achieved through pruning.
representative_data_gen (Callable): A function to generate representative data for pruning analysis.
pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
target_platform_capabilities (FrameworkQuantizationCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
target_platform_capabilities (Union[TargetPlatformCapabilities, str]): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
Returns:
Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information.
Expand Down Expand Up @@ -112,6 +114,7 @@ def keras_pruning_experimental(model: Model,
# Instantiate the Keras framework implementation.
fw_impl = PruningKerasImplementation()

target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
# Attach tpc model to framework
attach2keras = AttachTpcToKeras()
target_platform_capabilities = attach2keras.attach(target_platform_capabilities)
Expand Down
9 changes: 6 additions & 3 deletions model_compression_toolkit/pruning/pytorch/pruning_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.
# ==============================================================================

from typing import Callable, Tuple
from typing import Callable, Tuple, Union
from model_compression_toolkit import get_target_platform_capabilities
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
from model_compression_toolkit.verify_packages import FOUND_TORCH
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.pruning.pruner import Pruner
Expand Down Expand Up @@ -47,7 +48,8 @@ def pytorch_pruning_experimental(model: Module,
target_resource_utilization: ResourceUtilization,
representative_data_gen: Callable,
pruning_config: PruningConfig = PruningConfig(),
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
target_platform_capabilities: Union[TargetPlatformCapabilities, str]
= DEFAULT_PYOTRCH_TPC) -> \
Tuple[Module, PruningInfo]:
"""
Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
Expand All @@ -66,7 +68,7 @@ def pytorch_pruning_experimental(model: Module,
target_resource_utilization (ResourceUtilization): Key Performance Indicators specifying the pruning targets.
representative_data_gen (Callable): A function to generate representative data for pruning analysis.
pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities.
target_platform_capabilities (Union[TargetPlatformCapabilities, str]): Platform-specific constraints and capabilities.
Defaults to DEFAULT_PYTORCH_TPC.
Returns:
Expand Down Expand Up @@ -118,6 +120,7 @@ def pytorch_pruning_experimental(model: Module,
# Instantiate the Pytorch framework implementation.
fw_impl = PruningPytorchImplementation()

target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
# Attach TPC to framework
attach2pytorch = AttachTpcToPytorch()
framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
Expand Down
4 changes: 3 additions & 1 deletion model_compression_toolkit/ptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
from model_compression_toolkit.verify_packages import FOUND_TF
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
Expand Down Expand Up @@ -70,7 +71,7 @@ def keras_post_training_quantization(in_model: Model,
representative_data_gen (Callable): Dataset used for calibration.
target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Keras model according to.
Returns:
Expand Down Expand Up @@ -137,6 +138,7 @@ def keras_post_training_quantization(in_model: Model,

fw_impl = KerasImplementation()

target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
attach2keras = AttachTpcToKeras()
framework_platform_capabilities = attach2keras.attach(
target_platform_capabilities,
Expand Down
Loading

0 comments on commit a60f0b2

Please sign in to comment.