From 6cec70bb2c5fd3079e4d572e22a89b152a229941 Mon Sep 17 00:00:00 2001 From: Frank Cai <49231152+YIYANGCAI@users.noreply.github.com> Date: Tue, 6 Sep 2022 09:27:05 +0800 Subject: [PATCH] pytorch prune new api (#1212) --- neural_compressor/conf/config.py | 38 ++- neural_compressor/experimental/pruning.py | 69 +++-- .../experimental/pytorch_pruner/__init__.py | 16 ++ .../experimental/pytorch_pruner/logger.py | 22 ++ .../experimental/pytorch_pruner/patterns.py | 268 ++++++++++++++++++ .../pytorch_pruner/prune_utils.py | 172 +++++++++++ .../experimental/pytorch_pruner/pruner.py | 180 ++++++++++++ .../experimental/pytorch_pruner/pruning.py | 98 +++++++ .../experimental/pytorch_pruner/scheduler.py | 89 ++++++ test/pruning/test_pytorch_pruning.py | 108 +++++++ 10 files changed, 1031 insertions(+), 29 deletions(-) create mode 100644 neural_compressor/experimental/pytorch_pruner/__init__.py create mode 100644 neural_compressor/experimental/pytorch_pruner/logger.py create mode 100644 neural_compressor/experimental/pytorch_pruner/patterns.py create mode 100644 neural_compressor/experimental/pytorch_pruner/prune_utils.py create mode 100644 neural_compressor/experimental/pytorch_pruner/pruner.py create mode 100644 neural_compressor/experimental/pytorch_pruner/pruning.py create mode 100644 neural_compressor/experimental/pytorch_pruner/scheduler.py create mode 100644 test/pruning/test_pytorch_pruning.py diff --git a/neural_compressor/conf/config.py b/neural_compressor/conf/config.py index 8fcbd2ed811..35ed35f3c8b 100644 --- a/neural_compressor/conf/config.py +++ b/neural_compressor/conf/config.py @@ -50,16 +50,27 @@ def constructor(loader, node): @constructor_register class Pruner(): def __init__(self, start_epoch=None, end_epoch=None, initial_sparsity=None, - target_sparsity=None, update_frequency=1, prune_type='basic_magnitude', - method='per_tensor', names=[], parameters=None): + target_sparsity=None, update_frequency=1, + method='per_tensor', + prune_type='basic_magnitude',##for pytorch pruning, these values should be None + start_step=None, end_step=None, update_frequency_on_step=None, prune_domain=None, + sparsity_decay_type=None, pattern="tile_pattern_1x1", names=None, exclude_names=None, parameters=None): self.start_epoch = start_epoch self.end_epoch = end_epoch self.update_frequency = update_frequency self.target_sparsity = target_sparsity self.initial_sparsity = initial_sparsity self.update_frequency = update_frequency - assert prune_type.replace('_', '') in [i.lower() for i in PRUNERS], \ - 'now only support {}'.format(PRUNERS.keys()) + self.start_step = start_step + self.end_step = end_step + self.update_frequency_on_step = update_frequency_on_step + self.prune_domain = prune_domain + self.sparsity_decay_type = sparsity_decay_type + self.exclude_names = exclude_names + self.pattern = pattern + ## move this to experimental/pruning to support dynamic pruning + # assert prune_type.replace('_', '') in [i.lower() for i in PRUNERS], \ + # 'now only support {}'.format(PRUNERS.keys()) self.prune_type = prune_type self.method = method self.names= names @@ -663,15 +674,33 @@ def percent_to_float(data): weight_compression_schema = Schema({ Optional('initial_sparsity', default=0): And(float, lambda s: s < 1.0 and s >= 0.0), Optional('target_sparsity', default=0.97): float, + Optional('max_sparsity_ratio_per_layer', default=0.98):float, + Optional('prune_type', default="basic_magnitude"): str, Optional('start_epoch', default=0): int, Optional('end_epoch', default=4): int, + Optional('start_step', default=0): int, + Optional('end_step', default=0): int, + Optional('update_frequency', default=1.0): float, + Optional('update_frequency_on_step', default=1):int, + Optional('not_to_prune_names', default=[]):list, + Optional('prune_domain', default="global"): str, + Optional('names', default=[]): list, + Optional('exclude_names', default=None): list, + Optional('prune_layer_type', default=None): list, + Optional('sparsity_decay_type', default="exp"): str, + Optional('pattern', default="tile_pattern_1x1"): str, + Optional('pruners'): And(list, \ lambda s: all(isinstance(i, Pruner) for i in s)) }) +# weight_compression_pytorch_schema = Schema({},ignore_extra_keys=True) + approach_schema = Schema({ Hook('weight_compression', handler=_valid_prune_sparsity): object, + Hook('weight_compression_pytorch', handler=_valid_prune_sparsity): object, Optional('weight_compression'): weight_compression_schema, + Optional('weight_compression_pytorch'): weight_compression_schema, }) default_workspace = './nc_workspace/{}/'.format( @@ -1498,6 +1527,7 @@ class Pruning_Conf(Conf): def __init__(self, cfg=None): if isinstance(cfg, str): + self._read_cfg(cfg) self.usr_cfg = DotDict(self._read_cfg(cfg)) elif isinstance(cfg, DotDict): self.usr_cfg = DotDict(schema.validate(self._convert_cfg( diff --git a/neural_compressor/experimental/pruning.py b/neural_compressor/experimental/pruning.py index e0d139fa957..e72109a111e 100644 --- a/neural_compressor/experimental/pruning.py +++ b/neural_compressor/experimental/pruning.py @@ -23,6 +23,7 @@ from ..model import BaseModel from ..adaptor import FRAMEWORKS from ..conf.config import PruningConf + from warnings import warn class Pruning(Component): @@ -86,9 +87,10 @@ def _on_epoch_end(self): res = [] for pruner in self.pruners: res.append(pruner.on_epoch_end()) - stats, sparsity = self._model.report_sparsity() - logger.info(stats) - logger.info(sparsity) + if hasattr(self, "_model"): + stats, sparsity = self._model.report_sparsity() + logger.info(stats) + logger.info(sparsity) return res def _on_train_end(self): @@ -96,6 +98,11 @@ def _on_train_end(self): for pruner in self.pruners: pruner.on_train_end() + def _on_after_optimizer_step(self): + """ called after optimzier step """ + for pruner in self.pruners: + pruner.on_after_optimizer_step() + def pre_process(self): assert isinstance(self._model, BaseModel), 'need set neural_compressor Model for pruning....' @@ -181,28 +188,40 @@ def generate_hooks(self): def generate_pruners(self): for name in self.cfg.pruning.approach: - assert name == 'weight_compression', 'now we only support weight_compression' - for pruner in self.cfg.pruning.approach.weight_compression.pruners: - if pruner.prune_type == 'basic_magnitude': - self.pruners.append(PRUNERS['BasicMagnitude'](\ - self._model, \ - pruner, - self.cfg.pruning.approach.weight_compression)) - if pruner.prune_type == 'pattern_lock': - self.pruners.append(PRUNERS['PatternLock'](\ - self._model, \ - pruner, - self.cfg.pruning.approach.weight_compression)) - elif pruner.prune_type == 'gradient_sensitivity': - self.pruners.append(PRUNERS['GradientSensitivity'](\ - self._model, \ - pruner, - self.cfg.pruning.approach.weight_compression)) - elif pruner.prune_type == 'group_lasso': - self.pruners.append(PRUNERS['GroupLasso'](\ - self._model, \ - pruner, - self.cfg.pruning.approach.weight_compression)) + assert name == 'weight_compression' or name == "weight_compression_pytorch", \ + 'now we only support weight_compression and weight_compression_pytorch' + + if self.cfg.pruning.approach.weight_compression_pytorch != None: + from .pytorch_pruner.pruning import Pruning as PytorchPruning + self.pytorch_pruner = PytorchPruning(self.cfg) + self.pruners.append(self.pytorch_pruner) + + + if self.cfg.pruning.approach.weight_compression != None: + for pruner in self.cfg.pruning.approach.weight_compression.pruners: + if pruner.prune_type == 'basic_magnitude': + self.pruners.append(PRUNERS['BasicMagnitude'](\ + self._model, \ + pruner, + self.cfg.pruning.approach.weight_compression)) + elif pruner.prune_type == 'pattern_lock': + self.pruners.append(PRUNERS['PatternLock'](\ + self._model, \ + pruner, + self.cfg.pruning.approach.weight_compression)) + elif pruner.prune_type == 'gradient_sensitivity': + self.pruners.append(PRUNERS['GradientSensitivity'](\ + self._model, \ + pruner, + self.cfg.pruning.approach.weight_compression)) + elif pruner.prune_type == 'group_lasso': + self.pruners.append(PRUNERS['GroupLasso'](\ + self._model, \ + pruner, + self.cfg.pruning.approach.weight_compression)) + else: + ##print(pruner.prune_type) + assert False, 'now only support {}'.format(PRUNERS.keys()) def __call__(self): """The main entry point of pruning. diff --git a/neural_compressor/experimental/pytorch_pruner/__init__.py b/neural_compressor/experimental/pytorch_pruner/__init__.py new file mode 100644 index 00000000000..369707c0ef6 --- /dev/null +++ b/neural_compressor/experimental/pytorch_pruner/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neural_compressor/experimental/pytorch_pruner/logger.py b/neural_compressor/experimental/pytorch_pruner/logger.py new file mode 100644 index 00000000000..88b276b28a6 --- /dev/null +++ b/neural_compressor/experimental/pytorch_pruner/logger.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from ...utils import logger +except: + import logging + logger = logging.getLogger(__name__) diff --git a/neural_compressor/experimental/pytorch_pruner/patterns.py b/neural_compressor/experimental/pytorch_pruner/patterns.py new file mode 100644 index 00000000000..b5cb09e7a85 --- /dev/null +++ b/neural_compressor/experimental/pytorch_pruner/patterns.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from .logger import logger + +PATTERNS = {} + + +def register_pattern(name): + """Register a pattern to the registry""" + + def register(pattern): + PATTERNS[name] = pattern + return pattern + + return register + + +def get_pattern(config): + """Get registered pattern class""" + name = config.pattern + name = name.split('_')[-1] + if "x" in name: + return PATTERNS["NxM"](config) + if ":" in name: + return PATTERNS["N:M"](config) + assert False, f"currently only support {PATTERNS.keys()}" + + +class Pattern: + def __init__(self, config): + self.pattern = config.pattern + self.is_global = config.prune_domain == "global" + + def get_masks(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer): + if self.is_global: + return self.get_masks_global(scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer) + else: + return self.get_masks_local(scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer) + + def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer): + raise NotImplementedError + + def get_mask_single(self, score, exact_sparsity_ratio): + flattern_score = torch.flatten(score) + k = int(exact_sparsity_ratio * flattern_score.numel()) + threshold, _ = torch.kthvalue(flattern_score, k) + if not k < 1: + zero = torch.tensor([0.]).to(score.device) + one = torch.tensor([1.]).to(score.device) + mask = torch.where(score <= threshold, zero, one) + else: + mask = torch.ones(score.shape,device=score.device) + return mask + + + def get_masks_local(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer): + masks = {} + for key in scores.keys(): + score = {key: scores[key]} + pre_mask = {key: pre_masks[key]} + mask = self.get_masks_global(score, target_sparsity_ratio, pre_mask, max_sparsity_ratio_per_layer) + masks[key] = mask[key] + return masks + + def get_sparsity_ratio(self, pre_masks): + zero_cnt = 0 + total_cnt = 0 + for key in pre_masks.keys(): + pre_mask = pre_masks[key] + zero_cnt += torch.sum(pre_mask == 0.0).data.item() + total_cnt += pre_masks.numel() + return float(zero_cnt) / total_cnt + + +@register_pattern('NxM') +class PatternNxM(Pattern): + def __init__(self, config): + super(PatternNxM, self).__init__(config) + pattern = self.pattern.split('_')[-1] + self.block_size = [int(pattern.split('x')[0]), int(pattern.split('x')[1])] + + def get_sparsity_ratio(self, pre_masks): + zero_cnt = 0 + total_cnt = 0 + block_size = self.block_size + for key in pre_masks.keys(): + pre_mask = pre_masks[key] + shape = pre_mask.shape + if len(shape) == 4: + shape = pre_mask.reshape(pre_mask.shape[0], -1).shape + if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: + logger.warning(f"layer {key} is not support under current pattern, ignoring") + continue + + new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1], block_size[1]] + pre_mask = pre_mask.reshape(new_shape) + pre_mask_sum = pre_mask.sum(-1).sum(1) + zero_cnt += torch.sum(pre_mask_sum == 0.0).data.item() + total_cnt += pre_mask_sum.numel() + return float(zero_cnt) / total_cnt + + def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer): + block_size = self.block_size + new_scores = {} + not_divided_keys = [] + for key in scores.keys(): + current_score = scores[key] + shape = current_score.shape + if len(shape) == 4: ##default is conv, is transpose conv ok ? + shape = current_score.reshape(current_score.shape[0], -1).shape + if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: + + not_divided_keys.append(key) + continue + + new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1], block_size[1]] + current_score = current_score.reshape(new_shape) + current_score_sum = current_score.sum(-1).sum(1) + new_scores[key] = current_score_sum + global_scores = torch.cat([torch.flatten(v) for v in new_scores.values()]) + k = int(target_sparsity_ratio * global_scores.numel()) + masks = {} + if not k < 1: + threshold, _ = torch.kthvalue(global_scores, k) + for key in new_scores.keys(): + score = new_scores[key] + zero = torch.tensor([0.]).to(score.device) + one = torch.tensor([1.]).to(score.device) + mask = torch.where(score <= threshold, zero, one) + mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1) + if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer: + ##to prevent some layer not be purned too much + ##this is differnt with our original implementation + masks[key]=self.get_mask_single(new_scores[key],max_sparsity_ratio_per_layer) + masks[key]=masks[key].repeat_interleave(block_size[0], 0).repeat_interleave(block_size[1], -1) + # if pre_masks != {}:##when use one shot, this is not right + # masks[key] = pre_masks[key] + # else: + # masks[key] = mask + else: + masks[key] = mask + if len(scores[key].shape)==4: + ##we need to revert back + masks[key]=masks[key].reshape(scores[key].shape) + + for key in not_divided_keys: + p = scores[key] + masks[key] = torch.ones(p.shape).to(p.device) + logger.warning(f"{key} shape {scores[key].shape} cannot be divided by {self.pattern}") + + else: + for key in scores.keys(): + p = scores[key] + masks[key] = torch.ones(p.shape).to(p.device) + return masks + + +@register_pattern('N:M') +class PatternNInM(Pattern): + def __init__(self, config): + super(PatternNInM, self).__init__(config) + pattern = self.pattern.split('_')[-1] + self.N = int(pattern.split(':')[0]) + self.M = int(pattern.split(':')[1]) ##m is bigger + + def get_sparsity_ratio(self, pre_masks): + ##simply use elemwise sparsity + non_zero_cnt = 0 + total_cnt = 0 + for key in pre_masks.keys(): + non_zero_cnt += (torch.sum(pre_masks[key])).data.item() + total_cnt += pre_masks[key].numel() + return 1.0-float(non_zero_cnt) / total_cnt + + def get_masks_global(self, scores, target_sparsity_ratio, pre_masks, max_sparsity_ratio_per_layer): + N = self.N + M = self.M + target_sparsity_ratio = target_sparsity_ratio / (float(N / M)) ##recover sparsity for block wise + all_nm_masks = {} + new_scores = {} + not_divided_keys = [] + for key in scores.keys(): + current_score = scores[key] + if len(current_score.shape) == 4: ##TODO need to verify whether it's ok for transposed conv + current_score = current_score.permute(0, 2, 3, 1)##cout,k,k,cin + current_score = current_score.reshape(current_score.shape[0], -1) + shape = current_score.shape + new_shape = [shape[0], shape[1] // M, M] + current_score_new = current_score.reshape(new_shape) + + threshold, _ = torch.kthvalue(current_score_new, N, dim=2) + threshold = threshold.unsqueeze(-1) + if shape[1] % M != 0: + not_divided_keys.append(key) + continue + threshold = threshold.expand(shape[0], shape[1] // M, M) + threshold = threshold.reshape((shape[0], shape[1])) + + one = torch.tensor([1.]).to(current_score.device) + zero = torch.tensor([0.]).to(current_score.device) + mask = torch.where(current_score <= threshold, zero, one) + current_score_new = current_score_new.reshape((shape[0], shape[1])) + ##to get the sum of N scores in each block with M + current_score_new = current_score_new * (1.0 - mask) + current_score_new = current_score_new.reshape(shape[0], shape[1] // M, M) + score_sum = torch.sum(current_score_new, dim=-1) + all_nm_masks[key] = mask + new_scores[key] = score_sum + + global_scores = torch.cat([torch.flatten(v) for v in new_scores.values()]) + k = int(target_sparsity_ratio * global_scores.numel()) + masks = {} + if not k < 1: + threshold, _ = torch.kthvalue(global_scores, k) + for key in new_scores.keys(): + score = new_scores[key] + zero = torch.tensor([0.]).to(score.device) + one = torch.tensor([1.]).to(score.device) + mask = torch.where(score <= threshold, zero, one) + mask = mask.repeat_interleave(M, dim=-1) + ## both zero will be zero + mask = (mask + all_nm_masks[key]) + mask = torch.where(mask <= 0, zero, one) + if torch.sum(mask) / mask.numel() < 1.0 - max_sparsity_ratio_per_layer: + ##trick, to prevent some layer not be purned too much + masks[key] = self.get_mask_single(new_scores[key], max_sparsity_ratio_per_layer) + masks[key] = masks[key].repeat_interleave(M, dim=-1) + ## both zero will be zero + masks[key] = (masks[key] + all_nm_masks[key]) + masks[key] = torch.where(masks[key] <= 0, zero, one) + else: + masks[key] = mask + for key in not_divided_keys: + p = scores[key] + masks[key] = torch.ones(p.shape).to(p.device) + logger.warning(f"{key} shape {scores[key].shape} cannot be divided by {self.pattern}") + + else: + for key in scores.keys(): + p = scores[key] + masks[key] = torch.ones(p.shape).to(p.device) + for key in masks.keys(): + if len(scores[key].shape) == 4 and len(masks[key].shape) == 2: ## need to permute + mask = masks[key] + mask = mask.reshape(scores[key].shape[0], scores[key].shape[2], scores[key].shape[3], + scores[key].shape[1]) + mask = mask.permute(0, 3, 1, 2) + masks[key] = mask + + return masks diff --git a/neural_compressor/experimental/pytorch_pruner/prune_utils.py b/neural_compressor/experimental/pytorch_pruner/prune_utils.py new file mode 100644 index 00000000000..ac7a357a05d --- /dev/null +++ b/neural_compressor/experimental/pytorch_pruner/prune_utils.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import yaml + +try: + from ...conf.dotdict import DotDict +except: + from .dot_dict import DotDict ##TODO +from .logger import logger + + +def check_config(prune_config): + assert prune_config['start_step'] >= 0, "start_step should be greater than 0" + assert prune_config['end_step'] >= -1, "end_step should be greater than 0" + assert prune_config['end_step'] >= prune_config['start_step'], \ + "end_step should be greater than start_step" + assert prune_config['target_sparsity'] >= 0 and prune_config['target_sparsity'] < 1.0, \ + "begin_pruning_step should be in range [0,1)" + assert prune_config['update_frequency_on_step'] > 0, "update_frequency_on_step should be greater than 0" + assert prune_config['max_sparsity_ratio_per_layer'] >= 0 and prune_config['max_sparsity_ratio_per_layer'] <1, \ + "update_frequency_on_step should be greater than 0" + assert prune_config['prune_domain'] == "global" or prune_config['prune_domain'] == "local", \ + "only support 'global' and 'local' prune domain" + if "x" in prune_config["pattern"]: + pattern = prune_config["pattern"].split('_')[-1].split('x') + N = int(pattern[0]) + M = int(pattern[1]) + assert N > 0, "N should be greater than 0" + assert M > 0, "M should be greater than 0" + if ":" in prune_config["pattern"]: + pattern = prune_config["pattern"].split('_')[-1].split(':') + N = int(pattern[0]) + M = int(pattern[1]) + assert N > 0, "N should be greater than 0" + assert M > N, "M should be greater than N" + max_ratio = float(N) / M + assert prune_config['target_sparsity'] <= max_ratio, \ + "in N:M pattern, the max sparsity is N/M={}".format(max_ratio) + prune_config['max_sparsity_ratio_per_layer'] = min(max_ratio, prune_config['max_sparsity_ratio_per_layer']) + + +def reset_non_value(obj, key, default): + if not hasattr(obj, key) or getattr(obj, key) == None: + return default + else: + return getattr(obj, key) + +def get_none_to_default(obj,key,default): + if not hasattr(obj, key) or getattr(obj, key) == None: + return default + else: + return getattr(obj, key) + + +def process_and_check_config(val): + val = val["pruning"]['approach']['weight_compression_pytorch'] + start_step = get_none_to_default(val, "start_step", 0) + end_step = get_none_to_default(val, "end_step", 0) + not_to_prune_names = get_none_to_default(val,"not_to_prune_names", []) + prune_layer_type = get_none_to_default(val,"prune_layer_type", ['Conv2d', 'Linear']) + target_sparsity = get_none_to_default(val,"target_sparsity", 0.0) ## be care of this val + update_frequency_on_step = int(get_none_to_default(val,"update_frequency_on_step", 1)) + prune_domain = get_none_to_default(val,"prune_domain", "global") + prune_type = get_none_to_default(val,"prune_type", "snip_momentum") + sparsity_decay_type = get_none_to_default(val,"sparsity_decay_type", "exp") + max_sparsity_ratio_per_layer =get_none_to_default(val,"max_sparsity_ratio_per_layer", 0.98) + names = get_none_to_default(val,"names", []) + exclude_names =get_none_to_default(val,"exclude_names", []) + pattern =get_none_to_default(val,"pattern", "tile_pattern_4x1") + + pruners_info = [] + for info in val['pruners']: + pruner = {} + pruner['start_step'] = reset_non_value(info, 'start_step', start_step) + pruner['end_step'] = reset_non_value(info, 'end_step', end_step) + pruner['not_to_prune_names'] = reset_non_value(info, 'not_to_prune_names', not_to_prune_names) + pruner['prune_layer_type'] = reset_non_value(info, 'prune_layer_type', prune_layer_type) + pruner['target_sparsity'] = reset_non_value(info, 'target_sparsity', target_sparsity) + pruner['update_frequency_on_step'] = reset_non_value(info, 'update_frequency_on_step', update_frequency_on_step) + pruner['prune_domain'] = reset_non_value(info, 'prune_domain', prune_domain) + pruner['prune_type'] = reset_non_value(info, 'prune_type', prune_type) + pruner['sparsity_decay_type'] = reset_non_value(info, 'sparsity_decay_type', sparsity_decay_type) + pruner['max_sparsity_ratio_per_layer'] = reset_non_value(info, 'max_sparsity_ratio_per_layer', + max_sparsity_ratio_per_layer) + pruner['names'] = reset_non_value(info, 'names', names) + pruner['exclude_names'] = reset_non_value(info, 'exclude_names', + exclude_names) + pruner['pattern'] = reset_non_value(info, 'pattern', + pattern) + check_config(pruner) + pruner_info = DotDict(pruner) + pruners_info.append(pruner_info) + return pruners_info + + +def process_config(config): + if isinstance(config, str): + try: + with open(config, 'r') as f: + + content = f.read() + # try: + # from .schema_check import schema + # except ImportError: + from ...conf.config import schema + + val = yaml.safe_load(content) + schema.validate(val) + except FileNotFoundError as f: + logger.error("{}.".format(f)) + raise RuntimeError( + "The yaml file is not exist. Please check the file name or path." + ) + except Exception as e: + logger.error("{}.".format(e)) + raise RuntimeError( + "The yaml file format is not correct. Please refer to document." + ) + + elif isinstance(config, DotDict): + val = config + else: + assert False, f"not supported type {config}" + + return process_and_check_config(val) + + +def parse_to_prune(model, config): + modules = {} + if config["names"] == None or config["names"] == []: + config["names"] = [".*"] + for raw in config["names"]: + try: + pattern = re.compile(raw) + except: + assert False, f"regular expression match does not support {raw}" + for name, module in filter(lambda t: pattern.search(t[0]), model.named_modules()): + if type(module).__name__ in config["prune_layer_type"]: + modules[name] = module + return modules + + +def parse_not_to_prune(modules, config): + """drop non pruned layers""" + not_to_prune = config["not_to_prune_names"] + not_to_prune.extend(config["exclude_names"]) + + patterns = [re.compile(s) for s in not_to_prune] + if len(patterns) <= 0: + return modules + new_module = {} + for name in modules.keys(): + if any([p.search(name) for p in patterns]): + continue + new_module[name] = modules[name] + return new_module diff --git a/neural_compressor/experimental/pytorch_pruner/pruner.py b/neural_compressor/experimental/pytorch_pruner/pruner.py new file mode 100644 index 00000000000..c0ca412ca9f --- /dev/null +++ b/neural_compressor/experimental/pytorch_pruner/pruner.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from .patterns import get_pattern +from .scheduler import get_scheduler + +from .logger import logger + +PRUNERS = {} + + +def register_pruners(name): + """Register a pruner to the registry""" + + def register(pruner): + PRUNERS[name] = pruner + return pruner + + return register + + +def get_pruner(moduels, config): + """Get registered pruner class""" + name = config["prune_type"] + if name not in PRUNERS.keys(): + assert False, f"does not support {name}, currently only support {PRUNERS.keys()}" + return PRUNERS[name](moduels, config) + + +class Pruner: + def __init__(self, modules, config): + self.modules = modules + self.config = config + self.masks = {} + self.scores = {} + self.reg = None ##TODO need to add reg + self.pattern = get_pattern(config) + self.scheduler = get_scheduler(config) + self.current_sparsity_ratio = 0.0 + self._init() + + def _init(self): + self.global_step = -1 + self.start_step = self.config['start_step'] + self.end_step = self.config['end_step'] + self.update_frequency_on_step = self.config['update_frequency_on_step'] + ##this is different with original code + self.total_prune_cnt = (self.end_step - self.start_step + self.update_frequency_on_step) \ + // self.update_frequency_on_step + self.completed_pruned_cnt = 0 + self.masks = {} + for key in self.modules.keys(): + module = self.modules[key] + self.masks[key] = torch.ones(module.weight.shape).to(module.weight.device) ##TODO support bias or others + + self.target_sparsity_ratio = self.config['target_sparsity'] + + self.max_sparsity_ratio_per_layer = self.config['max_sparsity_ratio_per_layer'] + + def on_epoch_begin(self, epoch): + pass + + def mask_weights(self): + with torch.no_grad(): + for key in self.modules.keys(): + module = self.modules[key] + module.weight.data = module.weight.data * self.masks[key] + + def on_step_begin(self, local_step): + self.global_step += 1 + if not self.check_is_pruned_step(self.global_step): + return + + if self.current_sparsity_ratio > self.target_sparsity_ratio: + return + + current_target_sparsity_ratio = self.scheduler.update_sparsity_ratio(self.target_sparsity_ratio, + self.completed_pruned_cnt, + self.total_prune_cnt, self.masks) + logger.info(f"current target ratio is {current_target_sparsity_ratio}") + self.update_scores() + self.completed_pruned_cnt += 1 + if self.scores == {}: + return + self.masks = self.pattern.get_masks(self.scores, current_target_sparsity_ratio, self.masks, + self.max_sparsity_ratio_per_layer) + self.mask_weights() + + self.current_sparsity_ratio = self.pattern.get_sparsity_ratio(self.masks) + logger.info(f"current sparsity ratio is {self.current_sparsity_ratio}") + + def on_epoch_end(self): + pass + + def on_step_end(self): + pass + + def on_before_optimizer_step(self): + pass + + def on_after_optimizer_step(self): + self.mask_weights() + + def on_train_begin(self): + pass + + def on_train_end(self): + pass + + def check_is_pruned_step(self, step): + if step < self.start_step or step > self.end_step: + return False + if int(step - self.start_step) % self.update_frequency_on_step == 0: + return True + return False + + def update_scores(self): + pass + + +@register_pruners('snip') +class SnipPruner(Pruner): + def __init__(self, modules, config): + super(SnipPruner, self).__init__(modules, config) + assert self.config.end_step > 0, "gradient based criteria does not work on step 0" + self.scores = {} + + def on_after_optimizer_step(self): + with torch.no_grad(): + for key in self.modules.keys(): + p = self.modules[key].weight + self.scores[key] = torch.abs(p * p.grad) + self.mask_weights() + + +@register_pruners('snip_momentum') +class SnipMomentumPruner(Pruner): + def __init__(self, modules, config): + super(SnipMomentumPruner, self).__init__(modules, config) + assert self.config.end_step > 0, "gradient based criteria does not work on step 0" + # self.scores = {} + for key in modules.keys(): + p = modules[key].weight + self.scores[key] = torch.zeros(p.shape).to(p.device) + + def on_after_optimizer_step(self): + with torch.no_grad(): + for key in self.modules.keys(): + p = self.modules[key].weight + self.scores[key] *= 0.9 ##magic number + self.scores[key] += 1.0 * torch.abs(p * p.grad) + self.mask_weights() + + +@register_pruners('magnitude') +class MagnitudePruner(Pruner): + def __init__(self, modules, config): + super(MagnitudePruner, self).__init__(modules, config) + self.scores = {} + + def update_scores(self): + with torch.no_grad(): + for key in self.modules.keys(): + p = self.modules[key].weight.data + self.scores[key] = p diff --git a/neural_compressor/experimental/pytorch_pruner/pruning.py b/neural_compressor/experimental/pytorch_pruner/pruning.py new file mode 100644 index 00000000000..c22e4aae5bd --- /dev/null +++ b/neural_compressor/experimental/pytorch_pruner/pruning.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn + +from .prune_utils import process_config, parse_to_prune, parse_not_to_prune +from .pruner import get_pruner +from .logger import logger + +class Pruning: + def __init__(self, config): + self.model = None + self.config_file_path = config + self.pruners = [] + self.pruner_info = process_config(self.config_file_path) + + def update_items_for_all_pruners(self, **kwargs): + for item in self.pruner_info: + for key in kwargs: + if key in item.keys(): + item[key] = kwargs[key] + + #def _call_pruners(self, func): + # def warpper(self, *args, **kw): + # func_name = f"{func.__name__}" + # func(self, *args, **kw) + # for prune in self.pruners: + # prun_func = getattr(prune, func_name) + # prun_func(*args, **kw) + # + # return warpper + + def _generate_pruners(self): + assert isinstance(self.model, torch.nn.Module) + + for info in self.pruner_info: + modules = parse_to_prune(self.model, info) + modules = parse_not_to_prune(modules, info) + if modules == {}: + logger.warning("one pruner hooks no layers, please have a check") + + self.pruners.append(get_pruner(modules, info)) + info['modules'] = [key for key in modules.keys()] + info['len_of_modules'] = len(info['modules']) + logger.info(info) + + #@_call_pruners + def on_train_begin(self): + self._generate_pruners() ##TODO is there better place to place + + #@_call_pruners + def on_epoch_begin(self, epoch): + for pruner in self.pruners: + pruner.on_epoch_begin(epoch) + + #@_call_pruners + def on_step_begin(self, local_step): + for pruner in self.pruners: + pruner.on_step_begin(local_step) + + #@_call_pruners + def on_before_optimizer_step(self): + for pruner in self.pruners: + pruner.on_before_optimizer_step() + + #@_call_pruners + def on_step_end(self): + for pruner in self.pruners: + pruner.on_step_end() + + #@_call_pruners + def on_epoch_end(self): + for pruner in self.pruners: + pruner.on_epoch_end() + + #@_call_pruners + def on_train_end(self): + for pruner in self.pruners: + pruner.on_train_end() + + #@_call_pruners + def on_after_optimizer_step(self): + for pruner in self.pruners: + pruner.on_after_optimizer_step() diff --git a/neural_compressor/experimental/pytorch_pruner/scheduler.py b/neural_compressor/experimental/pytorch_pruner/scheduler.py new file mode 100644 index 00000000000..4e9e183649a --- /dev/null +++ b/neural_compressor/experimental/pytorch_pruner/scheduler.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +SCHEDULERS = {} + + +def register_scheduler(name): + """Register a scheduler to the registry""" + + def register(scheduler): + SCHEDULERS[name] = scheduler + return scheduler + + return register + + +def get_scheduler(config): + """Get registered scheduler class""" + name = "iterative" + if config.start_step == config.end_step: + name = "oneshot" + return SCHEDULERS[name](config) + + +class Scheduler: + def __init__(self, config): + self.config = config + + def update_sparsity_ratio(self, aggressive_ratio, current_prune_step, total_prune_steps, masks): + raise NotImplementedError + + +@register_scheduler('oneshot') +class OneshotScheduler(Scheduler): + def __init__(self, config): + super(OneshotScheduler, self).__init__(config) + + def update_sparsity_ratio(self, aggressive_ratio, current_prune_step, total_prune_steps, masks): + return aggressive_ratio + + +@register_scheduler('iterative') +class IterativeScheduler(Scheduler): + def __init__(self, config): + super(IterativeScheduler, self).__init__(config) + # self.decay_type = config["sparsity_decay_type"] + + def update_sparsity_ratio(self, target_ratio, current_prune_step, total_prune_steps, masks): + aggressive_ratio = target_ratio + # if self.config.prune_domain == "global": + # aggressive_ratio += 0.02 + + aggressive_ratio = min(self.config.max_sparsity_ratio_per_layer, + aggressive_ratio) ##lagacy issue + + decay_type = self.config.sparsity_decay_type + if decay_type == "cos": + current_target_sparsity = (aggressive_ratio) * ( + 1.0 - math.cos(float(current_prune_step) / total_prune_steps * (math.pi / 2))) + elif decay_type == "exp": + target_dense_change_ratio = (1.0 - aggressive_ratio) ** (1 / total_prune_steps) + current_target_sparsity = 1.0 - target_dense_change_ratio ** current_prune_step + + elif decay_type == "linear": + current_target_sparsity = (aggressive_ratio) * float(current_prune_step) / total_prune_steps + + elif decay_type == "cube": + current_target_sparsity = (aggressive_ratio) * ((float(current_prune_step) / total_prune_steps) ** 3) + else: + assert False, "{} is not supported".format(decay_type) + + current_target_sparsity = min(target_ratio, current_target_sparsity) + return current_target_sparsity diff --git a/test/pruning/test_pytorch_pruning.py b/test/pruning/test_pytorch_pruning.py new file mode 100644 index 00000000000..057ba388c26 --- /dev/null +++ b/test/pruning/test_pytorch_pruning.py @@ -0,0 +1,108 @@ +import os +import shutil +import unittest + +import torch +import torchvision +import torch.nn as nn + +from neural_compressor.data import DATASETS +from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader + +def build_fake_yaml(): + fake_yaml = """ + model: + name: imagenet_prune + framework: pytorch + + pruning: + approach: + weight_compression_pytorch: + initial_sparsity: 0.0 + target_sparsity: 0.9 + start_step: 0 + end_step: 10 + not_to_prune_names: ["classifier"] + exclude_names: [".*query",".*key", ".*value"] + update_frequency_on_step: 1 + sparsity_decay_type: "exp" + pruners: + - !Pruner + start_step: 0 + end_step: 10 + prune_type: "magnitude" + names: ['layer1.*'] + prune_domain: "global" + pattern: "tile_pattern_4x1" + + - !Pruner + start_step: 1 + end_step: 1 + target_sparsity: 0.5 + prune_type: "snip_momentum" + update_frequency: 2 + names: ['layer2.*'] + prune_domain: local + pattern: "tile_pattern_2:4" + - !Pruner + start_step: 2 + end_step: 8 + target_sparsity: 0.8 + prune_type: "snip" + names: ['layer3.*'] + prune_domain: "local" + pattern: "tile_pattern_16x1" + sparsity_decay_type: "cube" + """ + with open('fake.yaml', 'w', encoding="utf-8") as f: + f.write(fake_yaml) + + + +class TestPytorchPruning(unittest.TestCase): + + model = torchvision.models.resnet18() + + @classmethod + def setUpClass(cls): + build_fake_yaml() + + @classmethod + def tearDownClass(cls): + os.remove('fake.yaml') + shutil.rmtree('./saved', ignore_errors=True) + shutil.rmtree('runs', ignore_errors=True) + + def test_pytorch_pruning(self): + from neural_compressor.experimental.pytorch_pruner.pruning import Pruning + prune = Pruning('fake.yaml') + ##prune.generate_pruners() + + prune.model = self.model + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001) + datasets = DATASETS('pytorch') + dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True) + dummy_dataloader = PyTorchDataLoader(dummy_dataset) + prune.on_train_begin() + for epoch in range(2): + self.model.train() + prune.on_epoch_begin(epoch) + local_step = 0 + for image, target in dummy_dataloader: + prune.on_step_begin(local_step) + output = self.model(image) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + prune.on_before_optimizer_step() + optimizer.step() + prune.on_after_optimizer_step() + prune.on_step_end() + local_step += 1 + + prune.on_epoch_end() + + +if __name__ == "__main__": + unittest.main()