Skip to content

Commit

Permalink
fix bug in get/set_module (#1268)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <[email protected]>
  • Loading branch information
xin3he authored Sep 25, 2023
1 parent 4b920d5 commit dffcfe1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
28 changes: 13 additions & 15 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,13 @@ def get_module(model, key):
model (torch.nn.Module): original model
key (str): module name to be replaced
"""
attrs = key.split(".")
module = model
for attr in attrs:
try:
attr = int(attr)
module = module[attr]
except:
module = getattr(module, attr)
name_list = key.split(".")
for name in name_list:
if hasattr(module, name):
module = getattr(module, name)
else:
module = module
return module


Expand All @@ -193,15 +192,14 @@ def set_module(model, key, new_module):
key (str): module name to be replaced
new_module (torch.nn.Module): new module to be inserted
"""
attrs = key.split(".")
module = model
for attr in attrs[:-1]:
try:
attr = int(attr)
module = module[attr]
except:
module = getattr(module, attr)
setattr(module, attrs[-1], new_module)
name_list = key.split(".")
for name in name_list[:-1]:
if hasattr(module, name):
module = getattr(module, name)
else:
module = module
setattr(module, name_list[-1], new_module)


def cal_scale(input_max, weights, alpha, scale_type="orig"):
Expand Down
1 change: 0 additions & 1 deletion neural_compressor/adaptor/torch_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,6 @@ def set_module(model, op_name, new_module):
else:
module = module
setattr(module, name_list[-1], new_module)
return module


def simple_inference(model, input):
Expand Down

0 comments on commit dffcfe1

Please sign in to comment.