From ac47d9b97b597f809ab56f9f6cb1a86951e2e334 Mon Sep 17 00:00:00 2001 From: Yi30 <106061964+yiliu30@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:30:52 +0800 Subject: [PATCH] Enhance auto-tune module (#1608) Signed-off-by: Kaihui-intel Signed-off-by: yiliu30 Signed-off-by: chensuyue Co-authored-by: Kaihui-intel Co-authored-by: chensuyue --- .../scripts/codeScan/pylint/pylint.sh | 3 +- neural_compressor/common/__init__.py | 2 + neural_compressor/common/base_config.py | 54 +++++-- neural_compressor/common/base_tuning.py | 140 +++++++++++++----- neural_compressor/common/tuning_param.py | 100 +++++++++++++ neural_compressor/common/utils/utility.py | 28 ++++ .../torch/quantization/autotune.py | 2 + requirements_ort.txt | 1 + requirements_pt.txt | 1 + requirements_tf.txt | 1 + test/3x/common/test_common.py | 67 +++++---- test/3x/common/test_param.py | 23 +++ test/3x/onnxrt/test_autotune.py | 6 +- test/3x/onnxrt/test_config.py | 8 - test/3x/tensorflow/keras/test_config.py | 7 - 15 files changed, 345 insertions(+), 98 deletions(-) create mode 100644 neural_compressor/common/tuning_param.py create mode 100644 test/3x/common/test_param.py diff --git a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh index 3f31a7327a3..9feb5d2051f 100644 --- a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh +++ b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh @@ -39,7 +39,8 @@ pip install torch \ prettytable \ psutil \ py-cpuinfo \ - pyyaml + pyyaml \ + pydantic \ if [ "${scan_module}" = "neural_solution" ]; then cd /neural-compressor diff --git a/neural_compressor/common/__init__.py b/neural_compressor/common/__init__.py index 68ec1311b39..5a65e40ecb1 100644 --- a/neural_compressor/common/__init__.py +++ b/neural_compressor/common/__init__.py @@ -20,6 +20,7 @@ set_resume_from, set_workspace, set_tensorboard, + dump_elapsed_time, ) from neural_compressor.common.base_config import options @@ -33,4 +34,5 @@ "set_random_seed", "set_resume_from", "set_tensorboard", + "dump_elapsed_time", ] diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 3ae12259fd7..4a206d37486 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -17,6 +17,7 @@ from __future__ import annotations +import inspect import json import re from abc import ABC, abstractmethod @@ -25,6 +26,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from neural_compressor.common import Logger +from neural_compressor.common.tuning_param import TuningParam from neural_compressor.common.utils import ( BASE_CONFIG, COMPOSABLE_CONFIG, @@ -295,6 +297,15 @@ def __add__(self, other: BaseConfig) -> BaseConfig: else: return ComposableConfig(configs=[self, other]) + @staticmethod + def get_the_default_value_of_param(config: BaseConfig, param: str) -> Any: + # Get the signature of the __init__ method + signature = inspect.signature(config.__init__) + + # Get the parameters and their default values + parameters = signature.parameters + return parameters.get(param).default + def expand(self) -> List[BaseConfig]: """Expand the config. @@ -331,19 +342,42 @@ def expand(self) -> List[BaseConfig]: """ config_list: List[BaseConfig] = [] params_list = self.params_list - params_dict = OrderedDict() config = self + tuning_param_list = [] + not_tuning_param_pair = {} # key is the param name, value is the user specified value for param in params_list: - param_val = getattr(config, param) - # TODO (Yi) to handle param_val itself is a list - if isinstance(param_val, list): - params_dict[param] = param_val + # Create `TuningParam` for each param + # There are two cases: + # 1. The param is a string. + # 2. The param is a `TuningParam` instance. + if isinstance(param, str): + default_param = self.get_the_default_value_of_param(config, param) + tuning_param = TuningParam(name=param, tunable_type=List[type(default_param)]) + elif isinstance(param, TuningParam): + tuning_param = param else: - params_dict[param] = [param_val] - for params_values in product(*params_dict.values()): - new_config = self.__class__(**dict(zip(params_list, params_values))) - config_list.append(new_config) - logger.info(f"Expanded the {self.__class__.name} and got {len(config_list)} configs.") + raise ValueError(f"Unsupported param type: {param}") + # Assign the options to the `TuningParam` instance + param_val = getattr(config, tuning_param.name) + if param_val is not None: + if tuning_param.is_tunable(param_val): + tuning_param.options = param_val + tuning_param_list.append(tuning_param) + else: + not_tuning_param_pair[tuning_param.name] = param_val + logger.debug("Tuning param list: %s", tuning_param_list) + logger.debug("Not tuning param pair: %s", not_tuning_param_pair) + if len(tuning_param_list) == 0: + config_list = [config] + else: + tuning_param_name_lst = [tuning_param.name for tuning_param in tuning_param_list] + for params_values in product(*[tuning_param.options for tuning_param in tuning_param_list]): + tuning_param_pair = dict(zip(tuning_param_name_lst, params_values)) + tmp_params_dict = {**not_tuning_param_pair, **tuning_param_pair} + new_config = self.__class__(**tmp_params_dict) + logger.info(new_config.to_dict()) + config_list.append(new_config) + logger.info("Expanded the %s and got %d configs.", self.__class__.name, len(config_list)) return config_list def _get_op_name_op_type_config(self): diff --git a/neural_compressor/common/base_tuning.py b/neural_compressor/common/base_tuning.py index 4438a702bd7..6e9ca5486fc 100644 --- a/neural_compressor/common/base_tuning.py +++ b/neural_compressor/common/base_tuning.py @@ -16,7 +16,7 @@ import copy import inspect import uuid -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union from neural_compressor.common import Logger from neural_compressor.common.base_config import BaseConfig, ComposableConfig @@ -31,6 +31,10 @@ "TuningMonitor", "TuningLogger", "init_tuning", + "Sampler", + "SequentialSampler", + "default_sampler", + "ConfigSet", ] @@ -123,36 +127,103 @@ def self_check(self) -> None: evaluator = Evaluator() -class Sampler: - # TODO Separate sorting functionality of `ConfigLoader` into `Sampler` in the follow-up PR. - pass +class ConfigSet: + def __init__(self, config_list: List[BaseConfig]) -> None: + self.config_list = config_list -class ConfigLoader: - def __init__(self, config_set, sampler: Sampler) -> None: - self.config_set = config_set - self.sampler = sampler + def __getitem__(self, index) -> BaseConfig: + assert 0 <= index < len(self.config_list), f"Index {index} out of range." + return self.config_list[index] - @staticmethod - def parse_quant_config(quant_config: BaseConfig) -> List[BaseConfig]: - if isinstance(quant_config, ComposableConfig): - result = [] - for q_config in quant_config.config_list: - result += q_config.expand() - return result + def __len__(self) -> int: + return len(self.config_list) + + @classmethod + def _from_single_config(cls, config: BaseConfig) -> List[BaseConfig]: + config_list = [] + config_list = config.expand() + return config_list + + @classmethod + def _from_list_of_configs(cls, fwk_configs: List[BaseConfig]) -> List[BaseConfig]: + config_list = [] + for config in fwk_configs: + config_list += cls._from_single_config(config) + return config_list + + @classmethod + def generate_config_list(cls, fwk_configs: Union[BaseConfig, List[BaseConfig]]): + # There are several cases for the input `fwk_configs`: + # 1. fwk_configs is a single config + # 2. fwk_configs is a list of configs + # For a single config, we need to check if it can be expanded or not. + config_list = [] + if isinstance(fwk_configs, BaseConfig): + config_list = cls._from_single_config(fwk_configs) + elif isinstance(fwk_configs, List): + config_list = cls._from_list_of_configs(fwk_configs) else: - return quant_config.expand() + raise NotImplementedError(f"Unsupported type {type(fwk_configs)} for fwk_configs.") + return config_list + + @classmethod + def from_fwk_configs(cls, fwk_configs: Union[BaseConfig, List[BaseConfig]]) -> "ConfigSet": + """Create a ConfigSet object from a single config or a list of configs. + + Args: + fwk_configs: A single config or a list of configs. + Examples: + 1) single config: RTNConfig(weight_group_size=32) + 2) single expandable config: RTNConfig(weight_group_size=[32, 64]) + 3) mixed 1) and 2): [RTNConfig(weight_group_size=32), RTNConfig(weight_group_size=[32, 64])] + + Returns: + ConfigSet: A ConfigSet object. + """ + config_list = cls.generate_config_list(fwk_configs) + return cls(config_list) + + +class Sampler: + def __init__(self, config_source: Optional[ConfigSet]) -> None: + pass + + def __iter__(self) -> Iterator[BaseConfig]: + """Iterate over indices of config set elements.""" + raise NotImplementedError - def parse_quant_configs(self) -> List[BaseConfig]: - # TODO (Yi) separate this functionality into `Sampler` in the next PR - quant_config_list = [] - for quant_config in self.config_set: - quant_config_list.extend(ConfigLoader.parse_quant_config(quant_config)) - return quant_config_list + +class SequentialSampler(Sampler): + """Samples elements sequentially, always in the same order. + + Args: + config_source (_ConfigSet): config set to sample from + """ + + config_source: Sized + + def __init__(self, config_source: Sized) -> None: + self.config_source = config_source + + def __iter__(self) -> Iterator[int]: + return iter(range(len(self.config_source))) + + def __len__(self) -> int: + return len(self.config_source) + + +default_sampler = SequentialSampler + + +class ConfigLoader: + def __init__(self, config_set: ConfigSet, sampler: Sampler = default_sampler) -> None: + self.config_set = ConfigSet.from_fwk_configs(config_set) + self._sampler = sampler(self.config_set) def __iter__(self) -> Generator[BaseConfig, Any, None]: - for config in self.parse_quant_configs(): - yield config + for index in self._sampler: + yield self.config_set[index] class TuningLogger: @@ -211,12 +282,14 @@ class TuningConfig: Args: config_set: quantization configs. Default value is empty. - timeout: Tuning timeout (seconds). Default value is 0 which means early stop. + A single config or a list of configs. More details can + be found in the `from_fwk_configs`of `ConfigSet` class. max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit. tolerable_loss: This float indicates how much metric loss we can accept. \ The metric loss is relative, it can be both positive and negative. Default is 0.01. Examples: + # TODO: to refine it from neural_compressor import TuningConfig tune_config = TuningConfig( config_set=[config1, config2, ...], @@ -239,28 +312,13 @@ class TuningConfig: # The best tuning config is config2, because of the following: # 1. Not achieving the set goal. (config_metric < fp32_baseline * (1 - tolerable_loss)) # 2. Reached maximum tuning times. - - # Case 3: Timeout - tune_config = TuningConfig( - config_set=[config1, config2, ...], - timeout=10, # seconds - max_trials=3, - tolerable_loss=0.01 - ) - config1_tuning_time, config2_tuning_time, config3_tuning_time, ... = 4, 5, 6, ... # seconds - fp32_baseline = 100 - config1_metric, config2_metric, config3_metric, ... = 98, 98, 97, ... - - # Tuning result of case 3: - # The best tuning config is config2, due to timeout, the third trial was forced to exit. """ def __init__( - self, config_set=None, timeout=0, max_trials=100, sampler: Sampler = None, tolerable_loss=0.01 + self, config_set=None, max_trials=100, sampler: Sampler = default_sampler, tolerable_loss=0.01 ) -> None: """Init a TuneCriterion object.""" self.config_set = config_set - self.timeout = timeout self.max_trials = max_trials self.sampler = sampler self.tolerable_loss = tolerable_loss diff --git a/neural_compressor/common/tuning_param.py b/neural_compressor/common/tuning_param.py new file mode 100644 index 00000000000..f7c894bb892 --- /dev/null +++ b/neural_compressor/common/tuning_param.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +from enum import Enum, auto +from typing import Any, List + +from pydantic import BaseModel + +from neural_compressor.common import logger + + +class ParamLevel(Enum): + OP_LEVEL = auto() + OP_TYPE_LEVEL = auto() + MODEL_LEVEL = auto() + + +class TuningParam: + """Define the tunable parameter for the algorithm. + + Example: + Class FakeAlgoConfig(BaseConfig): + '''Fake algo config.'''. + + params_list = [ + ... + # For simple tunable types, like a list of int, giving + # the param name is enough. `BaseConfig` class will + # create the `TuningParam` implicitly. + "simple_attr" + + # For complex tunable types, like a list of lists, + # developers need to create the `TuningParam` explicitly. + TuningParam("complex_attr", tunable_type=List[List[str]]) + + # The default parameter level is `ParamLevel.OP_LEVEL`. + # If the parameter is at a different level, developers need + # to specify it explicitly. + TuningParam("model_attr", level=ParamLevel.MODEL_LEVEL) + + ... + + # TODO: more examples to explain the usage of `TuningParam`. + """ + + def __init__( + self, + name: str, + default_val: Any = None, + tunable_type=None, + options=None, + level: ParamLevel = ParamLevel.OP_LEVEL, + ) -> None: + self.name = name + self.default_val = default_val + self.tunable_type = tunable_type + self.options = options + self.level = level + + @staticmethod + def create_input_args_model(expect_args_type: Any) -> type: + """Dynamically create an InputArgsModel based on the provided type hint. + + Parameters: + - expect_args_type (Any): The user-provided type hint for input_args. + + Returns: + - type: The dynamically created InputArgsModel class. + """ + + class DynamicInputArgsModel(BaseModel): + input_args: expect_args_type + + return DynamicInputArgsModel + + def is_tunable(self, value: Any) -> bool: + # Use `Pydantic` to validate the input_args. + # TODO: refine the implementation in further. + assert isinstance( + self.tunable_type, typing._GenericAlias + ), f"Expected a type hint, got {self.tunable_type} instead." + DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type) + try: + new_args = DynamicInputArgsModel(input_args=value) + return True + except Exception as e: + logger.debug(f"Failed to validate the input_args: {e}") + return False diff --git a/neural_compressor/common/utils/utility.py b/neural_compressor/common/utils/utility.py index edf1c0cb55b..9d8591f7f18 100644 --- a/neural_compressor/common/utils/utility.py +++ b/neural_compressor/common/utils/utility.py @@ -15,14 +15,42 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time + +from neural_compressor.common.utils.logger import logger + __all__ = [ "set_workspace", "set_random_seed", "set_resume_from", "set_tensorboard", + "dump_elapsed_time", ] +def dump_elapsed_time(customized_msg=""): + """Get the elapsed time for decorated functions. + + Args: + customized_msg (string, optional): The parameter passed to decorator. Defaults to None. + """ + + def f(func): + def fi(*args, **kwargs): + start = time.time() + res = func(*args, **kwargs) + end = time.time() + logger.info( + "%s elapsed time: %s ms" + % (customized_msg if customized_msg else func.__qualname__, round((end - start) * 1000, 2)) + ) + return res + + return fi + + return f + + def set_random_seed(seed: int): """Set the random seed in config.""" from neural_compressor.common import options diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index 2aeb101b308..3a900985967 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -19,6 +19,7 @@ from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning +from neural_compressor.common.utils import dump_elapsed_time from neural_compressor.torch.quantization import quantize from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig from neural_compressor.torch.utils import constants, logger @@ -41,6 +42,7 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]: return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME) +@dump_elapsed_time("Pass auto-tune") def autotune( model: torch.nn.Module, tune_config: TuningConfig, diff --git a/requirements_ort.txt b/requirements_ort.txt index 3a27c292e06..cf1352e4395 100644 --- a/requirements_ort.txt +++ b/requirements_ort.txt @@ -2,3 +2,4 @@ numpy onnx onnxruntime onnxruntime-extensions +pydantic diff --git a/requirements_pt.txt b/requirements_pt.txt index 9d42cfe7c53..e3129bee51a 100644 --- a/requirements_pt.txt +++ b/requirements_pt.txt @@ -1,2 +1,3 @@ intel_extension_for_pytorch +pydantic torch diff --git a/requirements_tf.txt b/requirements_tf.txt index 5fbd34f6ae7..f8075c2a068 100644 --- a/requirements_tf.txt +++ b/requirements_tf.txt @@ -1,5 +1,6 @@ prettytable psutil py-cpuinfo +pydantic pyyaml tensorflow diff --git a/test/3x/common/test_common.py b/test/3x/common/test_common.py index c683f5220af..177ec9a361e 100644 --- a/test/3x/common/test_common.py +++ b/test/3x/common/test_common.py @@ -48,7 +48,8 @@ get_all_config_set_from_config_registry, register_config, ) -from neural_compressor.common.base_tuning import ConfigLoader, Sampler +from neural_compressor.common.base_tuning import ConfigLoader, ConfigSet, SequentialSampler +from neural_compressor.common.tuning_param import TuningParam from neural_compressor.common.utils import DEFAULT_WHITE_LIST, OP_NAME_OR_MODULE_TYPE PRIORITY_FAKE_ALGO = 100 @@ -66,6 +67,7 @@ class FakeAlgoConfig(BaseConfig): params_list = [ "weight_dtype", "weight_bits", + TuningParam("target_op_type_list", tunable_type=List[List[str]]), ] name = FAKE_CONFIG_NAME @@ -73,6 +75,7 @@ def __init__( self, weight_dtype: str = "int", weight_bits: int = 4, + target_op_type_list: List[str] = ["Conv", "Gemm"], white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): """Init fake config. @@ -84,6 +87,7 @@ def __init__( super().__init__(white_list=white_list) self.weight_bits = weight_bits self.weight_dtype = weight_dtype + self.target_op_type_list = target_op_type_list self._post_init() def to_dict(self): @@ -142,36 +146,43 @@ def test_api(self): self.assertEqual(len(config_set), 1) self.assertEqual(config_set[0].weight_bits, DEFAULT_WEIGHT_BITS) + def test_config_expand_complex_tunable_type(self): + target_op_type_list_options = [["Conv", "Gemm"], ["Conv", "Matmul"]] + configs = FakeAlgoConfig(target_op_type_list=target_op_type_list_options) + configs_list = configs.expand() + self.assertEqual(len(configs_list), len(target_op_type_list_options)) + for i in range(len(configs_list)): + self.assertEqual(configs_list[i].target_op_type_list, target_op_type_list_options[i]) -class TestConfigLoader(unittest.TestCase): + +class TestConfigSet(unittest.TestCase): + def setUp(self): + self.config_set = [get_default_fake_config(), get_default_fake_config()] + self.config_set_obj = ConfigSet.from_fwk_configs(self.config_set) + + def test_config_set(self) -> None: + self.assertEqual(len(self.config_set_obj), len(self.config_set)) + self.assertEqual(self.config_set_obj[0].weight_bits, self.config_set[0].weight_bits) + + +class TestConfigSampler(unittest.TestCase): def setUp(self): self.config_set = [get_default_fake_config(), get_default_fake_config()] - self.loader = ConfigLoader(self.config_set, Sampler()) - - def test_parse_quant_config_single(self): - quant_config = get_default_fake_config() - result = ConfigLoader.parse_quant_config(quant_config) - self.assertEqual(str(result), str(quant_config.expand())) - - def test_parse_quant_config_composable(self): - quant_config = get_default_fake_config() - composable_config = ComposableConfig(get_default_fake_config()) - composable_config.config_list = [quant_config] - result = ConfigLoader.parse_quant_config(composable_config) - self.assertEqual(str(result), str(quant_config.expand())) - - def test_parse_quant_configs(self): - quant_configs = [get_default_fake_config(), get_default_fake_config()] - self.config_set[0].expand = lambda: quant_configs - self.config_set[1].expand = lambda: [] - result = self.loader.parse_quant_configs() - self.assertEqual(result, quant_configs) - - def test_iteration(self): - quant_configs = [get_default_fake_config(), get_default_fake_config()] - self.loader.parse_quant_configs = lambda: quant_configs - result = list(self.loader) - self.assertEqual(result, quant_configs) + self.seq_sampler = SequentialSampler(self.config_set) + + def test_config_sampler(self) -> None: + self.assertEqual(list(self.seq_sampler), list(range(len(self.config_set)))) + + +class TestConfigLoader(unittest.TestCase): + def setUp(self): + self.config_set = [FakeAlgoConfig(weight_bits=4), FakeAlgoConfig(weight_bits=8)] + self.loader = ConfigLoader(self.config_set) + + def test_config_loader(self) -> None: + self.assertEqual(len(list(self.loader)), len(self.config_set)) + for i, config in enumerate(self.loader): + self.assertEqual(config, self.config_set[i]) if __name__ == "__main__": diff --git a/test/3x/common/test_param.py b/test/3x/common/test_param.py new file mode 100644 index 00000000000..879efc9921f --- /dev/null +++ b/test/3x/common/test_param.py @@ -0,0 +1,23 @@ +import unittest +from typing import List + +from neural_compressor.common.tuning_param import TuningParam + + +class TestTuningParam(unittest.TestCase): + def test_is_tunable_same_type(self): + # Test when tunable_type has the same type as the default value + param = TuningParam("param_name", [1, 2, 3], List[int]) + self.assertTrue(param.is_tunable([4, 5, 6])) + self.assertFalse(param.is_tunable(["not_an_int"])) + + def test_is_tunable_recursive(self): + # Test recursive type checking for iterables + param = TuningParam("param_name", [[1, 2], [3, 4]], List[List[int]]) + self.assertTrue(param.is_tunable([[5, 6], [7, 8]])) + # TODO: double check if this is the expected behavior + self.assertTrue(param.is_tunable([[5, 6], [7, "8"]])) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/3x/onnxrt/test_autotune.py b/test/3x/onnxrt/test_autotune.py index 41109e4b800..a1909868352 100644 --- a/test/3x/onnxrt/test_autotune.py +++ b/test/3x/onnxrt/test_autotune.py @@ -93,9 +93,9 @@ def eval_acc_fn(model) -> float: calibration_data_reader=self.data_reader, ) call_args_list = mock_warning.call_args_list - self.assertEqual( - call_args_list[0][0][0], - "Please refine your eval_fns to accept model path (str) as input.", + # There may be multiple calls to warning, so we need to check all of them + self.assertIn( + "Please refine your eval_fns to accept model path (str) as input.", [info[0][0] for info in call_args_list] ) def test_sq_auto_tune(self): diff --git a/test/3x/onnxrt/test_config.py b/test/3x/onnxrt/test_config.py index 5cc668965ea..9b0c49de1b8 100644 --- a/test/3x/onnxrt/test_config.py +++ b/test/3x/onnxrt/test_config.py @@ -246,14 +246,6 @@ def test_expand_config(self): self.assertEqual(expand_config_list[0].weight_bits, 4) self.assertEqual(expand_config_list[1].weight_bits, 8) - def test_config_set_api(self): - # *Note: this test is only for improving the code coverage and can be removed once the test_common is enabled. - from neural_compressor.common.base_config import config_registry, get_all_config_set_from_config_registry - from neural_compressor.onnxrt.quantization.config import FRAMEWORK_NAME - - config_set = get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME) - self.assertEqual(len(config_set), len(config_registry.registered_configs[FRAMEWORK_NAME])) - if __name__ == "__main__": unittest.main() diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index 8d7a0dcc340..fe9c7830356 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -315,13 +315,6 @@ def test_expand_config(self): self.assertEqual(expand_config_list[0].weight_granularity, "per_channel") self.assertEqual(expand_config_list[1].weight_granularity, "per_tensor") - def test_config_set_api(self): - # *Note: this test is only for improving the code coverage and can be removed once the test_common is enabled. - from neural_compressor.common.base_config import config_registry, get_all_config_set_from_config_registry - - config_set = get_all_config_set_from_config_registry(fwk_name="keras") - self.assertEqual(len(config_set), len(config_registry.registered_configs["keras"])) - if __name__ == "__main__": unittest.main()