diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 6007580d94e..41aeaca7ab3 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -21,6 +21,8 @@ import re from abc import ABC, abstractmethod from collections import OrderedDict +from copy import deepcopy +from itertools import product from typing import Any, Callable, Dict, List, Optional, Tuple, Union from neural_compressor.common.logger import Logger @@ -225,6 +227,57 @@ def __add__(self, other: BaseConfig) -> BaseConfig: else: return ComposableConfig(configs=[self, other]) + def expand(self) -> List[BaseConfig]: + """Expand the config. + + case 1 + { + "global": { "weight_bits": [4, 6]} + } + expand to : + 1st trial config: + { + "global": { "weight_bits": 4} + } + 2nd trial config: + { + "global": { "weight_bits": 6} + } + case 2 + # TODO (Yi) to support the expansion of config with `local` + { + "global": { + "weight_bits": [4, 6] + }, + "local": + { + "fc1":{ + "weight_bits": [6, 8] + }, + "fc2":{ + "weight_bits": [4] + } + } + + } -> ? + """ + config_list: List[BaseConfig] = [] + params_list = self.params_list + params_dict = OrderedDict() + config = self + for param in params_list: + param_val = getattr(config, param) + # TODO (Yi) to handle param_val itself is a list + if isinstance(param_val, list): + params_dict[param] = param_val + else: + params_dict[param] = [param_val] + for params_values in product(*params_dict.values()): + new_config = self.__class__(**dict(zip(params_list, params_values))) + config_list.append(new_config) + logger.info(f"Expanded the {self.__class__.name} and got {len(config_list)} configs.") + return config_list + def _get_op_name_op_type_config(self): op_type_config_dict = dict() op_name_config_dict = dict() diff --git a/neural_compressor/common/base_tune.py b/neural_compressor/common/base_tune.py new file mode 100644 index 00000000000..69652d5856b --- /dev/null +++ b/neural_compressor/common/base_tune.py @@ -0,0 +1,159 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Any, Callable, Dict, List, Optional, Union + +from neural_compressor.common.base_config import BaseConfig, ComposableConfig +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() + + +class FrameworkWrapper: + """Abstract base class for wrap framework's APIs. + + FrameworkWrapper provides a uniform interface for encapsulating different framework's APIs. + This class is intended to be used by a `tuner` to obtain quantized models. + """ + + def __init__(self, model) -> None: + self.model = model + + @abstractmethod + def apply(self) -> Any: + """The entry to apply algorithms on a given model.""" + raise NotImplementedError + + +class TuningObjectives: + EVAL_FN = "eval_fn" + WEIGHT = "weight" + FN_NAME = "name" + EVAL_FN_TEMPLATE: Dict[str, Any] = {EVAL_FN: None, WEIGHT: 1.0, FN_NAME: None} + + def __init__(self) -> None: + self.eval_fn_registry: List[Dict[str, Any]] = [] + + def evaluate(self, model) -> float: + """Evaluate the model using registered evaluation functions. + + Args: + model: The fp32 model or quantized model. + + Returns: + The overall result of all registered evaluation functions. + """ + result = 0 + for eval_pair in self.eval_fn_registry: + eval_fn = eval_pair[self.EVAL_FN] + eval_result = eval_fn(model) + result = self._update_the_objective_score(eval_pair, eval_result, result) + return result + + def _update_the_objective_score(self, eval_pair, eval_result, overall_result) -> float: + # TODO update the result according to the weight and algo_name + return overall_result + eval_result * eval_pair[self.WEIGHT] + + def get_number_of_tuning_objectives(self) -> int: + return len(self.eval_fn_registry) + + def _set_eval_fn_registry(self, user_eval_fns: List[Dict]) -> None: + self.eval_fn_registry = [ + { + self.EVAL_FN: user_eval_fn_pair[self.EVAL_FN], + self.WEIGHT: user_eval_fn_pair.get(self.WEIGHT, 1.0), + self.FN_NAME: user_eval_fn_pair.get(self.FN_NAME, user_eval_fn_pair[self.EVAL_FN].__name__), + } + for user_eval_fn_pair in user_eval_fns + ] + + def set_eval_fn_registry(self, eval_fns: Optional[Union[Dict, List[Dict]]] = None) -> None: + if eval_fns is None: + return + elif isinstance(eval_fns, Dict): + eval_fns = [eval_fns] + elif isinstance(eval_fns, List): + assert all([isinstance(eval_fn_pair, Dict) for eval_fn_pair in eval_fns]) + else: + raise NotImplementedError(f"The eval_fns should be a dict or a list of dict, but got {type(eval_fns)}.") + self._set_eval_fn_registry(eval_fns) + + +tuning_objectives = TuningObjectives() + + +class BaseTuningConfig: + """Base Class for Tuning Criterion. + + Args: + quant_configs: quantization configs. Default value is empty. + timeout: Tuning timeout (seconds). Default value is 0 which means early stop. + max_trials: Max tune times. Default value is 100. Combine with timeout field to decide when to exit. + """ + + def __init__(self, quant_configs=None, timeout=0, max_trials=100) -> None: + """Init a TuningCriterion object.""" + self.quant_configs = quant_configs + self.timeout = timeout + self.max_trials = max_trials + + +class Tuner: + def __init__( + self, tune_config: BaseTuningConfig, tuning_objectives: TuningObjectives, fwk_wrapper: FrameworkWrapper + ) -> None: + self.tune_config = tune_config + self.tuning_objectives = tuning_objectives + self.fwk_wrapper = fwk_wrapper + self._post_init() + + def _post_init(self) -> None: + # check the number of evaluation functions + num_tuning_objectives = self.tuning_objectives.get_number_of_tuning_objectives() + assert ( + num_tuning_objectives > 0 + ), "Please ensure that you register at least one evaluation metric for auto-tune." + logger.info(f"There are {num_tuning_objectives} tuning objectives.") + + @staticmethod + def parse_quant_config(quant_config: BaseConfig) -> List[BaseConfig]: + if isinstance(quant_config, ComposableConfig): + result = [] + for q_config in quant_config.config_list: + result += q_config.expand() + return result + else: + return quant_config.expand() + + def parse_quant_configs(self) -> List[BaseConfig]: + quant_config_list = [] + for quant_config in self.tune_config.quant_configs: + quant_config_list.extend(Tuner.parse_quant_config(quant_config)) + return quant_config_list + + def get_best_model(self, q_model, objective_score: Union[float, int]) -> Any: + # TODO(Yi) enable it at the next PR + pass + + def get_tuning_objective_score(self, model) -> float: + eval_result = self.tuning_objectives.evaluate(model) + return eval_result + + def search(self) -> Any: + for config in self.parse_quant_configs(): + logger.info(f"config {config}") + q_model = self.fwk_wrapper.apply(quant_config=config) + if self.get_best_model(q_model, self.get_tuning_objective_score(q_model)): + return q_model diff --git a/neural_compressor/torch/__init__.py b/neural_compressor/torch/__init__.py index a9f9d6731e3..57cfe472297 100644 --- a/neural_compressor/torch/__init__.py +++ b/neural_compressor/torch/__init__.py @@ -22,3 +22,5 @@ GPTQConfig, get_default_gptq_config, ) + +from neural_compressor.torch.tune import autotune, TuningConfig, get_default_tune_config diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index ecf813331ef..2a1f1ca599b 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -54,7 +54,7 @@ def quantize( else: assert isinstance( quant_config, BaseConfig - ), "Please pass a dict or config instance as the quantization configuration." + ), f"Please pass a dict or config instance as the quantization configuration, but got {type(quant_config)}." logger.info(f"Quantize model with config: \n {quant_config.to_json_string()} \n") # select quantization algo according to config diff --git a/neural_compressor/torch/tune.py b/neural_compressor/torch/tune.py new file mode 100644 index 00000000000..656d6c5b1be --- /dev/null +++ b/neural_compressor/torch/tune.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch + +from neural_compressor.common.base_tune import BaseTuningConfig, FrameworkWrapper, Tuner, tuning_objectives +from neural_compressor.common.logger import Logger +from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig + +logger = Logger().get_logger() + + +def get_default_tuning_config(): + # TODO (Yi) support it in the next PR + return None + + +class TorchWrapper(FrameworkWrapper): + """Concrete implementation of `FrameworkWrapper` for PyTorch models.""" + + def __init__( + self, model: torch.nn.Module, run_fn: Optional[Callable] = None, run_args: Optional[Tuple] = None + ) -> None: + super().__init__(model) + self.run_fn = run_fn + self.run_args = run_args + + def apply(self, quant_config): + """The entry to apply quantization algorithms on a given a model.""" + logger.info(f"apply quant_config: {quant_config}.") + from neural_compressor.torch import quantize + + q_model = quantize(model=self.model, quant_config=quant_config, run_fn=self.run_fn, run_args=self.run_args) + return q_model + + +class TuningConfig(BaseTuningConfig): + def __init__(self, quant_configs=None, timeout=0, max_trials=100): + super().__init__(quant_configs, timeout, max_trials) + + +def autotune( + model: torch.nn.Module, + tune_config: TuningConfig, + eval_fns: Optional[Union[Dict, List[Dict]]] = None, + run_fn=None, + run_args=None, +): + tuning_objectives.set_eval_fn_registry(eval_fns) + torch_wrapper = TorchWrapper(model, run_fn, run_args) + tuner = Tuner(tune_config=tune_config, tuning_objectives=tuning_objectives, fwk_wrapper=torch_wrapper) + best_qmodel = tuner.search() + return best_qmodel + + +def get_default_tune_config(): + # TODO use the registered default tuning config in the next PR + return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNWeightQuantConfig(weight_bits=[4, 8])]) diff --git a/test/3x/tensorflow/test_config.py b/test/3x/tensorflow/test_config.py index fa3d81cade5..287a86e4cff 100644 --- a/test/3x/tensorflow/test_config.py +++ b/test/3x/tensorflow/test_config.py @@ -297,5 +297,24 @@ def test_config_to_dict(self): self.assertIn("local", config_dict) +class TestQuantConfigForAutotune(unittest.TestCase): + def test_expand_config(self): + # test the expand functionalities, the user is not aware it + from neural_compressor.tensorflow import StaticQuantConfig + + quant_configs = StaticQuantConfig( + weight_dtype="int8", + weight_sym=True, + weight_granularity=["per_channel", "per_tensor"], + act_dtype="int8", + act_sym=True, + act_granularity="per_channel", + ) + + expand_config_list = StaticQuantConfig.expand(quant_configs) + self.assertEqual(expand_config_list[0].weight_granularity, "per_channel") + self.assertEqual(expand_config_list[1].weight_granularity, "per_tensor") + + if __name__ == "__main__": unittest.main() diff --git a/test/3x/torch/test_autotune.py b/test/3x/torch/test_autotune.py new file mode 100644 index 00000000000..876311355f1 --- /dev/null +++ b/test/3x/torch/test_autotune.py @@ -0,0 +1,118 @@ +import unittest + +import transformers + +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() +from functools import wraps + +import torch + + +def reset_tuning_target(test_func): + @wraps(test_func) + def wrapper(*args, **kwargs): + # Reset tuning targets before running the test + from neural_compressor.common.base_tune import tuning_objectives + + tuning_objectives.eval_fn_registry = [] + return test_func(*args, **kwargs) + + return wrapper + + +def build_simple_torch_model(): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(30, 50) + self.fc2 = torch.nn.Linear(50, 30) + self.fc3 = torch.nn.Linear(30, 5) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + out = self.fc3(out) + return out + + model = Model() + return model + + +class TestAutoTune(unittest.TestCase): + @classmethod + def setUpClass(self): + self.fp32_model = build_simple_torch_model() + self.input = torch.randn(1, 30) + self.gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + ) + self.lm_input = torch.ones([1, 10], dtype=torch.long) + + @classmethod + def tearDownClass(self): + pass + + def setUp(self): + # print the test name + logger.info(f"Running TestAutoTune test: {self.id()}") + + @reset_tuning_target + def test_autotune_api(self): + logger.info("test_autotune_api") + from neural_compressor.common.base_tune import tuning_objectives + from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune + + def eval_acc_fn(model) -> float: + return 1.0 + + custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2) + best_model = autotune( + model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] + ) + self.assertIsNone(best_model) + self.assertEqual(len(tuning_objectives.eval_fn_registry), 1) + + @reset_tuning_target + def test_autotune_api_2(self): + logger.info("test_autotune_api") + from neural_compressor.common.base_tune import tuning_objectives + from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune + + def eval_acc_fn(model) -> float: + return 1.0 + + def eval_perf_fn(model) -> float: + return 1.0 + + eval_fns = [ + {"eval_fn": eval_acc_fn, "weight": 0.5, "name": "accuracy"}, + { + "eval_fn": eval_perf_fn, + "weight": 0.5, + }, + ] + + custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2) + best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_fns) + self.assertIsNone(best_model) + self.assertEqual(len(tuning_objectives.eval_fn_registry), 2) + + @reset_tuning_target + def test_autotune_not_eval_func(self): + logger.info("test_autotune_api") + from neural_compressor.torch import RTNWeightQuantConfig, TuningConfig, autotune + + custom_tune_config = TuningConfig(quant_configs=[RTNWeightQuantConfig(weight_bits=[4, 6])], max_trials=2) + + # Use assertRaises to check that an AssertionError is raised + with self.assertRaises(AssertionError) as context: + best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config) + self.assertEqual( + str(context.exception), "Please ensure that you register at least one evaluation metric for auto-tune." + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index e00d6cdad17..0e0925685c0 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -324,5 +324,16 @@ def test_gptq_config(self): self.assertEqual(gptq_config1.to_dict(), gptq_config2.to_dict()) +class TestQuantConfigForAutotune(unittest.TestCase): + def test_expand_config(self): + # test the expand functionalities, the user is not aware it + from neural_compressor.torch import RTNWeightQuantConfig + + tune_config = RTNWeightQuantConfig(weight_bits=[4, 6]) + expand_config_list = RTNWeightQuantConfig.expand(tune_config) + self.assertEqual(expand_config_list[0].weight_bits, 4) + self.assertEqual(expand_config_list[1].weight_bits, 6) + + if __name__ == "__main__": unittest.main()