diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index 3d2883f9b5f..3dd02bf897a 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -174,14 +174,13 @@ def get_module(model, key): model (torch.nn.Module): original model key (str): module name to be replaced """ - attrs = key.split(".") module = model - for attr in attrs: - try: - attr = int(attr) - module = module[attr] - except: - module = getattr(module, attr) + name_list = key.split(".") + for name in name_list: + if hasattr(module, name): + module = getattr(module, name) + else: + module = module return module @@ -193,15 +192,14 @@ def set_module(model, key, new_module): key (str): module name to be replaced new_module (torch.nn.Module): new module to be inserted """ - attrs = key.split(".") module = model - for attr in attrs[:-1]: - try: - attr = int(attr) - module = module[attr] - except: - module = getattr(module, attr) - setattr(module, attrs[-1], new_module) + name_list = key.split(".") + for name in name_list[:-1]: + if hasattr(module, name): + module = getattr(module, name) + else: + module = module + setattr(module, name_list[-1], new_module) def cal_scale(input_max, weights, alpha, scale_type="orig"): diff --git a/neural_compressor/adaptor/torch_utils/util.py b/neural_compressor/adaptor/torch_utils/util.py index 62e27e9ec1e..9b63e51f03b 100644 --- a/neural_compressor/adaptor/torch_utils/util.py +++ b/neural_compressor/adaptor/torch_utils/util.py @@ -620,7 +620,6 @@ def set_module(model, op_name, new_module): else: module = module setattr(module, name_list[-1], new_module) - return module def simple_inference(model, input):