Skip to content

Commit

Permalink
Pytorch-Pruner-bugfix-patternlock-channelpruning (#1355)
Browse files Browse the repository at this point in the history
  • Loading branch information
YIYANGCAI authored Nov 3, 2022
1 parent 67c7073 commit f46bb12
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 72 deletions.
165 changes: 140 additions & 25 deletions neural_compressor/experimental/pytorch_pruner/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,16 @@ def get_mask_single(self, score, exact_sparsity_ratio):
one = torch.tensor([1.]).to(score.device)
mask = torch.where(score <= threshold, zero, one)
else:
mask = torch.ones(score.shape,device=score.device)
mask = torch.ones(score.shape, device=score.device)
return mask

def get_block_size_dict(self, data):
raise NotImplementedError

def get_masks_local(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer):
masks = {}
if isinstance(self, PatternNxM) and not isinstance(self.block_size, dict):
self.block_size = self.get_block_size_dict(pre_masks)
for key in scores.keys():
score = {key: scores[key]}
pre_mask = {key: pre_masks[key]}
Expand All @@ -89,19 +93,58 @@ def get_sparsity_ratio(self, pre_masks):
total_cnt += pre_masks.numel()
return float(zero_cnt) / total_cnt

def get_pattern_lock_masks(self, modules):
"""
basic way is to simply lock the zero position
"""
pattern_lock_masks = {}
for key in modules.keys():
weight = modules[key].weight
shape = weight.shape
mask = torch.ones(shape)
mask[weight == 0] = 0.0
pattern_lock_masks[key] = mask.to(weight.device)
return pattern_lock_masks


@register_pattern('NxM')
class PatternNxM(Pattern):
def __init__(self, config):
super(PatternNxM, self).__init__(config)
pattern = self.pattern.split('_')[-1]
self.block_size = [int(pattern.split('x')[0]), int(pattern.split('x')[1])]
self.N = pattern.split('x')[0]
self.M = pattern.split('x')[1]
if self.N == "channel": ##channel-wise pruning mode
self.block_size = ["channel", int(self.M)]
elif self.M == "channel": ##channel-wise pruning mode
self.block_size = [int(self.N), "channel"]
else:
self.block_size = [int(pattern.split('x')[0]), int(pattern.split('x')[1])]

def get_block_size_dict(self, data):
block_sizes_dict = {}
if self.N == "channel" or self.M == "channel":
for key in data.keys():
if isinstance(data[key], torch.nn.Module):
shape = data[key].weight.shape
else:
shape = data[key].shape
if self.N == "channel":
block_sizes_dict[key] = [shape[0], 1]
else:
block_sizes_dict[key] = [1, shape[1]]
return block_sizes_dict
for key in data.keys():
block_sizes_dict[key] = self.block_size
return block_sizes_dict

def get_sparsity_ratio(self, pre_masks):
zero_cnt = 0
total_cnt = 0
block_size = self.block_size
if isinstance(self.block_size, list):
self.block_size = self.get_block_size_dict(pre_masks)
for key in pre_masks.keys():
block_size = self.block_size[key]
pre_mask = pre_masks[key]
shape = pre_mask.shape
if len(shape) == 4:
Expand All @@ -117,49 +160,55 @@ def get_sparsity_ratio(self, pre_masks):
total_cnt += pre_mask_sum.numel()
return float(zero_cnt) / total_cnt

def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer):
block_size = self.block_size
def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer,
keep_pre_mask=False):
if isinstance(self.block_size, list):
self.block_size = self.get_block_size_dict(scores)
new_scores = {}
not_divided_keys = []
for key in scores.keys():
block_size = self.block_size[key]
current_score = scores[key]
if len(current_score.shape) == 4: ##TODO need to verify whether it's ok for transposed conv
current_score = current_score.permute(0, 2, 3, 1) ##cout,k,k,cin
current_score = current_score.reshape(current_score.shape[0], -1)
shape = current_score.shape
if len(shape) == 4: ##default is conv, is transpose conv ok ?
shape = current_score.reshape(current_score.shape[0], -1).shape
if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0:

if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: ## only consider input channel
not_divided_keys.append(key)
continue

