Skip to content

Commit

Permalink
[SW-191317] Raise exception according to hqt config object
Browse files Browse the repository at this point in the history
Change-Id: I06ba8fa912c811c88912987c11e5c12ef328348a
  • Loading branch information
ulivne committed Jul 3, 2024
1 parent 52a98f4 commit 768c2a4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 7 additions & 1 deletion neural_compressor/torch/algorithms/fp8_quant/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@

def save_calib_result(model):
import habana_quantization_toolkit as hqt
hqt.finish_measurements(model)
if (hasattr(model, "__hqt_config__") and
isinstance(model.__hqt_config__, hqt._quant_common.quant_config.Fp8cfg)):
# TODO SW-184714 modify hqt notation to inc notation once code is ported
hqt.finish_measurements(model)
else:
raise NotImplementedError("Saving calibration results currently supported only in HPU.")



def update_mode(config_path, measure_step=False, quant_step=False):
Expand Down
4 changes: 0 additions & 4 deletions neural_compressor/torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,5 @@ def convert(


def finalize_calibration(model):
if hasattr(model, "quant_config") and isinstance(model.quant_config, FP8Config): # FP8
from neural_compressor.torch.algorithms.fp8_quant import save_calib_result

save_calib_result(model)
else:
raise NotImplementedError("`finalize_calibration` only supports FP8 measurement now.")

0 comments on commit 768c2a4

Please sign in to comment.