Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Unify the 'iters' and 'calib_iters' in AutoRound config
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng Penghui <[email protected]>
  • Loading branch information
PenghuiCheng committed Jul 2, 2024
1 parent 317b913 commit 67f93ae
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def default_calib_func(model):
"autoround_args": {
"n_samples": config.nsamples,
"seqlen": config.calib_len,
"iters": config.iters,
"iters": config.calib_iters,
"scale_dtype": config.scale_dtype,
"enable_quanted_input": not config.disable_quanted_input,
"lr": config.lr,
Expand Down
12 changes: 9 additions & 3 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ def __init__(
minmax_lr: float = None,
disable_quanted_input: bool = False,
nsamples: int = 512,
iters: int = 200,
iters: int = None,
use_ggml: bool = False,
use_neural_speed: bool = False,
llm_int8_skip_modules=None,
Expand All @@ -1091,7 +1091,6 @@ def __init__(
self.lr = lr
self.minmax_lr = minmax_lr
self.disable_quanted_input = disable_quanted_input
self.iters = iters
self.llm_int8_skip_modules = (
llm_int8_skip_modules if llm_int8_skip_modules else []
)
Expand All @@ -1101,7 +1100,14 @@ def __init__(
self.calib_dataloader = kwargs.get("calib_dataloader", None)
self.calib_len = kwargs.get("calib_len", 2048)
self.calib_func = kwargs.get("calib_func", None)
self.calib_iters = kwargs.get("calib_iters", 100)
calib_iters = kwargs.get("calib_iters", None)
if iters is not None:
self.calib_iters = iters
if calib_iters is not None:
logger.info("cannot be set simultaneously for 'iters' and 'calib_iters', "
"we will use 'iters' as calibration iterations!")
else:
self.calib_iters = 200 if calib_iters is None else calib_iters
self.scheme = "sym" if self.sym else "asym"
if isinstance(compute_dtype, torch.dtype):
self.compute_dtype = convert_dtype_torch2str(compute_dtype)
Expand Down

0 comments on commit 67f93ae

Please sign in to comment.