new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1], block_size[1]]
new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1],
block_size[1]]
current_score = current_score.reshape(new_shape)
current_score_sum = current_score.sum(-1).sum(1)
current_score_sum = current_score.mean(-1).mean(
1) ##TODO sum or mean is quite different for per channel pruning
new_scores[key] = current_score_sum
global_scores = torch.cat([torch.flatten(v) for v in new_scores.values()])
k = int(target_sparsity_ratio * global_scores.numel())
masks = {}
if not k < 1:
threshold, _ = torch.kthvalue(global_scores, k)
for key in new_scores.keys():
block_size = self.block_size[key]
score = new_scores[key]
zero = torch.tensor([0.]).to(score.device)
one = torch.tensor([1.]).to(score.device)
mask = torch.where(score <= threshold, zero, one)
mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1)
if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer:
if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer:
##to prevent some layer not be purned too much
##this is differnt with our original implementation
masks[key]=self.get_mask_single(new_scores[key],max_sparsity_ratio_per_layer)
masks[key]=masks[key].repeat_interleave(block_size[0], 0).repeat_interleave(block_size[1], -1)
masks[key] = self.get_mask_single(new_scores[key], max_sparsity_ratio_per_layer)
masks[key] = masks[key].repeat_interleave(block_size[0], 0).repeat_interleave(block_size[1], -1)
# if pre_masks != {}:##when use one shot, this is not right
# masks[key] = pre_masks[key]
# else:
# masks[key] = mask
else:
masks[key] = mask
if len(scores[key].shape)==4:
##we need to revert back
masks[key]=masks[key].reshape(scores[key].shape)
# if len(scores[key].shape) == 4:
# ##we need to revert back
# masks[key] = masks[key].reshape(scores[key].shape)

for key in not_divided_keys:
p = scores[key]
Expand All @@ -170,8 +219,41 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit
for key in scores.keys():
p = scores[key]
masks[key] = torch.ones(p.shape).to(p.device)

for key in masks.keys():
if len(scores[key].shape) == 4 and len(masks[key].shape) == 2: ## need to permute
mask = masks[key]
mask = mask.reshape(scores[key].shape[0], scores[key].shape[2], scores[key].shape[3],
scores[key].shape[1])
mask = mask.permute(0, 3, 1, 2)
masks[key] = mask
return masks

def get_pattern_lock_masks(self, modules):
pattern_lock_masks = {}
if isinstance(self.block_size, list):
self.block_size = self.get_block_size_dict(modules)
for key in modules.keys():
block_size = self.block_size[key]
weight = modules[key].weight
if len(weight.shape) == 4: # conv
weight = weight.permute(0, 2, 3, 1)
weight = weight.reshape(weight.shape[0], -1)
shape = weight.shape
new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1], block_size[1]]
p = weight.reshape(new_shape)
p_mag = p.abs() # avoid the scene which sum is zero but weights are not
weight_block_sum = p_mag.sum(-1).sum(1)
mask = torch.ones(weight_block_sum.shape)
mask[weight_block_sum == 0] = 0.0
mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1)
orig_shape = modules[key].weight.shape
if len(orig_shape) == 4:
mask = mask.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1])
mask = mask.permute(0, 3, 1, 2)
pattern_lock_masks[key] = mask.to(weight.device)
return pattern_lock_masks


@register_pattern('N:M')
class PatternNInM(Pattern):
Expand All @@ -188,7 +270,7 @@ def get_sparsity_ratio(self, pre_masks):
for key in pre_masks.keys():
non_zero_cnt += (torch.sum(pre_masks[key])).data.item()
total_cnt += pre_masks[key].numel()
return 1.0-float(non_zero_cnt) / total_cnt
return 1.0 - float(non_zero_cnt) / total_cnt

