diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py b/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py index f482d759237..66f4e28bdd6 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py @@ -28,6 +28,16 @@ def patch_module_measure(mod, mconfig, mod_dict): + """Replaces the module with patched module according to mconfig. + + Args: + mod (nn.module): The module that will be replaced with patched module that measures the inputs. + mconfig (e.g. MaxAbsObserver/MaxAbsPerChannelObserver): The observer object that will measure the parameters. + mod_dict (dict): dictionary from module name to its patched module. + + Returns: + nn.module: The new module after patching. + """ parent = parent_child_mod_dict[mod].parent name = parent_child_mod_dict[mod].name patched_mod = mod_dict[mod.__class__.__name__].patched_module(mod, mconfig, name) @@ -72,6 +82,12 @@ def init_measure_object(mod, name, observer_class, mod_type, skip_measure_output def prepare_model(model, mod_list=None): + """Defines the observer class and modules for measurement as preparation. + + Args: + model (nn.module): The model that will be measured. + mod_list (list, optional): The specific submodules that will be measured in the model. Defaults to None. + """ config = get_hqt_config(model).cfg observer_class = observer_types[config["observer"]] if (config["shape_file"] is not None) and (observer_class != ShapeObserver): @@ -85,6 +101,16 @@ def prepare_model(model, mod_list=None): def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=None): + """Replace the submodules of the model that appear in mod_list with a patched submodule that uses the given observer_class + so the submodule will preform measurement on inputs/outputs in forward stage. + Weights measurement is done during model preparation as they are static. + + Args: + model (nn.module): The model that will be measured. + mod_list (list): The specific submodules that will be measured in the model. + observer_class (e.g. MaxAbsObserver/MaxAbsPerChannelObserver): The observer type that will measure the weights. + d_shapes (dict, optional): Defaults to None. + """ top_level_config = get_hqt_config(model) config = top_level_config.cfg skip_outputs_measurements = config["measure_exclude"] & (MeasureExclude.OUTPUT | MeasureExclude.ALL) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py index efe412cc16c..889208e94a3 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py @@ -25,6 +25,16 @@ def patch_module(mod, qconfig, mod_dict, patched_mod=None): + """Replaces the module with patched module according to mod_dict. + + Args: + mod (nn.module): The module that will be replaced with a patched module that quantize the inputs/outputs. + qconfig (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs. + mod_dict (dict): dictionary from module name to its patched module. + + Returns: + nn.module: The new patched module after patching. + """ parent = parent_child_mod_dict[mod].parent name = parent_child_mod_dict[mod].name if patched_mod is None: @@ -33,6 +43,8 @@ def patch_module(mod, qconfig, mod_dict, patched_mod=None): def apply_hf_hook(module): + """Applies hf_hook on a given module so its weights will be loaded from disk to cpu and then we can quantize it. + """ if hasattr(module, "_hf_hook"): module._hf_hook.pre_forward(module) module._hf_hook.detach_hook(module) @@ -43,6 +55,12 @@ def apply_hf_hook(module): def quantize_params(mod, mod_extra_config): + """Quantizes the weights of the given module according to the quantization info from mod_extra_config. + + Args: + mod (nn.module): The module that its weights will be quantized. + mod_extra_config (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs. + """ for param_name in mod_extra_config.params: quantizer = mod_extra_config.params[param_name] param = getattr(mod, param_name) @@ -55,6 +73,15 @@ def quantize_params(mod, mod_extra_config): def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float): + """Replaces the model submodules according to the mod_list with patched quantization modules. + Configures patched modules with the quantization/dequantization methods to apply on their input and output tensors. + Quantizes the model parameters as they are static. + + Args: + model (nn.module): The model to quantize. + qconfig (dict): Dict that maps between patched module and its quantization info. + mod_list (list): The specific submodules that will be quantized in the model. + """ config = get_hqt_config(model) patched_modules = [] patched_module_types = set() @@ -82,6 +109,12 @@ def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float): def quantize(model, mod_list): + """Builds quantization config object that contains for each submodule its quantization functions as preparation for quantization. + + Args: + model (nn.module): The model that will be quantized. + mod_list (list, optional): The specific modules that will be quantized in the model. + """ config = get_hqt_config(model) generate_model_info(model) hp_dtype = config.cfg["hp_dtype"] diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py b/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py index 30635109c2e..ae5dad489e7 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py @@ -42,6 +42,13 @@ def is_substr(substr_list, target): def prepare_model(model): + """Receives the parent module to quantize. + Replaces its submodules with patched submodules that perform calibration and quantization. + Returns the patched parent module that can perform calibration or quantization according to the configuration. + + Args: + model (nn.module): The model that will be measured/quantized. + """ config = get_hqt_config(model) update_mod_dict(config) allowlist = set(config.cfg["mod_dict"].keys())