Skip to content

Commit

Permalink
Add GroupFisher pruning algorithm. (#459)
Browse files Browse the repository at this point in the history
* init

* support expand dwconv

* add tools

* init

* add import

* add configs

* add ut and fix bug

* update

* update finetune config

* update impl imports

* add deploy configs and result

* add _train_step

* detla_type -> normalization_type

* change img link

* add prune to config

* add json dump when GroupFisherSubModel init

* update prune config

* update finetune config

* update deploy config

* update prune config

* update readme

* mutable_cfg -> fix_subnet

* update readme

* impl -> implementations

* update script.sh

* rm gen_fake_cfg

* add Implementation to readme

* update docstring

* add finetune_lr to config

* update readme

* fix error in config

* update links

* update configs

* refine

* fix spell error

* add test to readme

* update README

* update readme

* update readme

* update cite format

* fix for ci

* update to pass ci

* update readme

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>
  • Loading branch information
3 people authored Feb 20, 2023
1 parent 18754f3 commit 7acc046
Show file tree
Hide file tree
Showing 71 changed files with 3,087 additions and 22 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,5 @@ venv.bak/
# Srun
*.out
batchscript-*
work_dir
mmdeploy
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ repos:
^test
| ^docs
| ^configs
| ^.*/configs*
)
214 changes: 214 additions & 0 deletions configs/pruning/base/group_fisher/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Group_fisher pruning

> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf)
## Abstract

Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy.

![pipeline](/~https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png?raw=true)

## Results and models

### Classification on ImageNet

| Model | Top-1 | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download |
| ------------------------ | ----- | ----- | ------- | --------- | ------------- | --------- | ------------------------------------- | ----------------------------------------------------- |
| ResNet50 | 76.55 | - | 4.11 | - | 25.6 | - | [mmcls][cls_r50_c] | [model][cls_r50_m] |
| ResNet50_pruned_act | 75.22 | -1.33 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_a_pc] \| [finetune][r_a_fc] | [pruned][r_a_p] \| [finetuned][r_a_f] \| [log][r_a_l] |
| ResNet50_pruned_flops | 75.61 | -0.94 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_f_pc] \| [finetune][r_f_fc] | [pruned][r_f_p] \| [finetuned][r_f_f] \| [log][r_f_l] |
| MobileNetV2 | 71.86 | - | 0.313 | - | 3.51 | - | [mmcls][cls_m_c] | [model][cls_m_m] |
| MobileNetV2_pruned_act | 70.82 | -1.04 | 0.207 | 66.1% | 3.18 | 90.6% | [prune][m_a_pc] \| [finetune][m_a_fc] | [pruned][m_a_p] \| [finetuned][m_a_f] \| [log][m_a_l] |
| MobileNetV2_pruned_flops | 70.87 | -0.99 | 0.207 | 66.1% | 2.82 | 88.7% | [prune][m_f_pc] \| [finetune][m_f_fc] | [pruned][m_f_p] \| [finetuned][m_f_f] \| [log][m_f_l] |

### Detection on COCO

| Model(Detector-Backbone) | AP | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download |
| ------------------------------ | ---- | ---- | ------- | --------- | ------------- | --------- | --------------------------------------- | -------------------------------------------------------- |
| RetinaNet-R50-FPN | 36.5 | - | 250 | - | 63.8 | - | [mmdet][det_rt_c] | [model][det_rt_m] |
| RetinaNet-R50-FPN_pruned_act | 36.5 | 0.0 | 126 | 50.4% | 34.6 | 54.2% | [prune][rt_a_pc] \| [finetune][rt_a_fc] | [pruned][rt_a_p] \| [finetuned][rt_a_f] \| [log][rt_a_l] |
| RetinaNet-R50-FPN_pruned_flops | 36.6 | +0.1 | 126 | 50.4% | 34.9 | 54.7% | [prune][rt_f_pc] \| [finetune][rt_f_fc] | [pruned][rt_f_p] \| [finetuned][rt_f_f] \| [log][rt_f_l] |

**Note**

- Because the pruning papers use different pretraining and finetuning settings, It is hard to compare them fairly. As a result, we prefer to apply algorithms on the openmmlab settings.
- This may make the experiment results are different from that in the original papers.

## Get Started

We have three steps to apply GroupFisher to your model, including Prune, Finetune, Deploy.

Note: please use torch>=1.12, as we need fxtracer to parse the models automatically.

### Prune

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \
{config_folder}/group_fisher_{normalization_type}_prune_{model_name}.py 8 \
--work-dir $WORK_DIR
```

In the pruning config file. You have to fill some args as below.

```python
"""
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.
interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.
target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""
```

After the pruning process, you will get a checkpoint of the pruned model named flops\_{target_flop_ratio}.pth in your workdir.

### Finetune

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \
{config_folder}/group_fisher_{normalization_type}_finetune_{model_name}.py 8 \
--work-dir $WORK_DIR
```

There are also some args for you to fill in the config file as below.

```python
"""
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
```

After finetuning, except a checkpoint of the best model, there is also a fix_subnet.json, which records the pruned model structure. It will be used when deploying.

### Test

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_test.sh \
{config_folder}/group_fisher_{normalization_type}_finetune_{model_name}.py {checkpoint_path} 8
```

### Deploy

First, we assume you are fimilar to mmdeploy. For a pruned model, you only need to use the pruning deploy config to instead the pretrain config to deploy the pruned version of your model.

```bash
python {mmdeploy}/tools/deploy.py \
{mmdeploy}/{mmdeploy_config}.py \
{config_folder}/group_fisher_{normalization_type}_deploy_{model_name}.py \
{path_to_finetuned_checkpoint}.pth \
{mmdeploy}/tests/data/tiger.jpeg
```

