Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GroupFisher pruning algorithm. #459

Merged
merged 43 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
1f2e662
init
Feb 15, 2023
b6c1beb
support expand dwconv
Feb 15, 2023
dad2727
add tools
Feb 16, 2023
fdd36a5
init
Feb 14, 2023
065a7ba
add import
Feb 14, 2023
2b6d08a
add configs
Feb 14, 2023
d540563
add ut and fix bug
Feb 14, 2023
711a287
update
Feb 15, 2023
68ca066
update finetune config
Feb 15, 2023
be85021
update impl imports
Feb 16, 2023
192a1a5
add deploy configs and result
Feb 16, 2023
581e77b
add _train_step
Feb 16, 2023
0725f81
detla_type -> normalization_type
Feb 16, 2023
b011204
change img link
Feb 16, 2023
658a828
add prune to config
Feb 16, 2023
cba0481
add json dump when GroupFisherSubModel init
Feb 16, 2023
831c633
update prune config
Feb 16, 2023
1d9b3e2
update finetune config
Feb 16, 2023
ee75a33
update deploy config
Feb 16, 2023
83ea9b8
update prune config
Feb 16, 2023
bafb0ee
update readme
Feb 16, 2023
50b61b8
mutable_cfg -> fix_subnet
Feb 16, 2023
56ea222
update readme
Feb 16, 2023
bb27e01
impl -> implementations
Feb 17, 2023
7cf7dd1
update script.sh
Feb 17, 2023
1435939
rm gen_fake_cfg
Feb 17, 2023
100e2e1
add Implementation to readme
Feb 17, 2023
3efc0ba
update docstring
Feb 17, 2023
46a980b
add finetune_lr to config
Feb 17, 2023
ef3a79c
update readme
Feb 17, 2023
3d1a723
fix error in config
Feb 17, 2023
d7c039b
update links
Feb 17, 2023
c56dfeb
update configs
Feb 17, 2023
7ef7a60
refine
Feb 17, 2023
a5a6d13
fix spell error
Feb 17, 2023
7caf494
add test to readme
Feb 17, 2023
274c48f
update README
Feb 17, 2023
c46042a
update readme
Feb 17, 2023
cec5b3c
update readme
Feb 17, 2023
d5c90ee
update cite format
Feb 17, 2023
39aea6c
fix for ci
Feb 17, 2023
73a4170
update to pass ci
Feb 17, 2023
bd4ed3c
update readme
Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,5 @@ venv.bak/
# Srun
*.out
batchscript-*
work_dir
mmdeploy
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ repos:
^test
| ^docs
| ^configs
| ^.*/configs*
)
214 changes: 214 additions & 0 deletions configs/pruning/base/group_fisher/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Group_fisher pruning

> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf)

## Abstract

Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy.

![pipeline](/~https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png?raw=true)

## Results and models

### Classification on ImageNet

| Model | Top-1 | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download |
| ------------------------ | ----- | ----- | ------- | --------- | ------------- | --------- | ------------------------------------- | ----------------------------------------------------- |
| ResNet50 | 76.55 | - | 4.11 | - | 25.6 | - | [mmcls][cls_r50_c] | [model][cls_r50_m] |
| ResNet50_pruned_act | 75.22 | -1.33 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_a_pc] \| [finetune][r_a_fc] | [pruned][r_a_p] \| [finetuned][r_a_f] \| [log][r_a_l] |
| ResNet50_pruned_flops | 75.61 | -0.94 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_f_pc] \| [finetune][r_f_fc] | [pruned][r_f_p] \| [finetuned][r_f_f] \| [log][r_f_l] |
| MobileNetV2 | 71.86 | - | 0.313 | - | 3.51 | - | [mmcls][cls_m_c] | [model][cls_m_m] |
| MobileNetV2_pruned_act | 70.82 | -1.04 | 0.207 | 66.1% | 3.18 | 90.6% | [prune][m_a_pc] \| [finetune][m_a_fc] | [pruned][m_a_p] \| [finetuned][m_a_f] \| [log][m_a_l] |
| MobileNetV2_pruned_flops | 70.87 | -0.99 | 0.207 | 66.1% | 2.82 | 88.7% | [prune][m_f_pc] \| [finetune][m_f_fc] | [pruned][m_f_p] \| [finetuned][m_f_f] \| [log][m_f_l] |

### Detection on COCO

