Weight pruning is a technique to make Deep Neural Network (DNN) inference more computationally efficient by reducing the number of model parameters over the course of training. However, most weight pruning techniques generally does not speed up DNN training and can even require more iterations to reach model convergence. In this work, we propose a novel Structured Data Gradient Pruning (SDGP) method that can speed up training without impacting model convergence. This approach enforces a specific sparsity structure, where only N out of every M elements in a matrix can be nonzero, making it amenable to hardware acceleration. Modern accelerators such as the Nvidia A100 GPU support this type of structured sparsity for 2 nonzeros per 4 elements in a reduction. Assuming hardware support for 2:4 sparsity, our approach can achieve a 15-25% reduction in total training time without significant impact to performance.
Check out sdgp.py for details on how the data gradients are pruned during backpropagation. To make the pruning more efficient under group-level sorting, we implemented our own CUDA kernel. This is tested only with CUDA 11.3 and PyTorch 1.10.2 using Python 3.9.
Training generally follows the configuration details in the excellent ffcv library. To fit ImageNet in a system with 256 GB of RAM using the ffcv data loader, we decreased the image size and other settings from (500, 0.5, 90) which takes 337GB to (448, 0.60, 90) which takes 229GB. We did not observe any decrease in performance comapared to the results posted in the ffcv repository on either ResNet-18 or ResNet-50 using these slightly smaller images.
SDGP Prune Function | Non zeros | Group size | Top-1 Acc. | Config | Checkpoint |
---|---|---|---|---|---|
None (dense) | 4 | 4 | 95.3 | link | link |
Random | 2 | 4 | 94.5 | link | link |
Magnitude | 2 | 4 | 95.2 | link | link |
Rescale Mag. | 1 | 4 | 95.1 | link | link |
Rescale Mag. | 2 | 4 | 95.2 | link | link |
Rescale Mag. | 1 | 8 | 94.7 | link | link |
Rescale Mag. | 2 | 8 | 95.1 | link | link |
Rescale Mag. | 4 | 8 | 95.2 | link | link |
Rescale Mag. | 2 | 16 | 95.1 | link | link |
Rescale Mag. | 4 | 16 | 95.2 | link | link |
Rescale Mag. | 8 | 16 | 95.2 | link | link |
Rescale Mag. | 4 | 32 | 94.9 | link | link |
Rescale Mag. | 8 | 32 | 95.3 | link | link |
Rescale Mag. | 16 | 32 | 95.3 | link | link |
Model | SDGP Prune Function | Non zeros | Group size | Top-1 Acc. | Config | Checkpoint |
---|---|---|---|---|---|---|
ResNet-18 | None (dense) | 4 | 4 | 71.4 | link | link |
ResNet-18 | Random | 2 | 4 | 64.3 | link | link |
ResNet-18 | Magnitude | 2 | 4 | 72.1 | link | link |
ResNet-18 | Rescale Mag. | 2 | 4 | 72.4 | link | link |
ResNet-50 | None (dense) | 4 | 4 | 78.1 | link | link |
ResNet-50 | Random | 2 | 4 | 70.3 | link | link |
ResNet-50 | Magnitude | 2 | 4 | 77.7 | link | link |
ResNet-50 | Rescale Mag. | 2 | 4 | 77.6 | link | link |
RegNetX-400MF | None (dense) | 4 | 4 | 73.3 | link | link |
RegNetX-400MF | Random | 2 | 4 | 64.3 | link | link |
RegNetX-400MF | Magnitude | 2 | 4 | 72.1 | link | link |
RegNetX-400MF | Rescale Mag. | 2 | 4 | 72.4 | link | link |