diff --git a/src/sparseml/modifiers/quantization/quantization/pytorch.py b/src/sparseml/modifiers/quantization/quantization/pytorch.py index 246fd3ce52a..63091a62bfb 100644 --- a/src/sparseml/modifiers/quantization/quantization/pytorch.py +++ b/src/sparseml/modifiers/quantization/quantization/pytorch.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Any +from typing import Any, Optional from torch.nn import Module @@ -22,6 +22,7 @@ freeze_module_quantization, set_module_for_calibration, ) +from compressed_tensors.quantization.observers.helpers import get_observer_token_count from sparseml.core import Event, EventType, State from sparseml.modifiers.quantization.quantization.base import QuantizationModifier from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward @@ -74,6 +75,9 @@ def on_initialize(self, state: State, **kwargs) -> bool: if self.calculate_start() == -1: # one-shot module.apply(set_module_for_calibration) self._calibrate_if_possible(module) + self._check_token_distribution( + module, threshold=kwargs.get("min_tokens_per_module") + ) module.apply(freeze_module_quantization) return True @@ -139,3 +143,41 @@ def _calibrate(self, module: Module): if module_training: module.train() + + def _check_token_distribution( + self, model: Module, threshold: Optional[float] = None + ): + """ + A helper function that warns when a module has seen + fewer than threshold % of all the tokens throughout + the calibration process. + + Checks are only triggered if threshold is not None. + + :param model: the model to validate + :param threshold: the minimum percentage of tokens + (out of all the tokens in a batch) a module should + receive during calibration + """ + if threshold is None: + _LOGGER.debug("Skipping token distribution check. threshold is None.") + return + + all_tokens = self.calibration_dataloader_.dataset["input_ids"] + total_token_count = sum(len(sample) for sample in all_tokens) + counter = get_observer_token_count(model) + for module_name, token_count in counter.items(): + if token_count is None: + # the module has not been observed + # or its token_count is not being recorded + # by the observer (refer to the observer's + # implementation in the source code) + continue + if token_count / total_token_count < threshold: + _LOGGER.warning( + f"The module_name: {module_name} " + f"received less than {int(threshold * 100)}% " + "of calibration batch tokens " + f"({token_count}/{total_token_count} tokens). " + "This could result may harm the quantization quality." + ) diff --git a/src/sparseml/transformers/finetune/data/data_args.py b/src/sparseml/transformers/finetune/data/data_args.py index 9517a19e4de..6f45e7446c4 100644 --- a/src/sparseml/transformers/finetune/data/data_args.py +++ b/src/sparseml/transformers/finetune/data/data_args.py @@ -167,3 +167,14 @@ class DataTrainingArguments(CustomDataTrainingArguments): ), }, ) + min_tokens_per_module: Optional[float] = field( + default=0.2, + metadata={ + "help": ( + "The minimum percentage of tokens (out of the total number) " + "that the module should 'receive' throughout the forward " + "pass of the calibration. If a module receives fewer tokens, " + "a warning will be logged." + ), + }, + ) diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 7436261980e..5811909cd07 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -116,6 +116,8 @@ def __init__( if self.is_fsdp_enabled: self._prepare_model_for_fsdp() + self.min_tokens_per_module = data_args.min_tokens_per_module + def initialize_session( self, epoch: float, @@ -416,6 +418,7 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None): start=-1, copy_data=False, accelerator=self.accelerator, + min_tokens_per_module=self.min_tokens_per_module, ) # log model sparsity