| Model(Detector-Backbone) | AP | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download |
| ------------------------------ | ---- | ---- | ------- | --------- | ------------- | --------- | --------------------------------------- | -------------------------------------------------------- |
| RetinaNet-R50-FPN | 36.5 | - | 250 | - | 63.8 | - | [mmdet][det_rt_c] | [model][det_rt_m] |
| RetinaNet-R50-FPN_pruned_act | 36.5 | 0.0 | 126 | 50.4% | 34.6 | 54.2% | [prune][rt_a_pc] \| [finetune][rt_a_fc] | [pruned][rt_a_p] \| [finetuned][rt_a_f] \| [log][rt_a_l] |
| RetinaNet-R50-FPN_pruned_flops | 36.6 | +0.1 | 126 | 50.4% | 34.9 | 54.7% | [prune][rt_f_pc] \| [finetune][rt_f_fc] | [pruned][rt_f_p] \| [finetuned][rt_f_f] \| [log][rt_f_l] |

**Note**

- Because the pruning papers use different pretraining and finetuning settings, It is hard to compare them fairly. As a result, we prefer to apply algorithms on the openmmlab settings.
- This may make the experiment results are different from that in the original papers.

## Get Started

We have three steps to apply GroupFisher to your model, including Prune, Finetune, Deploy.

Note: please use torch>=1.12, as we need fxtracer to parse the models automatically.

### Prune

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \
{config_folder}/group_fisher_{normalization_type}_prune_{model_name}.py 8 \
--work-dir $WORK_DIR
```

In the pruning config file. You have to fill some args as below.

```python
"""
_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.

interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.

target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""
```

After the pruning process, you will get a checkpoint of the pruned model named flops\_{target_flop_ratio}.pth in your workdir.

### Finetune

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \
{config_folder}/group_fisher_{normalization_type}_finetune_{model_name}.py 8 \
--work-dir $WORK_DIR
```

There are also some args for you to fill in the config file as below.

```python
"""
_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""
```

After finetuning, except a checkpoint of the best model, there is also a fix_subnet.json, which records the pruned model structure. It will be used when deploying.

### Test

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_test.sh \
{config_folder}/group_fisher_{normalization_type}_finetune_{model_name}.py {checkpoint_path} 8
```

### Deploy

First, we assume you are fimilar to mmdeploy. For a pruned model, you only need to use the pruning deploy config to instead the pretrain config to deploy the pruned version of your model.

```bash
python {mmdeploy}/tools/deploy.py \
{mmdeploy}/{mmdeploy_config}.py \
{config_folder}/group_fisher_{normalization_type}_deploy_{model_name}.py \
{path_to_finetuned_checkpoint}.pth \
{mmdeploy}/tests/data/tiger.jpeg
```

The deploy config has some args as below:

```python
"""
_base_ (str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""
```

The divisor is important for the actual inference speed, and we suggest you to test it in \[1,2,4,8,16,32\] to find the fastest divisor.

## Implementation

All the modules of GroupFisher is placesded in mmrazor/implementations/pruning/group_fisher/.

| File | Module | Feature |
| -------------------- | -------------------------------------------------------------------- | --------------------------------------------------------------------------------------- |
| algorithm.py | GroupFisherAlgorithm | Dicide when to prune a channel according to the interval and the current iteration. |
| mutator.py | GroupFisherChannelMutator | Select the unit with the channel of the minimal importance and to prune it. |
| unit.py | GroupFisherChannelUnit | Compute fisher info |
| ops.py <br> counters | GroupFisherConv2d <br> GroupFisherLinear <br> corresbonding counters | Collect model info to compute fisher info, including activation, grad and tensor shape. |

There are also some modules to support GroupFisher. These modules may be refactored and moved to other folders as common modules for all pruning algorithms.

| File | Module | Feature |
| ------------------------- | ---------------------------------------- | ------------------------------------------------------------------- |
| hook.py | PruningStructureHook<br>ResourceInfoHook | Display pruning Structure iteratively. |
| prune_sub_model.py | GroupFisherSubModel | Convert a pruning algorithm(architecture) to a pruned static model. |
| prune_deploy_sub_model.py | GroupFisherDeploySubModel | Init a pruned static model for mmdeploy. |

## Citation

```latex
@InProceedings{Liu:2021,
TITLE = {Group Fisher Pruning for Practical Network Compression},
AUTHOR = {Liu, Liyang
AND Zhang, Shilong
AND Kuang, Zhanghui
AND Zhou, Aojun
AND Xue, Jing-hao
AND Wang, Xinjiang
AND Chen, Yimin
AND Yang, Wenming
AND Liao, Qingmin
AND Zhang, Wayne},
BOOKTITLE = {Proceedings of the 38th International Conference on Machine Learning},
YEAR = {2021},
SERIES = {Proceedings of Machine Learning Research},
MONTH = {18--24 Jul},
PUBLISHER = {PMLR},
}
```

