Skip to content

Commit

Permalink
[SW-189684] Add description to functions in HQT
Browse files Browse the repository at this point in the history
Change-Id: Id5822a21abd1f60f28999574c2ca0e89acc70bf6
  • Loading branch information
Yantom1 committed Jul 29, 2024
1 parent 7bf9521 commit ad0625b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
26 changes: 26 additions & 0 deletions neural_compressor/torch/algorithms/fp8_quant/_core/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@


def patch_module_measure(mod, mconfig, mod_dict):
"""Replaces the module with patched module according to mconfig.
Args:
mod (nn.module): The module that will be replaced with patched module that measures the inputs.
mconfig (e.g. MaxAbsObserver/MaxAbsPerChannelObserver): The observer object that will measure the parameters.
mod_dict (dict): dictionary from module name to its patched module.
Returns:
nn.module: The new module after patching.
"""
parent = parent_child_mod_dict[mod].parent
name = parent_child_mod_dict[mod].name
patched_mod = mod_dict[mod.__class__.__name__].patched_module(mod, mconfig, name)
Expand Down Expand Up @@ -72,6 +82,12 @@ def init_measure_object(mod, name, observer_class, mod_type, skip_measure_output


def prepare_model(model, mod_list=None):
"""Defines the observer class and modules for measurement as preparation.
Args:
model (nn.module): The model that will be measured.
mod_list (list, optional): The specific submodules that will be measured in the model. Defaults to None.
"""
config = get_hqt_config(model).cfg
observer_class = observer_types[config["observer"]]
if (config["shape_file"] is not None) and (observer_class != ShapeObserver):
Expand All @@ -85,6 +101,16 @@ def prepare_model(model, mod_list=None):


def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=None):
"""Replace the submodules of the model that appear in mod_list with a patched submodule that uses the given observer_class
so the submodule will preform measurement on inputs/outputs in forward stage.
Weights measurement is done during model preparation as they are static.
Args:
model (nn.module): The model that will be measured.
mod_list (list): The specific submodules that will be measured in the model.
observer_class (e.g. MaxAbsObserver/MaxAbsPerChannelObserver): The observer type that will measure the weights.
d_shapes (dict, optional): Defaults to None.
"""
top_level_config = get_hqt_config(model)
config = top_level_config.cfg
skip_outputs_measurements = config["measure_exclude"] & (MeasureExclude.OUTPUT | MeasureExclude.ALL)
Expand Down
33 changes: 33 additions & 0 deletions neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@


def patch_module(mod, qconfig, mod_dict, patched_mod=None):
"""Replaces the module with patched module according to mod_dict.
Args:
mod (nn.module): The module that will be replaced with a patched module that quantize the inputs/outputs.
qconfig (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs.
mod_dict (dict): dictionary from module name to its patched module.
Returns:
nn.module: The new patched module after patching.
"""
parent = parent_child_mod_dict[mod].parent
name = parent_child_mod_dict[mod].name
if patched_mod is None:
Expand All @@ -33,6 +43,8 @@ def patch_module(mod, qconfig, mod_dict, patched_mod=None):


def apply_hf_hook(module):
"""Applies hf_hook on a given module so its weights will be loaded from disk to cpu and then we can quantize it.
"""
if hasattr(module, "_hf_hook"):
module._hf_hook.pre_forward(module)
module._hf_hook.detach_hook(module)
Expand All @@ -43,6 +55,12 @@ def apply_hf_hook(module):


def quantize_params(mod, mod_extra_config):
"""Quantizes the weights of the given module according to the quantization info from mod_extra_config.
Args:
mod (nn.module): The module that its weights will be quantized.
mod_extra_config (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs.
"""
for param_name in mod_extra_config.params:
quantizer = mod_extra_config.params[param_name]
param = getattr(mod, param_name)
Expand All @@ -55,6 +73,15 @@ def quantize_params(mod, mod_extra_config):


def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float):
"""Replaces the model submodules according to the mod_list with patched quantization modules.
Configures patched modules with the quantization/dequantization methods to apply on their input and output tensors.
Quantizes the model parameters as they are static.
Args:
model (nn.module): The model to quantize.
qconfig (dict): Dict that maps between patched module and its quantization info.
mod_list (list): The specific submodules that will be quantized in the model.
"""
config = get_hqt_config(model)
patched_modules = []
patched_module_types = set()
Expand Down Expand Up @@ -82,6 +109,12 @@ def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float):


def quantize(model, mod_list):
"""Builds quantization config object that contains for each submodule its quantization functions as preparation for quantization.
Args:
model (nn.module): The model that will be quantized.
mod_list (list, optional): The specific modules that will be quantized in the model.
"""
config = get_hqt_config(model)
generate_model_info(model)
hp_dtype = config.cfg["hp_dtype"]
Expand Down
7 changes: 7 additions & 0 deletions neural_compressor/torch/algorithms/fp8_quant/_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def is_substr(substr_list, target):


def prepare_model(model):
"""Receives the parent module to quantize.
Replaces its submodules with patched submodules that perform calibration and quantization.
Returns the patched parent module that can perform calibration or quantization according to the configuration.
Args:
model (nn.module): The model that will be measured/quantized.
"""
config = get_hqt_config(model)
update_mod_dict(config)
allowlist = set(config.cfg["mod_dict"].keys())
Expand Down

0 comments on commit ad0625b

Please sign in to comment.