diff --git a/deepspeed/compression/compress.py b/deepspeed/compression/compress.py
index 9c4632f8aef3..c1cdeade3b40 100644
--- a/deepspeed/compression/compress.py
+++ b/deepspeed/compression/compress.py
@@ -11,6 +11,11 @@
import os
import json
+try:
+ import neural_compressor as nc
+except ImportError as e:
+ nc = None
+
def check_deepspeed_config(config):
if isinstance(config, dict):
@@ -117,6 +122,26 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
layer_added_compress_methods = get_compress_methods(c_model, compress_methods, mpu=mpu)
compression_preparation(c_model, layer_added_compress_methods, mpu)
+ # For sparse pruning snip_momentum method
+ shared_parameters = compress_methods[SPARSE_PRUNING][SHARED_PARAMETERS]
+ if shared_parameters[SPARSE_PRUNING_ENABLED] and \
+ shared_parameters[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
+
+ assert nc is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning"
+
+ from .helper import generate_pruners, register_on_step_begin
+ from nc import WeightPruningConfig
+
+ config = WeightPruningConfig(target_sparsity=1 - shared_parameters[SPARSE_PRUNING_DENSE_RATIO],
+ pattern=shared_parameters[SPARSE_PRUNING_BLOCK_PATTERN],
+ pruning_frequency=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE],
+ start_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET],
+ end_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_END],
+ excluded_op_names=shared_parameters[SPARSE_PRUNING_EXCLUDED_MODULES])
+ pruners = generate_pruners(config, c_model)
+ c_model.pruners = pruners
+ register_on_step_begin(c_model)
+
return model
diff --git a/deepspeed/compression/config.py b/deepspeed/compression/config.py
index d6e241bd0f80..e1fa5ef4bdb5 100644
--- a/deepspeed/compression/config.py
+++ b/deepspeed/compression/config.py
@@ -5,7 +5,7 @@
from .constants import *
import copy
-from ..runtime.config_utils import get_scalar_param
+from ..runtime.config_utils import get_scalar_param, get_list_param
def get_compression_config(param_dict):
@@ -221,15 +221,17 @@ def get_sparse_pruning(param_dict):
# shared parameters
output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
# each sub-groups
- if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]:
+ if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED] and output[SHARED_PARAMETERS][
+ SPARSE_PRUNING_METHOD] != SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
assert DIFFERENT_GROUPS in sub_param_dict.keys(
- ), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
+ ), f"Sparse Pruning is enabled and not snip_momentum method, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
return output
def get_sparse_pruning_shared_parameters(param_dict):
output = {}
+
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[SPARSE_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_ENABLED,
@@ -237,10 +239,26 @@ def get_sparse_pruning_shared_parameters(param_dict):
output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_METHOD,
SPARSE_PRUNING_METHOD_DEFAULT)
assert output[SPARSE_PRUNING_METHOD] in [
- SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK
- ], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
+ SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK, SPARSE_PRUNING_METHOD_SNIP_MOMENTUM
+ ], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}, {SPARSE_PRUNING_METHOD_SNIP_MOMENTUM}]"
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET,
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
+ if output[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
+ output[SPARSE_PRUNING_BLOCK_PATTERN] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_BLOCK_PATTERN,
+ SPARSE_PRUNING_BLOCK_PATTERN_DEFAULT)
+ output[SPARSE_PRUNING_DENSE_RATIO] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_DENSE_RATIO,
+ SPARSE_PRUNING_DENSE_RATIO_DEFAULT)
+ assert output[SPARSE_PRUNING_DENSE_RATIO] > 0 and output[
+ SPARSE_PRUNING_DENSE_RATIO] < 1, f"Invalid dense_ratio value. Must be less than 1"
+ output[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE] = get_scalar_param(
+ sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE_DEFAULT)
+ output[SPARSE_PRUNING_EXCLUDED_MODULES] = get_list_param(sub_param_dict, SPARSE_PRUNING_EXCLUDED_MODULES,
+ SPARSE_PRUNING_EXCLUDED_MODULES_DEFAULT)
+ output[SPARSE_PRUNING_SCHEDULE_OFFSET_END] = get_scalar_param(sub_param_dict,
+ SPARSE_PRUNING_SCHEDULE_OFFSET_END,
+ output[SPARSE_PRUNING_SCHEDULE_OFFSET])
+ assert output[SPARSE_PRUNING_SCHEDULE_OFFSET] <= output[
+ SPARSE_PRUNING_SCHEDULE_OFFSET_END], f"Invalid schedule_offset and schedule_offset_end values"
else:
output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT
diff --git a/deepspeed/compression/constants.py b/deepspeed/compression/constants.py
index 7bce1f4e4c4e..67375d510a4b 100644
--- a/deepspeed/compression/constants.py
+++ b/deepspeed/compression/constants.py
@@ -12,6 +12,7 @@
DIFFERENT_GROUPS = "different_groups"
TECHNIQUE_ENABLED = "enabled"
TECHNIQUE_SCHEDULE_OFFSET = "schedule_offset"
+TECHNIQUE_SCHEDULE_OFFSET_END = "schedule_offset_end"
DIFFERENT_GROUPS_PARAMETERS = "params"
DIFFERENT_GROUPS_MODULE_SCOPE = "modules"
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT = "*"
@@ -111,11 +112,25 @@
SPARSE_PRUNING_METHOD_DEFAULT = "l1"
SPARSE_PRUNING_METHOD_L1 = "l1"
SPARSE_PRUNING_METHOD_TOPK = "topk"
+SPARSE_PRUNING_METHOD_SNIP_MOMENTUM = "snip_momentum"
+
+SPARSE_PRUNING_BLOCK_PATTERN = "block_pattern"
+SPARSE_PRUNING_BLOCK_PATTERN_DEFAULT = "4x1"
+
+SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE = "schedule_offset_stride"
+SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE_DEFAULT = 1
SPARSE_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
+SPARSE_PRUNING_SCHEDULE_OFFSET_END = TECHNIQUE_SCHEDULE_OFFSET_END
+SPARSE_PRUNING_SCHEDULE_OFFSET_END_DEFAULT = SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT
+
SPARSE_PRUNING_DENSE_RATIO = "dense_ratio"
+SPARSE_PRUNING_DENSE_RATIO_DEFAULT = 0.1
+
+SPARSE_PRUNING_EXCLUDED_MODULES = "excluded_modules"
+SPARSE_PRUNING_EXCLUDED_MODULES_DEFAULT = []
###
# Row Pruning
###
diff --git a/deepspeed/compression/helper.py b/deepspeed/compression/helper.py
index fdca916e9f15..ac06059ee2dd 100644
--- a/deepspeed/compression/helper.py
+++ b/deepspeed/compression/helper.py
@@ -6,6 +6,12 @@
import torch
from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
from .constants import *
+from deepspeed.utils import logger
+
+try:
+ from neural_compressor.compression import pruner as nc_pruner
+except ImportError as e:
+ nc_pruner = None
def recursive_getattr(model, module_name):
@@ -246,3 +252,71 @@ def convert_conv1d_to_linear(model, convert_type):
recursive_setattr(c_model, name, new_module)
return model
+
+
+def generate_pruners(config, model):
+ """Generate pruners.
+ Args:
+ config (`neural_compressor.WeightPruningConfig`)
+ The object to the class WeightPruningConfig.
+ model (`torch.nn.module`)
+ The torch module object to be pruned.
+ """
+ assert nc_pruner is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning"
+ from nc_pruner.utils import process_config, parse_to_prune
+ from nc_pruner.pruners import get_pruner
+ assert isinstance(model, torch.nn.Module)
+ pruners_info = process_config(config)
+ pruners = []
+ for info in pruners_info:
+ modules = parse_to_prune(info, model)
+ if modules == {}:
+ logger.warning("one pruner hooks no layers, please have a check")
+
+ pruners.append(get_pruner(info, modules))
+ info['modules'] = [key for key in modules.keys()]
+ info['len_of_modules'] = len(info['modules'])
+ logger.info(info)
+ return pruners
+
+
+def register_on_step_begin(model):
+ """Mount on_step_begin to the model.
+ Args:
+ model (`torch.nn.module`)
+ The torch module object to be pruned.
+ """
+
+ def hook(module, input):
+ for pruner in module.pruners:
+ pruner.on_step_begin(0)
+
+ hook_handle = model.register_forward_pre_hook(hook)
+ return hook_handle
+
+
+def rewrite_optimizer_step(opt: torch.optim.Optimizer):
+ """Mount on_before/after_optimizer_step to the optimizer.
+ Args:
+ model (`torch.opt.Optimizer`)
+ The torch optimizer object to be hooked.
+ """
+
+ def new_step(self, closure=None):
+ if hasattr(self, "pruners"):
+ for pruner in self.pruners:
+ pruner.on_before_optimizer_step()
+
+ if closure is not None:
+ res = self.orig_step(closure)
+ else:
+ res = self.orig_step()
+ if hasattr(self, "pruners"):
+ for pruner in self.pruners:
+ pruner.on_after_optimizer_step()
+ return res
+
+ opt.orig_step = opt.step
+ import types
+ opt.step = types.MethodType(new_step, opt)
+ return opt
diff --git a/deepspeed/compression/scheduler.py b/deepspeed/compression/scheduler.py
index 582ecd8f6f5e..28830d0c8888 100644
--- a/deepspeed/compression/scheduler.py
+++ b/deepspeed/compression/scheduler.py
@@ -100,7 +100,9 @@ def check_sparse_pruning(self):
return
else:
shared_parameters = sp[SHARED_PARAMETERS]
- if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
+ if self.training_steps >= shared_parameters[
+ TECHNIQUE_SCHEDULE_OFFSET] and self.training_steps <= shared_parameters[
+ TECHNIQUE_SCHEDULE_OFFSET_END]:
for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]:
for module_name in module_name_list:
module = recursive_getattr(self.model, module_name)
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 29223423d2f4..7f619428c147 100644
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -314,6 +314,12 @@ def __init__(
elif self.bfloat16_enabled():
self.optimizer = self._configure_bf16_optimizer(optimizer=None)
+ # Hook optimizer for snip_momentum pruning
+ if hasattr(model, 'pruners'):
+ from ..compression.helper import rewrite_optimizer_step
+ self.optimizer.pruners = model.pruners
+ rewrite_optimizer_step(self.optimizer)
+
# Bookkeeping for sparse support
self.sparse_tensor_module_names = set()
# if self.sparse_gradients_enabled():
diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md
index 84f2f833f212..a0fd279ff969 100755
--- a/docs/_pages/config-json.md
+++ b/docs/_pages/config-json.md
@@ -1435,6 +1435,25 @@ Different quantization sets, this is used for different quantization parameters.
}
```
+```json
+"compression_training": {
+ "sparse_pruning":{
+ "shared_parameters":{
+ "enabled": true,
+ "schedule_offset": 30,
+ "schedule_offset_end": 90,
+ "schedule_offset_stride": 15,
+ "method": "snip_momentum",
+ "block_pattern": "4x1",
+ "dense_ratio": 0.4,
+ "excluded_modules": ['classifier', 'pooler']
+ },
+ "different_groups":{
+ }
+ }
+}
+```
+
**shared_parameters**: [dictionary]
Shared parameters for all sparse pruning groups.
@@ -1443,11 +1462,17 @@ Shared parameters for all sparse pruning groups.
| ----- | ----- | ----- |
| **enabled**: [boolean] | Enable sparse pruning or not. | `false` |
| **schedule_offset**: [integer] | Enable sparse pruning after scheduled steps (can be treated as warmup steps). | `0` |
-| **method**: [string] | Choose different pruning methods, l1 (static, magnitude based) or topk (dynamic, learnable). | `"l1"` |
+| **schedule_offset_end**: [integer] | Disable sparse pruning after scheduled steps, mandotory for `snip_momentum`. | `0` |
+| **schedule_offset_stride**: [integer] | The stride of pruning on training steps, mandotory for `snip_momentum`. | `"1"` |
+| **method**: [string] | Choose different pruning methods, l1 (static, magnitude based), topk (dynamic, learnable) or snip_momentum (structured pruning). | `"l1"` |
+| **block_pattern**: [string] | Choose different structured pruning block patterns, NxM or N:M (N and M are integers). For instance, "4x1" or "2:4" are common block patterns, mandotory for `snip_momentum`. | `"4x1"` |
+| **dense_ratio**: [float] | Used to get the targeted global sparsity ratio, mandotory for `snip_momentum`. | `"0.1"` |
+| **excluded_modules**: [list] | Excluded pruning scope on some special modules like output layer. | `[]` |
**different_groups**: [dictionary]
Different pruning sets, this is used for different pruning parameters. In this example, we give one set. In practice, you can choose the number of sets based on your requirements.
+Note for `snip_momentum` method, you can leave it as empty.
| Fields | Value | Default |
| ----- | ----- | ----- |
diff --git a/docs/_tutorials/model-compression.md b/docs/_tutorials/model-compression.md
index 20f2e6a6b25b..529984664306 100644
--- a/docs/_tutorials/model-compression.md
+++ b/docs/_tutorials/model-compression.md
@@ -158,7 +158,7 @@ Pruning aims to reduce the number of parameters and operations involved in gener
| **Method** | **Type** |
| --------------------- | ------------ |
-| [Sparse pruning](#141-sparse-pruning) | Unstructured |
+| [Sparse pruning](#141-sparse-pruning) | Unstructured and Structured |
| [Row pruning](#142-row-pruning) | Structured |
| [Head pruning](#143-head-pruning) | Structured |
| [Channel pruning](#144-channel-pruning) | Structured |
@@ -166,7 +166,7 @@ Pruning aims to reduce the number of parameters and operations involved in gener
#### 1.4.1 Sparse Pruning
**What is sparse pruning**
-Sparse pruning means we set some of the elements in each weight matrix with zero values. There is no structure pattern in the zero values. One way to perform pruning is based on the absolute value of the weight parameters, see for instance [this paper](https://arxiv.org/abs/1506.02626).
+Sparse pruning means we set some of the elements in each weight matrix with zero values. Relying on the pruning method user chosen, the zero values may have structured pattern or unstructured pattern. One way to perform pruning is based on the absolute value of the weight parameters, see for instance [this paper](https://arxiv.org/abs/1506.02626). Another way to perform pruning is based on the weights' effect to the loss function when they are masked, see for instance [this paper](https://arxiv.org/abs/1810.02340).
**When to use sparse pruning**
@@ -178,11 +178,13 @@ Sparse pruning can be enabled and configured using the DeepSpeed config JSON fil
(1)`schedule_offset`, we empirically find that when using `method: topk`, it’s better to set the `schedule_offset` to a large value such as 10% of the total training steps.
-(2)`method`, we support L1 norm and topk methods. Users are welcome to contribute more methods.
+(2)`method`, we support L1 norm, topk and snip_momentum methods. Users are welcome to contribute more methods.
-(3)`sp1`, users can expand more groups such as `sp2`, `sp3`, etc.
+(3)`sp1`, users can expand more groups such as `sp2`, `sp3`, etc. Note this is not needed for snip_momentum method.
-(4)`dense_ratio`, for unstructured sparse pruning, the dense ratio could be less than 0.1 for BRET-base model while still yielding a good accuracy. For ResNet-50, the dense ratio could be as low as 0.3 while still having good accuracy on ImageNet.
+(4)`dense_ratio`, for unstructured sparse pruning, the dense ratio could be less than 0.1 for BRET-base model while still yielding a good accuracy. For ResNet-50, the dense ratio could be as low as 0.3 while still having good accuracy on ImageNet. for structured sparse pruning like snip_momentum, the dense ratio should be specified in shared_parameters and is used to calculate the global sparsity ratio.
+
+(5)`frequency`, `block_pattern` and `schedule_offset_end`, they are used to specify the pruning frequency on steps, the block-wise pruning pattern (NxM and N in M), and the end steps for pruning. For snip_momentum method, these configurations are mandotory.
The client code change is the same as [weight quantization](#12-weight-quantization).
diff --git a/requirements/requirements-sparse_pruning.txt b/requirements/requirements-sparse_pruning.txt
new file mode 100755
index 000000000000..3b96b4134cdb
--- /dev/null
+++ b/requirements/requirements-sparse_pruning.txt
@@ -0,0 +1 @@
+neural-compressor==2.1.0
diff --git a/setup.py b/setup.py
index d9bcd4ddb7b1..5d0aba18f2bb 100755
--- a/setup.py
+++ b/setup.py
@@ -65,6 +65,7 @@ def fetch_requirements(path):
'autotuning': fetch_requirements('requirements/requirements-autotuning.txt'),
'autotuning_ml': fetch_requirements('requirements/requirements-autotuning-ml.txt'),
'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'),
+ 'sparse': fetch_requirements('requirements/requirements-sparse_pruning.txt'),
'inf': fetch_requirements('requirements/requirements-inf.txt'),
'sd': fetch_requirements('requirements/requirements-sd.txt')
}