<!-- model links
{model}_{prune_mode}_{file type}
model: r: resnet50, m: mobilenetv2, rt:retinanet
prune_mode: a: act, f: flops
file_type: p: pruned model, f:finetuned_model, l: log, pc: prune config, fc: finetune config.

repo link
{repo}_{model}_{file type}
-->

[cls_m_c]: /~https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py
[cls_m_m]: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth
[cls_r50_c]: /~https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py
[cls_r50_m]: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth
[det_rt_c]: /~https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_1x_coco.py
[det_rt_m]: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth
[m_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.pth
[m_a_fc]: ../../mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py
[m_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/20230130_203443.json
[m_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.pth
[m_a_pc]: ../../mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py
[m_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.pth
[m_f_fc]: ../../mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py
[m_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/20230201_211550.json
[m_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.pth
[m_f_pc]: ../../mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py
[rt_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.pth
[rt_a_fc]: ../../mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py
[rt_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/retinanet/act/20230113_231904.json
[rt_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.pth
[rt_a_pc]: ../../mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py
[rt_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.pth
[rt_f_fc]: ../../mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py
[rt_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/20230129_101502.json
[rt_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.pth
[rt_f_pc]: ../../mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py
[r_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_finetune_resnet50_8xb32_in1k.pth
[r_a_fc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py
[r_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/resnet50/act/20230130_175426.json
[r_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth
[r_a_pc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py
[r_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_finetune_resnet50_8xb32_in1k.pth
[r_f_fc]: ../../mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py
[r_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/20230129_190931.json
[r_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_prune_resnet50_8xb32_in1k.pth
[r_f_pc]: ../../mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py
24 changes: 24 additions & 0 deletions configs/pruning/base/group_fisher/group_fisher_deploy_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#############################################################################
"""You have to fill these args.

_base_(str): The path to your pretrain config file.
fix_subnet (Union[dict,str]): The dict store the pruning structure or the
json file including it.
divisor (int): The divisor the make the channel number divisible.
"""

_base_ = ''
fix_subnet = {}
divisor = 8
##############################################################################

architecture = _base_.model

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherDeploySubModel',
architecture=architecture,
fix_subnet=fix_subnet,
divisor=divisor,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#############################################################################
"""# You have to fill these args.

_base_(str): The path to your pruning config file.
pruned_path (str): The path to the checkpoint of the pruned model.
finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
rate of the pretrain.
"""

_base_ = ''
pruned_path = ''
finetune_lr = 0.1
##############################################################################

algorithm = _base_.model
algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherSubModel',
algorithm=algorithm,
)

# restore lr
optim_wrapper = dict(optimizer=dict(lr=finetune_lr))

# remove pruning related hooks
custom_hooks = _base_.custom_hooks[:-2]

# delete ddp
model_wrapper_cfg = None
75 changes: 75 additions & 0 deletions configs/pruning/base/group_fisher/group_fisher_prune_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#############################################################################
"""You have to fill these args.

_base_ (str): The path to your pretrained model checkpoint.
pretrained_path (str): The path to your pretrained model checkpoint.

interval (int): Interval between pruning two channels. You should ensure you
can reach your target pruning ratio when the training ends.
normalization_type (str): GroupFisher uses two methods to normlized the channel
importance, including ['flops','act']. The former uses flops, while the
latter uses the memory occupation of activation feature maps.
lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable,
you need to decrease the original lr rate until the pruning training work
steadly without getting nan.

target_flop_ratio (float): The target flop ratio to prune your model.
input_shape (Tuple): input shape to measure the flops.
"""

_base_ = ''
pretrained_path = ''

interval = 10
normalization_type = 'act'
lr_ratio = 0.1

target_flop_ratio = 0.5
input_shape = (1, 3, 224, 224)
##############################################################################

architecture = _base_.model

if hasattr(_base_, 'data_preprocessor'):
architecture.update({'data_preprocessor': _base_.data_preprocessor})
data_preprocessor = None

architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path)
architecture['_scope_'] = _base_.default_scope

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GroupFisherAlgorithm',
architecture=architecture,
interval=interval,
mutator=dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
channel_unit_cfg=dict(
type='GroupFisherChannelUnit',
default_args=dict(normalization_type=normalization_type, ),
),
),
)

model_wrapper_cfg = dict(
type='mmrazor.GroupFisherDDP',
broadcast_buffers=False,
)

optim_wrapper = dict(
optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio))

custom_hooks = getattr(_base_, 'custom_hooks', []) + [
dict(type='mmrazor.PruningStructureHook'),
dict(
type='mmrazor.ResourceInfoHook',
interval=interval,
demo_input=dict(
type='mmrazor.DefaultDemoInput',
input_shape=input_shape,
),
save_ckpt_thr=[target_flop_ratio],
),
]
Loading