def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer):
N = self.N
Expand All @@ -199,18 +281,20 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit
not_divided_keys = []
for key in scores.keys():
current_score = scores[key]
shape = current_score.shape
if shape[1] % M != 0:
not_divided_keys.append(key)
continue
if len(current_score.shape) == 4: ##TODO need to verify whether it's ok for transposed conv
current_score = current_score.permute(0, 2, 3, 1)##cout,k,k,cin
current_score = current_score.permute(0, 2, 3, 1) ##cout,k,k,cin
current_score = current_score.reshape(current_score.shape[0], -1)
shape = current_score.shape
shape = current_score.shape
new_shape = [shape[0], shape[1] // M, M]
current_score_new = current_score.reshape(new_shape)

threshold, _ = torch.kthvalue(current_score_new, N, dim=2)
threshold = threshold.unsqueeze(-1)
if shape[1] % M != 0:
not_divided_keys.append(key)
continue

threshold = threshold.expand(shape[0], shape[1] // M, M)
threshold = threshold.reshape((shape[0], shape[1]))

Expand All @@ -221,7 +305,7 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit
##to get the sum of N scores in each block with M
current_score_new = current_score_new * (1.0 - mask)
current_score_new = current_score_new.reshape(shape[0], shape[1] // M, M)
score_sum = torch.sum(current_score_new, dim=-1)
score_sum = torch.mean(current_score_new, dim=-1)
all_nm_masks[key] = mask
new_scores[key] = score_sum

Expand All @@ -239,7 +323,7 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit
## both zero will be zero
mask = (mask + all_nm_masks[key])
mask = torch.where(mask <= 0, zero, one)
if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer:
if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer:
##trick, to prevent some layer not be purned too much
masks[key] = self.get_mask_single(new_scores[key], max_sparsity_ratio_per_layer)
masks[key] = masks[key].repeat_interleave(M, dim=-1)
Expand All @@ -266,3 +350,34 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit
masks[key] = mask

return masks

def get_pattern_lock_masks(self, modules):
pattern_lock_masks = {}
N, M = self.N, self.M
for key in modules.keys():
weight = modules[key].weight
if len(weight.shape) == 4: # conv
weight = weight.permute(0, 2, 3, 1)
weight = weight.reshape(weight.shape[0], -1)
shape = weight.shape
##TODO need to check whether it can be divisible later
new_shape = [shape[0], shape[1] // M, M]
weight_new = weight.reshape(new_shape)
mask1 = torch.ones(weight_new.shape)
mask2 = torch.ones(weight_new.shape)
nonzeros = torch.count_nonzero(weight_new, dim=-1)
zeros = M - nonzeros
mask1[weight_new == 0] = 0.0
mask2[zeros >= N] = 0.0
mask3 = mask1 + mask2 # zero in mask3 means its block has been completely pruned.
zero = torch.tensor([0.]).to(weight.device)
one = torch.tensor([1.]).to(weight.device)
mask = torch.where(mask3 == 0, zero, one)
mask = mask.reshape(shape)
orig_shape = modules[key].weight.shape
if len(orig_shape) == 4:
mask = mask.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1])
mask = mask.permute(0, 3, 1, 2)

pattern_lock_masks[key] = mask.to(weight.device)
return pattern_lock_masks
29 changes: 19 additions & 10 deletions neural_compressor/experimental/pytorch_pruner/prune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,23 @@ def check_config(prune_config):
"only support 'global' and 'local' prune domain"
if "x" in prune_config["pattern"]:
pattern = prune_config["pattern"].split('_')[-1].split('x')
N = int(pattern[0])
M = int(pattern[1])
assert N > 0, "N should be greater than 0"
assert M > 0, "M should be greater than 0"
if pattern[0]=="channel" or pattern[1]=="channel":
pass
else:
try:
N = int(pattern[0])
M = int(pattern[1])
except:
assert False, "N or M can't convert to int"
assert N > 0, "N should be greater than 0"
assert M > 0, "M should be greater than 0"
if ":" in prune_config["pattern"]:
pattern = prune_config["pattern"].split('_')[-1].split(':')
N = int(pattern[0])
M = int(pattern[1])
try:
N = int(pattern[0])
M = int(pattern[1])
except:
assert False, "N or M can't convert to int"
assert N > 0, "N should be greater than 0"
assert M > N, "M should be greater than N"
max_ratio = float(N) / M
Expand Down Expand Up @@ -101,8 +110,7 @@ def process_and_check_config(val):
pruner['extra_excluded_names'] = reset_non_value_to_default(info, 'extra_excluded_names',
extra_excluded_names)
pruner['pattern'] = reset_non_value_to_default(info, 'pattern',
pattern)

pattern)
check_config(pruner)
pruner_info = DotDict(pruner)
pruners_info.append(pruner_info)
Expand All @@ -115,10 +123,11 @@ def process_config(config):
with open(config, 'r') as f:
content = f.read()
try:
from ...conf.config import schema
except ImportError:
from .schema_check import schema

except ImportError:
from ...conf.config import schema

val = yaml.safe_load(content)
schema.validate(val)
except FileNotFoundError as f:
Expand Down
Loading

0 comments on commit f46bb12

Please sign in to comment.