The deploy config has some args as below:

```python
"""
_base_ (str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
```

The divisor is important for the actual inference speed, and we suggest you to test it in \[1,2,4,8,16,32\] to find the fastest divisor.

## Implementation

All the modules of GroupFisher is placesded in mmrazor/implementations/pruning/group_fisher/.

| File | Module | Feature |
| -------------------- | -------------------------------------------------------------------- | --------------------------------------------------------------------------------------- |
| algorithm.py | GroupFisherAlgorithm | Dicide when to prune a channel according to the interval and the current iteration. |
| mutator.py | GroupFisherChannelMutator | Select the unit with the channel of the minimal importance and to prune it. |
| unit.py | GroupFisherChannelUnit | Compute fisher info |
| ops.py <br> counters | GroupFisherConv2d <br> GroupFisherLinear <br> corresbonding counters | Collect model info to compute fisher info, including activation, grad and tensor shape. |

There are also some modules to support GroupFisher. These modules may be refactored and moved to other folders as common modules for all pruning algorithms.

| File | Module | Feature |
| ------------------------- | ---------------------------------------- | ------------------------------------------------------------------- |
| hook.py | PruningStructureHook<br>ResourceInfoHook | Display pruning Structure iteratively. |
| prune_sub_model.py | GroupFisherSubModel | Convert a pruning algorithm(architecture) to a pruned static model. |
| prune_deploy_sub_model.py | GroupFisherDeploySubModel | Init a pruned static model for mmdeploy. |

## Citation

```latex
@InProceedings{Liu:2021,
TITLE = {Group Fisher Pruning for Practical Network Compression},
AUTHOR = {Liu, Liyang
AND Zhang, Shilong
AND Kuang, Zhanghui
AND Zhou, Aojun
AND Xue, Jing-hao
AND Wang, Xinjiang
AND Chen, Yimin
AND Yang, Wenming
AND Liao, Qingmin
AND Zhang, Wayne},
BOOKTITLE = {Proceedings of the 38th International Conference on Machine Learning},
YEAR = {2021},
SERIES = {Proceedings of Machine Learning Research},
MONTH = {18--24 Jul},
PUBLISHER = {PMLR},
}
```

<!-- model links
{model}_{prune_mode}_{file type}
model: r: resnet50, m: mobilenetv2, rt:retinanet
prune_mode: a: act, f: flops
file_type: p: pruned model, f:finetuned_model, l: log, pc: prune config, fc: finetune config.
repo link
{repo}_{model}_{file type}
-->

[cls_m_c]: /~https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py
[cls_m_m]: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth
[cls_r50_c]: /~https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py
[cls_r50_m]: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
[det_rt_c]: /~https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_1x_coco.py
[det_rt_m]: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth
[m_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.pth
[m_a_fc]: ../../mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py
[m_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/20230130_203443.json
[m_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.pth
[m_a_pc]: ../../mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py
[m_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.pth
[m_f_fc]: ../../mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py
[m_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/20230201_211550.json
[m_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.pth
[m_f_pc]: ../../mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py
[rt_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.pth
[rt_a_fc]: ../../mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py
[rt_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/retinanet/act/20230113_231904.json
[rt_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.pth
[rt_a_pc]: ../../mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py
[rt_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.pth
[rt_f_fc]: ../../mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py
[rt_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/20230129_101502.json
[rt_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.pth
[rt_f_pc]: ../../mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py
[r_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_finetune_resnet50_8xb32_in1k.pth
[r_a_fc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py
[r_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/resnet50/act/20230130_175426.json
[r_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth
[r_a_pc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py
[r_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_finetune_resnet50_8xb32_in1k.pth
[r_f_fc]: ../../mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py
[r_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/20230129_190931.json
[r_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_prune_resnet50_8xb32_in1k.pth
[r_f_pc]: ../../mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py
24 changes: 24 additions & 0 deletions configs/pruning/base/group_fisher/group_fisher_deploy_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#############################################################################
"""You have to fill these args.
_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""

_base_ = ''
fix_subnet = {}
divisor = 8
##############################################################################

architecture = _base_.model

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#############################################################################
"""# You have to fill these args.
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""

_base_ = ''
pruned_path = ''
finetune_lr = 0.1
##############################################################################

algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)

# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))

# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]

# delete ddp
model_wrapper_cfg = None
75 changes: 75 additions & 0 deletions configs/pruning/base/group_fisher/group_fisher_prune_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#############################################################################
"""You have to fill these args.
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.
interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.
target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""

_base_ = ''
pretrained_path = ''

interval = 10
normalization_type = 'act'
lr_ratio = 0.1

target_flop_ratio = 0.5
input_shape = (1, 3, 224, 224)
##############################################################################

architecture = _base_.model

if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None

architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherAlgorithm',
architecture=architecture,
interval=interval,
mutator=dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
channel_unit_cfg=dict(
type='GroupFisherChannelUnit',
default_args=dict(normalization_type=normalization_type, ),
),
),
)

model_wrapper_cfg = dict(
type='mmrazor.GroupFisherDDP',
broadcast_buffers=False,
)

optim_wrapper = dict(
optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio))

custom_hooks = getattr(_base_, 'custom_hooks', []) + [
dict(type='mmrazor.PruningStructureHook'),
dict(
type='mmrazor.ResourceInfoHook',
interval=interval,
demo_input=dict(
type='mmrazor.DefaultDemoInput',
input_shape=input_shape,
),
save_ckpt_thr=[target_flop_ratio],
),
]
Loading

0 comments on commit 7acc046

Please sign in to comment.