Skip to content

Commit

Permalink
increase sq auto alpha running speed (#1399)
Browse files Browse the repository at this point in the history
Signed-off-by: Guo, Heng <[email protected]>
  • Loading branch information
n1ck-guo authored Nov 22, 2023
1 parent fcbac41 commit 173c188
Showing 1 changed file with 51 additions and 9 deletions.
60 changes: 51 additions & 9 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@
logger = logging.getLogger()
from collections import UserDict, defaultdict

from tqdm import tqdm


def enough_memo_store_scale(device, need_space):
if device == "cuda": # pragma: no cover
current_gpu_index = torch.cuda.current_device()
total_memory = torch.cuda.get_device_properties(current_gpu_index).total_memory
used_memory = torch.cuda.memory_allocated(current_gpu_index)
free_space = total_memory - used_memory
else:
import psutil

free_space = psutil.virtual_memory().free
return free_space >= need_space


def move_input_to_device(input, device=torch.device("cpu")):
if isinstance(input, dict) or isinstance(input, UserDict):
Expand Down Expand Up @@ -333,6 +348,9 @@ def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, tra
self.weight_clip = True
self.default_alpha = 0.5

self._save_scale = False
self.weight_scale_dict = {}

def _get_device(self):
"""Get the model device
:return:Model device."""
Expand Down Expand Up @@ -562,12 +580,7 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
weight_scales_info = {}
absorb_scales_info = {}
for index, key in enumerate(absorb_to_layer.keys()):
if isinstance(alpha, float):
alpha_tmp = alpha
elif isinstance(alpha, dict):
alpha_tmp = alpha[key]
else:
alpha_tmp = alpha
alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha
if alpha_tmp < 0:
scale = torch.ones((1), device=self.device)
else:
Expand All @@ -591,13 +604,24 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
self.max_value_info[key]["absorbed_layer"] = layer_names
continue

scale = cal_scale(input_max, weights, alpha_tmp)
if self._save_scale:
if key in self.weight_scale_dict and alpha_tmp in self.weight_scale_dict[key]:
scale = self.weight_scale_dict[key][alpha_tmp]
else:
scale = cal_scale(input_max, weights, alpha_tmp)
else:
scale = cal_scale(input_max, weights, alpha_tmp)

absorb_scales_info[key] = 1.0 / scale
absorb_scales_info[key][scale == 0] = 0
layer_names = absorb_to_layer[key]
for layer_name in layer_names:
##self._scale_layer_weight(layer_name, scale)
weight_scales_info[layer_name] = scale
if self._save_scale:
if layer_name not in self.weight_scale_dict:
self.weight_scale_dict[layer_name] = {}
self.weight_scale_dict[layer_name][alpha_tmp] = scale
return absorb_scales_info, weight_scales_info

def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
Expand Down Expand Up @@ -869,8 +893,9 @@ def _auto_tune_alpha(
logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.")
self._qdq_model_unwrapper_for_auto()
return best_alphas
bar = tqdm(self.dataloader, total=calib_sample_num, desc="auto tune alpha")
try:
for input, label in self.dataloader:
for input, label in bar:
loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
Expand Down Expand Up @@ -899,10 +924,12 @@ def _auto_tune_alpha(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
# does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change
# self.weight_scale_dict = {}
if total_cnt >= calib_sample_num:
break
except:
for input in self.dataloader:
for input in bar:
loss_alphas = {}
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
Expand Down Expand Up @@ -932,6 +959,7 @@ def _auto_tune_alpha(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
# self.weight_scale_dict = {}
if total_cnt >= calib_sample_num:
break

Expand Down Expand Up @@ -1036,6 +1064,18 @@ def transform(
for d in diff_modules:
del self.absorb_to_layer[d]

scale_memo_use = 0
for key in self.absorb_to_layer:
layer_name = self.absorb_to_layer[key][0]
input_max = input_maxes_abs[layer_name]
scale_memo_use += 4 * input_max.shape[0] * len(self.absorb_to_layer[key])
if alpha == "auto":
alpha_space = (auto_alpha_args["alpha_max"] - auto_alpha_args["alpha_min"]) / auto_alpha_args[
"alpha_step"
] + 1
scale_memo_use *= alpha_space
self._save_scale = enough_memo_store_scale(self.device, scale_memo_use)

if alpha == "auto":
self.alpha_per_layer = self._auto_tune_alpha(
input_maxes_abs, calib_sample_num=32, **auto_alpha_args
Expand All @@ -1047,6 +1087,8 @@ def transform(
if example_inputs is not None:
out_pre_sq = model_forward_per_sample(self.model, example_inputs, self.device)

if folding:
self._save_scale = False
if self.record_max_info:
# max_info is recorded in self.max_value_info
self._adjust_parameters(self.absorb_to_layer, input_maxes_abs, alpha)
Expand Down

0 comments on commit 173c188

Please sign in to comment.