From 039af39eeccba9705d4865917a38553f701fcc7b Mon Sep 17 00:00:00 2001 From: Nir David Date: Thu, 25 Jul 2024 12:18:23 +0300 Subject: [PATCH] [SW-194200] Save scale file only with new scales Change-Id: I14a4ef94d188b13c2fbf4ea77d2b42cb5bd6d952 --- neural_compressor/torch/algorithms/fp8_quant/_core/scale.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py index c0b51cd9e74..910e05370b2 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py @@ -103,6 +103,7 @@ def get_config( ) scales = convert_scales_to_tensors_dict(scales_obj, scales_file_format, params["hp_dtype"]) model_dict = dict(model.named_modules()) + save_file = False for mname in mod_list: mod = model_dict[mname] set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module @@ -123,6 +124,7 @@ def get_config( scales_obj[mname] = ModuleConfig( **format_functions_rec((torch.Tensor, scales_file_format))(scales[mname].__dict__) ) + save_file = True logger.debug( "Preparing quantization functions for layer %s layer_type=%s", @@ -138,7 +140,7 @@ def get_config( params, ) qconfig[mname] = mod_extra_config - if scales_file is not None: + if save_file and scales_file is not None: save_scales(model, scales_obj, scales_file_format, scales_file + ".npz") save_scales(model, scales_obj, scales_file_format, scales_file + ".json") return qconfig