diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 7d8a0e069ea..9d3893fad81 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -144,18 +144,13 @@ def __init__(self, model) -> None: self.device = "cpu" self.orig_model = model + def forward(self, *args, **kwargs): - if kwargs and len(args) == 0: - # Handle cases where input data is a dict + if args: + for arg in args: + self.data_pairs.append(arg) + if kwargs: self.data_pairs.append(kwargs) - elif args and len(args) == 1: - # Handle cases where input data is a Tensor - self.data_pairs.append(args[0]) - else: - logger.error("Handle cases where input data is neither a Tensor nor a dict") - return self.orig_model.forward(*args, **kwargs) - - class AutoRoundQuantizer(Quantizer): def __init__( self, @@ -256,9 +251,16 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs): Returns: A prepared model. """ - self.rounder = AutoRoundProcessor( + prepare_model = InputCaptureModule(model) + return prepare_model + + def convert(self, model: torch.nn.Module, *args, **kwargs): + dataloader = torch.utils.data.DataLoader(model.data_pairs) + model = model.orig_model + rounder = AutoRound( model=model, tokenizer=None, + dataset=dataloader, weight_config=self.quant_config or {}, enable_full_range=self.enable_full_range, batch_size=self.batch_size, @@ -282,15 +284,7 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs): data_type=self.data_type, scale_dtype=self.scale_dtype, ) - - self.rounder.prepare() - prepare_model = InputCaptureModule(model) - return prepare_model - - def convert(self, model: torch.nn.Module, *args, **kwargs): - self.rounder.model_input = model.data_pairs - model = model.orig_model - model, weight_config = self.rounder.convert() + model, weight_config = rounder.quantize() model.autoround_config = weight_config if self.device == "xpu": model = pack_model(model, weight_config, device=self.device, inplace=True, @@ -384,183 +378,3 @@ def get_autoround_default_run_fn( f"Valid samples size:{total_cnt}, Target sample size:{n_samples}" ) - -class AutoRoundProcessor(AutoRound): - - @torch.no_grad() - def cache_inter_data(self, block_names, n_samples, layer_names=[], last_cache_name=None): - """Save the inputs of block_name for calibration. For layers, we cache both of inputs and output. - - This method temporarily replaces the forward method of the model to capture - the inputs passing through the specified block. It then calibrates the model - using a specified number of samples. Finally, it restores the original forward - method and returns the inputs for the specified block. - Args: - block_names (list): The names of the blocks for which inputs are to be saved. - layer_names (list):The names of the layers for which inputs are to be saved. - n_samples (int): The number of samples to use for calibration. - last_cache_name (str, optional): The name of the last layer to be cached, - we could break the forward in this layer to save time - - Returns: - dict: A dictionary containing the inputs for the specified block. - """ - self.inputs = {} - self.to_cached_layers = block_names + layer_names - tmp_dtype = None - ## have bug if block name is not the first block - if (len(block_names) > 1 or len(layer_names) > 0) and self.low_gpu_mem_usage: - tmp_dtype = self.model.dtype - self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32) - - self.last_cache_name = last_cache_name - if last_cache_name is None and len(block_names) + len(layer_names) == 1: - self.last_cache_name = block_names[0] if len(block_names) == 1 else layer_names[0] - calib_bs = self.train_bs - self.hook_handles = [] - self._replace_forward() - for data in self.model_input: - if isinstance(data, torch.Tensor): - self.model(data) - else: - self.model(**data) - self._recover_forward() - res = self.inputs - del self.model_input - del self.last_cache_name - del self.to_cached_layers - if tmp_dtype is not None: - self.model = self.model.to(tmp_dtype) - - return res - - @torch.no_grad() - def prepare(self): - """Prepares a given model for quantization.""" - self.block_names = get_block_names(self.model) - if len(self.block_names) == 0: - logger.warning("could not find blocks, exit with original model") - return self.model, self.weight_config - - if self.amp: - self.model = self.model.to(self.amp_dtype) - - self.layer_names = self.get_quantized_layer_names_outside_blocks() - self.start_time = time.time() - # all_inputs = self.try_cache_inter_data_gpucpu([block_names[0]], self.n_samples, layer_names=layer_names) - - # try_cache_inter_data_gpucpu - # ([block_names[0]], self.n_samples, layer_names=layer_names) - # self, block_names, n_samples, layer_names=[], last_cache_name=None - last_cache_name = None - cache_block_names = [self.block_names[0]] - try: - self.model = self.model.to(self.device) - # all_inputs = self.cache_inter_data( - # block_names[0], self.n_samples, layer_names=layer_names, last_cache_name=last_cache_name - # ) - # cache_inter_data cache_inter_data(self, block_names, n_samples, layer_names=[], last_cache_name=None): - self.inputs = {} - self.to_cached_layers = cache_block_names + self.layer_names - self.tmp_dtype = None - ## have bug if block name is not the first block - if (len(cache_block_names) > 1 or len(self.layer_names) > 0) and self.low_gpu_mem_usage: - self.tmp_dtype = self.model.dtype - self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32) - - self.last_cache_name = last_cache_name - if last_cache_name is None and len(cache_block_names) + len(self.layer_names) == 1: - self.last_cache_name = cache_block_names[0] if len(cache_block_names) == 1 else self.layer_names[0] - # calib_bs = self.train_bs - self.hook_handles = [] - self._replace_forward() - self.prepared_gpu = True - # self.calib(self.n_samples, calib_bs) - - except: - logger.info("switch to cpu to cache inputs") - self.model = self.model.to("cpu") - torch.cuda.empty_cache() - # all_inputs = self.cache_inter_data( - # self.block_names[0], self.n_samples, layer_names=self.layer_names, last_cache_name=last_cache_name - # ) - self.inputs = {} - self.to_cached_layers = cache_block_names + self.layer_names - self.tmp_dtype = None - ## have bug if block name is not the first block - if (len(cache_block_names) > 1 or len(self.layer_names) > 0) and self.low_gpu_mem_usage: - self.tmp_dtype = self.model.dtype - self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32) - - self.last_cache_name = last_cache_name - if last_cache_name is None and len(cache_block_names) + len(self.layer_names) == 1: - self.last_cache_name = cache_block_names[0] if len(cache_block_names) == 1 else self.layer_names[0] - # calib_bs = self.train_bs - self.hook_handles = [] - self._replace_forward() - cache_block_names - # self.calib(n_samples, calib_bs) - - def convert(self): - """Converts a prepared model to a quantized model.""" - self._recover_forward() - res = self.inputs - del self.last_cache_name - del self.to_cached_layers - if self.tmp_dtype is not None: - self.model = self.model.to(self.tmp_dtype) - if self.prepared_gpu is True: - self.model = self.model.to("cpu") - - all_inputs = res - - del self.inputs - inputs = all_inputs[self.block_names[0]] - - all_inputs.pop(self.block_names[0]) - self.inputs = None - del self.inputs - if "input_ids" in inputs.keys(): - total_samples = len(inputs["input_ids"]) - self.n_samples = total_samples - if total_samples < self.train_bs: - self.train_bs = total_samples - logger.warning(f"force the train batch size to {total_samples} ") - - self.model = self.model.to("cpu") - torch.cuda.empty_cache() - self.quant_blocks( - self.model, - inputs, - self.block_names, - n_blocks=self.n_blocks, - device=self.device, - ) - - self.quant_layers(self.layer_names, all_inputs) - - self.dump_data_to_weight_config() - - end_time = time.time() - cost_time = end_time - self.start_time - logger.info(f"quantization tuning time {cost_time}") - - ## dump a summary - quantized_layers = [] - unquantized_layers = [] - for n, m in self.model.named_modules(): - if isinstance(m, tuple(self.supported_types)): - if self.weight_config[n]["bits"] == 16: - unquantized_layers.append(n) - else: - quantized_layers.append(n) - summary_info = ( - f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model" - ) - if len(unquantized_layers) > 0: - summary_info += f", {unquantized_layers} have not been quantized" - logger.info(summary_info) - - self.quantized = True - ##self.model = self.model.to(self.model_orig_dtype)##keep it as amp dtype - return self.model, self.weight_config diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index d61e2a29000..e7455e246c6 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -46,8 +46,8 @@ def setup_method(self, method): @pytest.mark.parametrize("quant_lm_head", [True, False]) def test_autoround(self, quant_lm_head): - gpt_j_model = copy.deepcopy(self.gptj) - quant_config = AutoRoundConfig(n_samples=20, seqlen=10, iters=10, scale_dtype="fp32") + fp32_model = copy.deepcopy(self.gptj) + quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") if quant_lm_head is False: quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") @@ -56,10 +56,9 @@ def test_autoround(self, quant_lm_head): run_args = ( self.tokenizer, "NeelNanda/pile-10k", - 20, + 32, 10, ) - fp32_model = gpt_j_model # prepare + convert API model = prepare(model=fp32_model, quant_config=quant_config) @@ -78,7 +77,7 @@ def test_autoround(self, quant_lm_head): def test_autoround_with_quantize_API(self): gpt_j_model = copy.deepcopy(self.gptj) - quant_config = get_default_AutoRound_config() + quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") @@ -91,7 +90,7 @@ def test_autoround_with_quantize_API(self): run_args=( self.tokenizer, "NeelNanda/pile-10k", - 20, + 32, 10, ), ) @@ -101,7 +100,7 @@ def test_autoround_with_quantize_API(self): def test_save_and_load(self): fp32_model = copy.deepcopy(self.gptj) - quant_config = get_default_AutoRound_config() + quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") # quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") @@ -109,7 +108,7 @@ def test_save_and_load(self): run_args = ( self.tokenizer, "NeelNanda/pile-10k", - 20, + 32, 10, ) # quantizer execute @@ -144,10 +143,10 @@ def test_conv1d(self): run_args = ( tokenizer, "NeelNanda/pile-10k", - 20, + 32, 10, ) - quant_config = get_default_AutoRound_config() + quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") model = prepare(model=model, quant_config=quant_config) run_fn(model, *run_args) q_model = convert(model)