Skip to content

Commit

Permalink
Align sparsity with block-wise masks in progressive pruning. (#1250)
Browse files Browse the repository at this point in the history
Signed-off-by: YIYANGCAI <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
YIYANGCAI and pre-commit-ci[bot] authored Sep 18, 2023
1 parent ec06411 commit fcdc29a
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion neural_compressor/compression/pruner/pruners/progressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def _init(self):
self.progressive_steps = self.progressive_configs["progressive_steps"]
self.progressive_type = self.progressive_configs["progressive_type"]
self.use_global = self.progressive_configs["use_global"]
self.progressive_logger = False
self.progressive_logger = True
self.align_masks_flag = False
self._init_for_progressive()

def _init_for_progressive(self):
Expand All @@ -77,6 +78,11 @@ def _init_for_progressive(self):
self.use_progressive = False
return

if self.pruning_frequency == 1:
logger.info("Current progressive setting will degrading to non-progressive pruning.")
self.use_progressive = False
return

# step 3: log hyper-parameters. and check validity.
if self.use_progressive:
logger.info("Progressive pruning is enabled!")
Expand Down Expand Up @@ -225,6 +231,10 @@ def on_step_begin(self, local_step):
Implement at the start of each step.
"""
if self.global_step > self.end_step and self.align_masks_flag is False:
self.align_masks_after_pruning()
self.align_masks_flag = True

if self.handled_global_step == self.global_step:
return

Expand Down Expand Up @@ -270,3 +280,24 @@ def print_progressive_sparsity(self):
"""Output the progressive sparsity."""
cur_sp = self.pattern.get_sparsity_ratio_progressive(self.progressive_masks)
logger.info("Step: {} -> Current progressive sparsity: {}".format(self.global_step, cur_sp))

def obtain_weight_sparsity(self, modules):
total_numels = 0
sparse_numels = 0
for key in modules.keys():
total_numels += modules[key].weight.data.numel()
sparse_numels += torch.sum(torch.where(modules[key].weight.data == 0, 1, 0)).item()
return sparse_numels / total_numels

def align_masks_after_pruning(self):
if not self.use_progressive:
return
"""Implement at the end of training phase."""
# If training ends while a progressive masks is applying, we have to use self.masks to align
# step 1 calculate sparsity under progressive masks
sparsity1 = self.obtain_weight_sparsity(self.modules)
# step 2 use block-wise masks to remask weights
self.mask_weights_general(self.masks)
# step 3 calculate sparsity under progressive masks
sparsity2 = self.obtain_weight_sparsity(self.modules)
logger.info(f"Replace progressive mask with complete masks: Sparsity Update: {sparsity1} => {sparsity2}")

0 comments on commit fcdc29a

Please sign in to comment.