diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 5e9e72a8882..35b0f532738 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -436,6 +436,13 @@ def _is_op_type(name: str) -> bool: def get_config_set_for_tuning(cls): raise NotImplementedError + def __eq__(self, other: BaseConfig) -> bool: + if not isinstance(other, type(self)): + return False + return self.params_list == other.params_list and all( + getattr(self, str(attr)) == getattr(other, str(attr)) for attr in self.params_list + ) + class ComposableConfig(BaseConfig): name = COMPOSABLE_CONFIG diff --git a/neural_compressor/common/base_tuning.py b/neural_compressor/common/base_tuning.py index 54f908232ad..2a1adfa480b 100644 --- a/neural_compressor/common/base_tuning.py +++ b/neural_compressor/common/base_tuning.py @@ -231,13 +231,29 @@ def __len__(self) -> int: class ConfigLoader: - def __init__(self, config_set: ConfigSet, sampler: Sampler = default_sampler) -> None: + def __init__( + self, config_set: ConfigSet, sampler: Sampler = default_sampler, skip_verified_config: bool = True + ) -> None: self.config_set = ConfigSet.from_fwk_configs(config_set) self._sampler = sampler(self.config_set) + self.skip_verified_config = skip_verified_config + self.verify_config_list = list() + + def is_verified_config(self, config): + for verified_config in self.verify_config_list: + if config == verified_config: + return True + return False def __iter__(self) -> Generator[BaseConfig, Any, None]: for index in self._sampler: - yield self.config_set[index] + new_config = self.config_set[index] + if self.skip_verified_config and self.is_verified_config(new_config): + logger.warning("Skip the verified config:") + logger.warning(new_config.to_dict()) + continue + self.verify_config_list.append(new_config) + yield new_config class TuningConfig: diff --git a/neural_compressor/common/tuning_param.py b/neural_compressor/common/tuning_param.py index 207811590ee..3f6d9272e4f 100644 --- a/neural_compressor/common/tuning_param.py +++ b/neural_compressor/common/tuning_param.py @@ -98,3 +98,6 @@ def is_tunable(self, value: Any) -> bool: except Exception as e: logger.debug(f"Failed to validate the input_args: {e}") return False + + def __str__(self) -> str: + return self.name diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 3b516977bb7..41d2593c224 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -41,7 +41,7 @@ def rtn_entry( configs_mapping: Dict[Tuple[str, callable], RTNConfig], mode: Mode = Mode.QUANTIZE, *args, - **kwargs + **kwargs, ) -> torch.nn.Module: """The main entry to apply rtn quantization.""" from neural_compressor.torch.algorithms.weight_only.rtn import RTNQuantizer @@ -258,7 +258,7 @@ def awq_quantize_entry( configs_mapping: Dict[Tuple[str, callable], AWQConfig], mode: Mode = Mode.QUANTIZE, *args, - **kwargs + **kwargs, ) -> torch.nn.Module: logger.info("Quantize model with the AWQ algorithm.") from neural_compressor.torch.algorithms.weight_only.awq import AWQQuantizer @@ -455,7 +455,7 @@ def hqq_entry( configs_mapping: Dict[Tuple[str, Callable], HQQConfig], mode: Mode = Mode.QUANTIZE, *args, - **kwargs + **kwargs, ) -> torch.nn.Module: from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer diff --git a/test/3x/common/test_common.py b/test/3x/common/test_common.py index d1df7d98b1d..9eb89b9ef23 100644 --- a/test/3x/common/test_common.py +++ b/test/3x/common/test_common.py @@ -336,6 +336,14 @@ def test_config_loader(self) -> None: for i, config in enumerate(self.loader): self.assertEqual(config, self.config_set[i]) + def test_config_loader_skip_verified_config(self) -> None: + config_set = [FakeAlgoConfig(weight_bits=[4, 8]), FakeAlgoConfig(weight_bits=8)] + config_loader = ConfigLoader(config_set) + config_count = 0 + for i, config in enumerate(config_loader): + config_count += 1 + self.assertEqual(config_count, 2) + if __name__ == "__main__": unittest.main() diff --git a/test/3x/torch/quantization/weight_only/test_mixed_algos.py b/test/3x/torch/quantization/weight_only/test_mixed_algos.py index d465f8cd9c3..b4789f6c5d9 100644 --- a/test/3x/torch/quantization/weight_only/test_mixed_algos.py +++ b/test/3x/torch/quantization/weight_only/test_mixed_algos.py @@ -10,11 +10,8 @@ def run_fn(model): - # GPTQ uses ValueError to reduce computation when collecting input data of the first block - # It's special for UTs, no need to add this wrapper in examples. - with pytest.raises(ValueError): - model(torch.tensor([[10, 20, 30]], dtype=torch.long)) - model(torch.tensor([[40, 50, 60]], dtype=torch.long)) + model(torch.tensor([[10, 20, 30]], dtype=torch.long)) + model(torch.tensor([[40, 50, 60]], dtype=torch.long)) class TestMixedTwoAlgo: