diff --git a/neural_compressor/experimental/pytorch_pruner/patterns.py b/neural_compressor/experimental/pytorch_pruner/patterns.py index b5cb09e7a85..27b0ca5ee2f 100644 --- a/neural_compressor/experimental/pytorch_pruner/patterns.py +++ b/neural_compressor/experimental/pytorch_pruner/patterns.py @@ -67,12 +67,16 @@ def get_mask_single(self, score, exact_sparsity_ratio): one = torch.tensor([1.]).to(score.device) mask = torch.where(score <= threshold, zero, one) else: - mask = torch.ones(score.shape,device=score.device) + mask = torch.ones(score.shape, device=score.device) return mask + def get_block_size_dict(self, data): + raise NotImplementedError def get_masks_local(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer): masks = {} + if isinstance(self, PatternNxM) and not isinstance(self.block_size, dict): + self.block_size = self.get_block_size_dict(pre_masks) for key in scores.keys(): score = {key: scores[key]} pre_mask = {key: pre_masks[key]} @@ -89,19 +93,58 @@ def get_sparsity_ratio(self, pre_masks): total_cnt += pre_masks.numel() return float(zero_cnt) / total_cnt + def get_pattern_lock_masks(self, modules): + """ + basic way is to simply lock the zero position + """ + pattern_lock_masks = {} + for key in modules.keys(): + weight = modules[key].weight + shape = weight.shape + mask = torch.ones(shape) + mask[weight == 0] = 0.0 + pattern_lock_masks[key] = mask.to(weight.device) + return pattern_lock_masks + @register_pattern('NxM') class PatternNxM(Pattern): def __init__(self, config): super(PatternNxM, self).__init__(config) pattern = self.pattern.split('_')[-1] - self.block_size = [int(pattern.split('x')[0]), int(pattern.split('x')[1])] + self.N = pattern.split('x')[0] + self.M = pattern.split('x')[1] + if self.N == "channel": ##channel-wise pruning mode + self.block_size = ["channel", int(self.M)] + elif self.M == "channel": ##channel-wise pruning mode + self.block_size = [int(self.N), "channel"] + else: + self.block_size = [int(pattern.split('x')[0]), int(pattern.split('x')[1])] + + def get_block_size_dict(self, data): + block_sizes_dict = {} + if self.N == "channel" or self.M == "channel": + for key in data.keys(): + if isinstance(data[key], torch.nn.Module): + shape = data[key].weight.shape + else: + shape = data[key].shape + if self.N == "channel": + block_sizes_dict[key] = [shape[0], 1] + else: + block_sizes_dict[key] = [1, shape[1]] + return block_sizes_dict + for key in data.keys(): + block_sizes_dict[key] = self.block_size + return block_sizes_dict def get_sparsity_ratio(self, pre_masks): zero_cnt = 0 total_cnt = 0 - block_size = self.block_size + if isinstance(self.block_size, list): + self.block_size = self.get_block_size_dict(pre_masks) for key in pre_masks.keys(): + block_size = self.block_size[key] pre_mask = pre_masks[key] shape = pre_mask.shape if len(shape) == 4: @@ -117,23 +160,28 @@ def get_sparsity_ratio(self, pre_masks): total_cnt += pre_mask_sum.numel() return float(zero_cnt) / total_cnt - def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer): - block_size = self.block_size + def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer, + keep_pre_mask=False): + if isinstance(self.block_size, list): + self.block_size = self.get_block_size_dict(scores) new_scores = {} not_divided_keys = [] for key in scores.keys(): + block_size = self.block_size[key] current_score = scores[key] + if len(current_score.shape) == 4: ##TODO need to verify whether it's ok for transposed conv + current_score = current_score.permute(0, 2, 3, 1) ##cout,k,k,cin + current_score = current_score.reshape(current_score.shape[0], -1) shape = current_score.shape - if len(shape) == 4: ##default is conv, is transpose conv ok ? - shape = current_score.reshape(current_score.shape[0], -1).shape - if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: - + if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: ## only consider input channel not_divided_keys.append(key) continue - new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1], block_size[1]] + new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1], + block_size[1]] current_score = current_score.reshape(new_shape) - current_score_sum = current_score.sum(-1).sum(1) + current_score_sum = current_score.mean(-1).mean( + 1) ##TODO sum or mean is quite different for per channel pruning new_scores[key] = current_score_sum global_scores = torch.cat([torch.flatten(v) for v in new_scores.values()]) k = int(target_sparsity_ratio * global_scores.numel()) @@ -141,25 +189,26 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit if not k < 1: threshold, _ = torch.kthvalue(global_scores, k) for key in new_scores.keys(): + block_size = self.block_size[key] score = new_scores[key] zero = torch.tensor([0.]).to(score.device) one = torch.tensor([1.]).to(score.device) mask = torch.where(score <= threshold, zero, one) mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1) - if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer: + if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer: ##to prevent some layer not be purned too much ##this is differnt with our original implementation - masks[key]=self.get_mask_single(new_scores[key],max_sparsity_ratio_per_layer) - masks[key]=masks[key].repeat_interleave(block_size[0], 0).repeat_interleave(block_size[1], -1) + masks[key] = self.get_mask_single(new_scores[key], max_sparsity_ratio_per_layer) + masks[key] = masks[key].repeat_interleave(block_size[0], 0).repeat_interleave(block_size[1], -1) # if pre_masks != {}:##when use one shot, this is not right # masks[key] = pre_masks[key] # else: # masks[key] = mask else: masks[key] = mask - if len(scores[key].shape)==4: - ##we need to revert back - masks[key]=masks[key].reshape(scores[key].shape) + # if len(scores[key].shape) == 4: + # ##we need to revert back + # masks[key] = masks[key].reshape(scores[key].shape) for key in not_divided_keys: p = scores[key] @@ -170,8 +219,41 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit for key in scores.keys(): p = scores[key] masks[key] = torch.ones(p.shape).to(p.device) + + for key in masks.keys(): + if len(scores[key].shape) == 4 and len(masks[key].shape) == 2: ## need to permute + mask = masks[key] + mask = mask.reshape(scores[key].shape[0], scores[key].shape[2], scores[key].shape[3], + scores[key].shape[1]) + mask = mask.permute(0, 3, 1, 2) + masks[key] = mask return masks + def get_pattern_lock_masks(self, modules): + pattern_lock_masks = {} + if isinstance(self.block_size, list): + self.block_size = self.get_block_size_dict(modules) + for key in modules.keys(): + block_size = self.block_size[key] + weight = modules[key].weight + if len(weight.shape) == 4: # conv + weight = weight.permute(0, 2, 3, 1) + weight = weight.reshape(weight.shape[0], -1) + shape = weight.shape + new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1], block_size[1]] + p = weight.reshape(new_shape) + p_mag = p.abs() # avoid the scene which sum is zero but weights are not + weight_block_sum = p_mag.sum(-1).sum(1) + mask = torch.ones(weight_block_sum.shape) + mask[weight_block_sum == 0] = 0.0 + mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1) + orig_shape = modules[key].weight.shape + if len(orig_shape) == 4: + mask = mask.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]) + mask = mask.permute(0, 3, 1, 2) + pattern_lock_masks[key] = mask.to(weight.device) + return pattern_lock_masks + @register_pattern('N:M') class PatternNInM(Pattern): @@ -188,7 +270,7 @@ def get_sparsity_ratio(self, pre_masks): for key in pre_masks.keys(): non_zero_cnt += (torch.sum(pre_masks[key])).data.item() total_cnt += pre_masks[key].numel() - return 1.0-float(non_zero_cnt) / total_cnt + return 1.0 - float(non_zero_cnt) / total_cnt def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer): N = self.N @@ -199,18 +281,20 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit not_divided_keys = [] for key in scores.keys(): current_score = scores[key] + shape = current_score.shape + if shape[1] % M != 0: + not_divided_keys.append(key) + continue if len(current_score.shape) == 4: ##TODO need to verify whether it's ok for transposed conv - current_score = current_score.permute(0, 2, 3, 1)##cout,k,k,cin + current_score = current_score.permute(0, 2, 3, 1) ##cout,k,k,cin current_score = current_score.reshape(current_score.shape[0], -1) - shape = current_score.shape + shape = current_score.shape new_shape = [shape[0], shape[1] // M, M] current_score_new = current_score.reshape(new_shape) threshold, _ = torch.kthvalue(current_score_new, N, dim=2) threshold = threshold.unsqueeze(-1) - if shape[1] % M != 0: - not_divided_keys.append(key) - continue + threshold = threshold.expand(shape[0], shape[1] // M, M) threshold = threshold.reshape((shape[0], shape[1])) @@ -221,7 +305,7 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit ##to get the sum of N scores in each block with M current_score_new = current_score_new * (1.0 - mask) current_score_new = current_score_new.reshape(shape[0], shape[1] // M, M) - score_sum = torch.sum(current_score_new, dim=-1) + score_sum = torch.mean(current_score_new, dim=-1) all_nm_masks[key] = mask new_scores[key] = score_sum @@ -239,7 +323,7 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit ## both zero will be zero mask = (mask + all_nm_masks[key]) mask = torch.where(mask <= 0, zero, one) - if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer: + if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer: ##trick, to prevent some layer not be purned too much masks[key] = self.get_mask_single(new_scores[key], max_sparsity_ratio_per_layer) masks[key] = masks[key].repeat_interleave(M, dim=-1) @@ -266,3 +350,34 @@ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsit masks[key] = mask return masks + + def get_pattern_lock_masks(self, modules): + pattern_lock_masks = {} + N, M = self.N, self.M + for key in modules.keys(): + weight = modules[key].weight + if len(weight.shape) == 4: # conv + weight = weight.permute(0, 2, 3, 1) + weight = weight.reshape(weight.shape[0], -1) + shape = weight.shape + ##TODO need to check whether it can be divisible later + new_shape = [shape[0], shape[1] // M, M] + weight_new = weight.reshape(new_shape) + mask1 = torch.ones(weight_new.shape) + mask2 = torch.ones(weight_new.shape) + nonzeros = torch.count_nonzero(weight_new, dim=-1) + zeros = M - nonzeros + mask1[weight_new == 0] = 0.0 + mask2[zeros >= N] = 0.0 + mask3 = mask1 + mask2 # zero in mask3 means its block has been completely pruned. + zero = torch.tensor([0.]).to(weight.device) + one = torch.tensor([1.]).to(weight.device) + mask = torch.where(mask3 == 0, zero, one) + mask = mask.reshape(shape) + orig_shape = modules[key].weight.shape + if len(orig_shape) == 4: + mask = mask.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]) + mask = mask.permute(0, 3, 1, 2) + + pattern_lock_masks[key] = mask.to(weight.device) + return pattern_lock_masks diff --git a/neural_compressor/experimental/pytorch_pruner/prune_utils.py b/neural_compressor/experimental/pytorch_pruner/prune_utils.py index c9a9b6ac4f8..036e2a2b72a 100644 --- a/neural_compressor/experimental/pytorch_pruner/prune_utils.py +++ b/neural_compressor/experimental/pytorch_pruner/prune_utils.py @@ -39,14 +39,23 @@ def check_config(prune_config): "only support 'global' and 'local' prune domain" if "x" in prune_config["pattern"]: pattern = prune_config["pattern"].split('_')[-1].split('x') - N = int(pattern[0]) - M = int(pattern[1]) - assert N > 0, "N should be greater than 0" - assert M > 0, "M should be greater than 0" + if pattern[0]=="channel" or pattern[1]=="channel": + pass + else: + try: + N = int(pattern[0]) + M = int(pattern[1]) + except: + assert False, "N or M can't convert to int" + assert N > 0, "N should be greater than 0" + assert M > 0, "M should be greater than 0" if ":" in prune_config["pattern"]: pattern = prune_config["pattern"].split('_')[-1].split(':') - N = int(pattern[0]) - M = int(pattern[1]) + try: + N = int(pattern[0]) + M = int(pattern[1]) + except: + assert False, "N or M can't convert to int" assert N > 0, "N should be greater than 0" assert M > N, "M should be greater than N" max_ratio = float(N) / M @@ -101,8 +110,7 @@ def process_and_check_config(val): pruner['extra_excluded_names'] = reset_non_value_to_default(info, 'extra_excluded_names', extra_excluded_names) pruner['pattern'] = reset_non_value_to_default(info, 'pattern', - pattern) - + pattern) check_config(pruner) pruner_info = DotDict(pruner) pruners_info.append(pruner_info) @@ -115,10 +123,11 @@ def process_config(config): with open(config, 'r') as f: content = f.read() try: - from ...conf.config import schema - except ImportError: from .schema_check import schema + except ImportError: + from ...conf.config import schema + val = yaml.safe_load(content) schema.validate(val) except FileNotFoundError as f: diff --git a/neural_compressor/experimental/pytorch_pruner/pruner.py b/neural_compressor/experimental/pytorch_pruner/pruner.py index ee57b99fda7..d6098d8820f 100644 --- a/neural_compressor/experimental/pytorch_pruner/pruner.py +++ b/neural_compressor/experimental/pytorch_pruner/pruner.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +# !/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright (c) 2022 Intel Corporation @@ -34,12 +34,12 @@ def register(pruner): return register -def get_pruner(moduels, config): +def get_pruner(modules, config): """Get registered pruner class""" name = config["prune_type"] if name not in PRUNERS.keys(): assert False, f"does not support {name}, currently only support {PRUNERS.keys()}" - return PRUNERS[name](moduels, config) + return PRUNERS[name](modules, config) class Pruner: @@ -122,6 +122,12 @@ def on_train_begin(self): def on_train_end(self): pass + def on_before_eval(self): + pass + + def on_after_eval(self): + pass + def check_is_pruned_step(self, step): if step < self.start_step or step > self.end_step: return False @@ -133,23 +139,37 @@ def update_scores(self): pass +@register_pruners('magnitude') +class MagnitudePruner(Pruner): + def __init__(self, modules, config): + super(MagnitudePruner, self).__init__(modules, config) + self.scores = {} + + def update_scores(self): + with torch.no_grad(): + for key in self.modules.keys(): + p = self.modules[key].weight.data + self.scores[key] = p + + @register_pruners('snip') class SnipPruner(Pruner): """ - please refer to SNIP: Single-shot Network Pruning based on Connection Sensitivity + please refer to SNIP: Single-shot Network Pruning based on Connection Sensitivity (https://arxiv.org/abs/1810.02340) """ + def __init__(self, modules, config): super(SnipPruner, self).__init__(modules, config) assert self.config.end_step > 0, "gradient based criteria does not work on step 0" self.scores = {} def on_after_optimizer_step(self): + self.mask_weights() with torch.no_grad(): for key in self.modules.keys(): p = self.modules[key].weight self.scores[key] = torch.abs(p * p.grad) - self.mask_weights() @register_pruners('snip_momentum') @@ -163,22 +183,25 @@ def __init__(self, modules, config): self.scores[key] = torch.zeros(p.shape).to(p.device) def on_after_optimizer_step(self): + self.mask_weights() with torch.no_grad(): for key in self.modules.keys(): p = self.modules[key].weight self.scores[key] *= 0.9 ##magic number self.scores[key] += 1.0 * torch.abs(p * p.grad) - self.mask_weights() -@register_pruners('magnitude') -class MagnitudePruner(Pruner): +@register_pruners('pattern_lock') +class PatternLockPruner(Pruner): def __init__(self, modules, config): - super(MagnitudePruner, self).__init__(modules, config) - self.scores = {} + super(PatternLockPruner, self).__init__(modules, config) + assert self.config.end_step == self.config.start_step, "pattern_lock pruner only supports one shot mode" - def update_scores(self): - with torch.no_grad(): - for key in self.modules.keys(): - p = self.modules[key].weight.data - self.scores[key] = p + def on_step_begin(self, local_step): + self.global_step += 1 + if not self.check_is_pruned_step(self.global_step): + return + self.masks = self.pattern.get_pattern_lock_masks(self.modules) + + def on_after_optimizer_step(self): + self.mask_weights() diff --git a/neural_compressor/experimental/pytorch_pruner/pruning.py b/neural_compressor/experimental/pytorch_pruner/pruning.py index c22e4aae5bd..dac24f4a6cd 100644 --- a/neural_compressor/experimental/pytorch_pruner/pruning.py +++ b/neural_compressor/experimental/pytorch_pruner/pruning.py @@ -21,6 +21,7 @@ from .pruner import get_pruner from .logger import logger + class Pruning: def __init__(self, config): self.model = None @@ -34,7 +35,7 @@ def update_items_for_all_pruners(self, **kwargs): if key in item.keys(): item[key] = kwargs[key] - #def _call_pruners(self, func): + # def _call_pruners(self, func): # def warpper(self, *args, **kw): # func_name = f"{func.__name__}" # func(self, *args, **kw) @@ -44,6 +45,34 @@ def update_items_for_all_pruners(self, **kwargs): # # return warpper + def get_sparsity_ratio(self): + pattern_sparsity_cnt = 0 + element_sparsity_cnt = 0 + for pruner in self.pruners: + modules = pruner.modules + sparsity_ratio = pruner.pattern.get_sparsity_ratio(pruner.masks) + cnt = 0 + for key in modules.keys(): + cnt += modules[key].weight.numel() + pattern_sparsity_cnt += int(cnt * sparsity_ratio) + for key in pruner.masks.keys(): + element_sparsity_cnt += torch.sum(pruner.masks[key] == 0).data.item() + + linear_conv_cnt = 0 + param_cnt = 0 + for name, module in self.model.named_modules(): + if type(module).__name__ in ["Linear"] or "Conv" in type(module).__name__: + linear_conv_cnt += module.weight.numel() + + for n, param in self.model.named_parameters(): + param_cnt += param.numel() + blockwise_over_matmul_gemm_conv = float(pattern_sparsity_cnt) / linear_conv_cnt + elementwise_over_matmul_gemm_conv = float(element_sparsity_cnt) / linear_conv_cnt + elementwise_over_all = float( + element_sparsity_cnt) / param_cnt + + return elementwise_over_matmul_gemm_conv, elementwise_over_all, blockwise_over_matmul_gemm_conv + def _generate_pruners(self): assert isinstance(self.model, torch.nn.Module) @@ -58,41 +87,52 @@ def _generate_pruners(self): info['len_of_modules'] = len(info['modules']) logger.info(info) - #@_call_pruners + # @_call_pruners def on_train_begin(self): self._generate_pruners() ##TODO is there better place to place - #@_call_pruners + # @_call_pruners def on_epoch_begin(self, epoch): for pruner in self.pruners: - pruner.on_epoch_begin(epoch) + pruner.on_epoch_begin(epoch) - #@_call_pruners + + # @_call_pruners def on_step_begin(self, local_step): for pruner in self.pruners: pruner.on_step_begin(local_step) - #@_call_pruners + # @_call_pruners def on_before_optimizer_step(self): for pruner in self.pruners: pruner.on_before_optimizer_step() - #@_call_pruners + # @_call_pruners def on_step_end(self): for pruner in self.pruners: pruner.on_step_end() - #@_call_pruners + # @_call_pruners def on_epoch_end(self): for pruner in self.pruners: pruner.on_epoch_end() - #@_call_pruners + # @_call_pruners def on_train_end(self): for pruner in self.pruners: pruner.on_train_end() - #@_call_pruners + # @_call_pruners + + def on_before_eval(self): + for pruner in self.pruners: + pruner.on_before_eval() + + def on_after_eval(self): + for pruner in self.pruners: + pruner.on_after_eval() + + # @_call_pruners def on_after_optimizer_step(self): for pruner in self.pruners: pruner.on_after_optimizer_step() diff --git a/test/pruning/test_pytorch_pruning.py b/test/pruning/test_pytorch_pruning.py index bc2422ce94c..73739de75ab 100644 --- a/test/pruning/test_pytorch_pruning.py +++ b/test/pruning/test_pytorch_pruning.py @@ -9,8 +9,9 @@ from neural_compressor.data import DATASETS from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader -def build_fake_yaml(): - fake_yaml = """ + +def build_fake_yaml_basic(): + fake_snip_yaml = """ model: name: imagenet_prune framework: pytorch @@ -23,16 +24,17 @@ def build_fake_yaml(): start_step: 0 end_step: 10 excluded_names: ["classifier"] - + update_frequency_on_step: 1 sparsity_decay_type: "exp" pruners: - !Pruner start_step: 0 + sparsity_decay_type: "cos" end_step: 10 prune_type: "magnitude" names: ['layer1.*'] - extra_excluded_names: ['layer2.*'] + extra_excluded_names: ['layer2.*'] prune_domain: "global" pattern: "tile_pattern_4x1" @@ -45,6 +47,7 @@ def build_fake_yaml(): names: ['layer2.*'] prune_domain: local pattern: "tile_pattern_2:4" + - !Pruner start_step: 2 end_step: 8 @@ -54,10 +57,62 @@ def build_fake_yaml(): prune_domain: "local" pattern: "tile_pattern_16x1" sparsity_decay_type: "cube" - """ - with open('fake.yaml', 'w', encoding="utf-8") as f: - f.write(fake_yaml) + """ + with open('fake_snip.yaml', 'w', encoding="utf-8") as f: + f.write(fake_snip_yaml) + +def build_fake_yaml_channel(): + fake_channel_pruning_yaml = """ + model: + name: imagenet_prune + framework: pytorch + + pruning: + approach: + weight_compression_pytorch: + initial_sparsity: 0.0 + target_sparsity: 0.9 + start_step: 0 + end_step: 10 + excluded_names: ["classifier"] + + update_frequency_on_step: 1 + sparsity_decay_type: "exp" + pruners: + - !Pruner + start_step: 5 + end_step: 5 + prune_type: "pattern_lock" + names: ['layer1.*'] + extra_excluded_names: ['layer2.*'] + prune_domain: "global" + pattern: "channelx1" + + - !Pruner + start_step: 1 + end_step: 1 + target_sparsity: 0.5 + prune_type: "pattern_lock" + update_frequency: 2 + names: ['layer2.*'] + prune_domain: local + pattern: "2:4" + + - !Pruner + start_step: 2 + end_step: 8 + target_sparsity: 0.8 + prune_type: "snip" + names: ['layer3.*'] + prune_domain: "local" + pattern: "1xchannel" + sparsity_decay_type: "cube" + + """ + + with open('fake_channel_pruning.yaml', 'w', encoding="utf-8") as f: + f.write(fake_channel_pruning_yaml) class TestPytorchPruning(unittest.TestCase): @@ -66,19 +121,23 @@ class TestPytorchPruning(unittest.TestCase): @classmethod def setUpClass(cls): - build_fake_yaml() + build_fake_yaml_basic() + build_fake_yaml_channel() + @classmethod def tearDownClass(cls): - os.remove('fake.yaml') + os.remove('fake_channel_pruning.yaml') + os.remove('fake_snip.yaml') shutil.rmtree('./saved', ignore_errors=True) shutil.rmtree('runs', ignore_errors=True) - def test_pytorch_pruning(self): + def test_pytorch_pruning_basic(self): from neural_compressor.experimental.pytorch_pruner.pruning import Pruning - prune = Pruning('fake.yaml') - ##prune.generate_pruners() + prune = Pruning("fake_snip.yaml") + ##prune.generate_pruners() + prune.update_items_for_all_pruners(start_step=1) prune.model = self.model criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001) @@ -86,6 +145,7 @@ def test_pytorch_pruning(self): dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True) dummy_dataloader = PyTorchDataLoader(dummy_dataset) prune.on_train_begin() + prune.update_items_for_all_pruners(update_frequency_on_step=1) for epoch in range(2): self.model.train() prune.on_epoch_begin(epoch) @@ -103,7 +163,41 @@ def test_pytorch_pruning(self): local_step += 1 prune.on_epoch_end() + prune.get_sparsity_ratio() + prune.on_train_end() + prune.on_before_eval() + prune.on_after_eval() + + def test_pytorch_pruner_channel_pruning(self): + from neural_compressor.experimental.pytorch_pruner.pruning import Pruning + prune = Pruning("fake_channel_pruning.yaml") + ##prune.generate_pruners() + prune.model = self.model + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001) + datasets = DATASETS('pytorch') + dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True) + dummy_dataloader = PyTorchDataLoader(dummy_dataset) + prune.on_train_begin() + for epoch in range(2): + self.model.train() + prune.on_epoch_begin(epoch) + local_step = 0 + for image, target in dummy_dataloader: + prune.on_step_begin(local_step) + output = self.model(image) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + prune.on_before_optimizer_step() + optimizer.step() + prune.on_after_optimizer_step() + prune.on_step_end() + local_step += 1 + prune.on_epoch_end() if __name__ == "__main__": unittest.main() + +