From 7acc046678a17e48636f48208055f665f0b6f1af Mon Sep 17 00:00:00 2001 From: LKJacky <108643365+LKJacky@users.noreply.github.com> Date: Mon, 20 Feb 2023 14:29:42 +0800 Subject: [PATCH] Add GroupFisher pruning algorithm. (#459) * init * support expand dwconv * add tools * init * add import * add configs * add ut and fix bug * update * update finetune config * update impl imports * add deploy configs and result * add _train_step * detla_type -> normalization_type * change img link * add prune to config * add json dump when GroupFisherSubModel init * update prune config * update finetune config * update deploy config * update prune config * update readme * mutable_cfg -> fix_subnet * update readme * impl -> implementations * update script.sh * rm gen_fake_cfg * add Implementation to readme * update docstring * add finetune_lr to config * update readme * fix error in config * update links * update configs * refine * fix spell error * add test to readme * update README * update readme * update readme * update cite format * fix for ci * update to pass ci * update readme --------- Co-authored-by: liukai Co-authored-by: Your Name --- .gitignore | 2 + .pre-commit-config.yaml | 1 + configs/pruning/base/group_fisher/README.md | 214 ++++++++++++++++ .../group_fisher_deploy_template.py | 24 ++ .../group_fisher_finetune_template.py | 32 +++ .../group_fisher_prune_template.py | 75 ++++++ configs/pruning/mmcls/group_fisher/README.md | 11 + ...sher_act_deploy_mobilenet-v2_8xb32_in1k.py | 50 ++++ ...er_act_finetune_mobilenet-v2_8xb32_in1k.py | 31 +++ ...isher_act_prune_mobilenet-v2_8xb32_in1k.py | 75 ++++++ ...er_flops_deploy_mobilenet-v2_8xb32_in1k.py | 49 ++++ ..._flops_finetune_mobilenet-v2_8xb32_in1k.py | 32 +++ ...her_flops_prune_mobilenet-v2_8xb32_in1k.py | 5 + .../mmcls/group_fisher/mobilenet/script.sh | 7 + ...p_fisher_act_deploy_resnet50_8xb32_in1k.py | 61 +++++ ...fisher_act_finetune_resnet50_8xb32_in1k.py | 31 +++ ...up_fisher_act_prune_resnet50_8xb32_in1k.py | 75 ++++++ ...fisher_flops_deploy_resnet50_8xb32_in1k.py | 61 +++++ ...sher_flops_finetune_resnet50_8xb32_in1k.py | 31 +++ ..._fisher_flops_prune_resnet50_8xb32_in1k.py | 5 + .../mmcls/group_fisher/resnet50/script.sh | 7 + configs/pruning/mmdet/group_fisher/README.md | 11 + ...er_act_deploy_retinanet_r50_fpn_1x_coco.py | 73 ++++++ ..._act_finetune_retinanet_r50_fpn_1x_coco.py | 31 +++ ...her_act_prune_retinanet_r50_fpn_1x_coco.py | 75 ++++++ ..._flops_deploy_retinanet_r50_fpn_1x_coco.py | 73 ++++++ ...lops_finetune_retinanet_r50_fpn_1x_coco.py | 31 +++ ...r_flops_prune_retinanet_r50_fpn_1x_coco.py | 5 + .../mmdet/group_fisher/retinanet/script.sh | 7 + mmrazor/engine/hooks/group_fisher_hooks.py | 9 + mmrazor/implementations/__init__.py | 13 + mmrazor/implementations/pruning/__init__.py | 4 + .../pruning/group_fisher/__init__.py | 24 ++ .../pruning/group_fisher/algorithm.py | 86 +++++++ .../pruning/group_fisher/counters.py | 16 ++ .../pruning/group_fisher/hook.py | 183 ++++++++++++++ .../pruning/group_fisher/mutator.py | 87 +++++++ .../pruning/group_fisher/ops.py | 150 +++++++++++ .../group_fisher/prune_deploy_sub_model.py | 65 +++++ .../pruning/group_fisher/prune_sub_model.py | 105 ++++++++ .../pruning/group_fisher/unit.py | 230 +++++++++++++++++ .../pruning/group_fisher_algoritho.py | 7 + .../dynamic_ops/bricks/group_fisher_ops.py | 11 + .../units/group_fisher_unit.py | 7 + .../units/sequential_mutable_channel_unit.py | 2 +- .../channel_mutator/channel_mutator.py | 1 + .../channel_mutator/group_fisher_mutator.py | 7 + .../demo_inputs/default_demo_inputs.py | 9 +- .../task_modules/demo_inputs/demo_inputs.py | 8 +- .../counters/op_counters/__init__.py | 26 +- .../op_counters/dynamic_op_counters.py | 60 +++++ .../op_counters/group_fisher_counters.py | 7 + .../models/utils/expandable_utils/__init__.py | 15 ++ mmrazor/models/utils/expandable_utils/ops.py | 237 ++++++++++++++++++ .../models/utils/expandable_utils/tools.py | 84 +++++++ mmrazor/models/utils/expandable_utils/unit.py | 29 +++ mmrazor/utils/__init__.py | 3 +- mmrazor/utils/runtime_info.py | 58 +++++ mmrazor/utils/setup_env.py | 1 + tests/data/models.py | 39 ++- tests/test_impl/__init__.py | 1 + tests/test_impl/test_pruning/__init__.py | 1 + .../test_group_fisher/__init__.py | 1 + .../test_group_fisher/test_algorithm.py | 68 +++++ .../test_prune_deploy_sub_model.py | 49 ++++ .../test_group_fisher/test_prune_sub_model.py | 64 +++++ .../test_group_fisher/test_unit.py | 44 ++++ tests/test_models/test_utils/__init__.py | 1 + .../test_expandable_utils/__init__.py | 1 + .../test_expandable_utils/test_expand.py | 56 +++++ tools/pruning/get_flops.py | 55 ++++ 71 files changed, 3087 insertions(+), 22 deletions(-) create mode 100644 configs/pruning/base/group_fisher/README.md create mode 100644 configs/pruning/base/group_fisher/group_fisher_deploy_template.py create mode 100644 configs/pruning/base/group_fisher/group_fisher_finetune_template.py create mode 100644 configs/pruning/base/group_fisher/group_fisher_prune_template.py create mode 100644 configs/pruning/mmcls/group_fisher/README.md create mode 100644 configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_deploy_mobilenet-v2_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_deploy_mobilenet-v2_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/mobilenet/script.sh create mode 100644 configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_deploy_resnet50_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_deploy_resnet50_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py create mode 100644 configs/pruning/mmcls/group_fisher/resnet50/script.sh create mode 100644 configs/pruning/mmdet/group_fisher/README.md create mode 100644 configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_deploy_retinanet_r50_fpn_1x_coco.py create mode 100644 configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py create mode 100644 configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py create mode 100644 configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_deploy_retinanet_r50_fpn_1x_coco.py create mode 100644 configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py create mode 100644 configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py create mode 100644 configs/pruning/mmdet/group_fisher/retinanet/script.sh create mode 100644 mmrazor/engine/hooks/group_fisher_hooks.py create mode 100644 mmrazor/implementations/__init__.py create mode 100644 mmrazor/implementations/pruning/__init__.py create mode 100644 mmrazor/implementations/pruning/group_fisher/__init__.py create mode 100644 mmrazor/implementations/pruning/group_fisher/algorithm.py create mode 100644 mmrazor/implementations/pruning/group_fisher/counters.py create mode 100644 mmrazor/implementations/pruning/group_fisher/hook.py create mode 100644 mmrazor/implementations/pruning/group_fisher/mutator.py create mode 100644 mmrazor/implementations/pruning/group_fisher/ops.py create mode 100644 mmrazor/implementations/pruning/group_fisher/prune_deploy_sub_model.py create mode 100644 mmrazor/implementations/pruning/group_fisher/prune_sub_model.py create mode 100644 mmrazor/implementations/pruning/group_fisher/unit.py create mode 100644 mmrazor/models/algorithms/pruning/group_fisher_algoritho.py create mode 100644 mmrazor/models/architectures/dynamic_ops/bricks/group_fisher_ops.py create mode 100644 mmrazor/models/mutables/mutable_channel/units/group_fisher_unit.py create mode 100644 mmrazor/models/mutators/channel_mutator/group_fisher_mutator.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/dynamic_op_counters.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/group_fisher_counters.py create mode 100644 mmrazor/models/utils/expandable_utils/__init__.py create mode 100644 mmrazor/models/utils/expandable_utils/ops.py create mode 100644 mmrazor/models/utils/expandable_utils/tools.py create mode 100644 mmrazor/models/utils/expandable_utils/unit.py create mode 100644 mmrazor/utils/runtime_info.py create mode 100644 tests/test_impl/__init__.py create mode 100644 tests/test_impl/test_pruning/__init__.py create mode 100644 tests/test_impl/test_pruning/test_group_fisher/__init__.py create mode 100644 tests/test_impl/test_pruning/test_group_fisher/test_algorithm.py create mode 100644 tests/test_impl/test_pruning/test_group_fisher/test_prune_deploy_sub_model.py create mode 100644 tests/test_impl/test_pruning/test_group_fisher/test_prune_sub_model.py create mode 100644 tests/test_impl/test_pruning/test_group_fisher/test_unit.py create mode 100644 tests/test_models/test_utils/__init__.py create mode 100644 tests/test_models/test_utils/test_expandable_utils/__init__.py create mode 100644 tests/test_models/test_utils/test_expandable_utils/test_expand.py create mode 100644 tools/pruning/get_flops.py diff --git a/.gitignore b/.gitignore index 321634843..92e7c7929 100644 --- a/.gitignore +++ b/.gitignore @@ -121,3 +121,5 @@ venv.bak/ # Srun *.out batchscript-* +work_dir +mmdeploy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 491ddaa78..cd73ef928 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,4 +68,5 @@ repos: ^test | ^docs | ^configs + | ^.*/configs* ) diff --git a/configs/pruning/base/group_fisher/README.md b/configs/pruning/base/group_fisher/README.md new file mode 100644 index 000000000..cf85eb39e --- /dev/null +++ b/configs/pruning/base/group_fisher/README.md @@ -0,0 +1,214 @@ +# Group_fisher pruning + +> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf) + +## Abstract + +Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy. + +![pipeline](/~https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png?raw=true) + +## Results and models + +### Classification on ImageNet + +| Model | Top-1 | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download | +| ------------------------ | ----- | ----- | ------- | --------- | ------------- | --------- | ------------------------------------- | ----------------------------------------------------- | +| ResNet50 | 76.55 | - | 4.11 | - | 25.6 | - | [mmcls][cls_r50_c] | [model][cls_r50_m] | +| ResNet50_pruned_act | 75.22 | -1.33 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_a_pc] \| [finetune][r_a_fc] | [pruned][r_a_p] \| [finetuned][r_a_f] \| [log][r_a_l] | +| ResNet50_pruned_flops | 75.61 | -0.94 | 2.06 | 50.1% | 16.3 | 63.7% | [prune][r_f_pc] \| [finetune][r_f_fc] | [pruned][r_f_p] \| [finetuned][r_f_f] \| [log][r_f_l] | +| MobileNetV2 | 71.86 | - | 0.313 | - | 3.51 | - | [mmcls][cls_m_c] | [model][cls_m_m] | +| MobileNetV2_pruned_act | 70.82 | -1.04 | 0.207 | 66.1% | 3.18 | 90.6% | [prune][m_a_pc] \| [finetune][m_a_fc] | [pruned][m_a_p] \| [finetuned][m_a_f] \| [log][m_a_l] | +| MobileNetV2_pruned_flops | 70.87 | -0.99 | 0.207 | 66.1% | 2.82 | 88.7% | [prune][m_f_pc] \| [finetune][m_f_fc] | [pruned][m_f_p] \| [finetuned][m_f_f] \| [log][m_f_l] | + +### Detection on COCO + +| Model(Detector-Backbone) | AP | Gap | Flop(G) | Remain(%) | Parameters(M) | Remain(%) | Config | Download | +| ------------------------------ | ---- | ---- | ------- | --------- | ------------- | --------- | --------------------------------------- | -------------------------------------------------------- | +| RetinaNet-R50-FPN | 36.5 | - | 250 | - | 63.8 | - | [mmdet][det_rt_c] | [model][det_rt_m] | +| RetinaNet-R50-FPN_pruned_act | 36.5 | 0.0 | 126 | 50.4% | 34.6 | 54.2% | [prune][rt_a_pc] \| [finetune][rt_a_fc] | [pruned][rt_a_p] \| [finetuned][rt_a_f] \| [log][rt_a_l] | +| RetinaNet-R50-FPN_pruned_flops | 36.6 | +0.1 | 126 | 50.4% | 34.9 | 54.7% | [prune][rt_f_pc] \| [finetune][rt_f_fc] | [pruned][rt_f_p] \| [finetuned][rt_f_f] \| [log][rt_f_l] | + +**Note** + +- Because the pruning papers use different pretraining and finetuning settings, It is hard to compare them fairly. As a result, we prefer to apply algorithms on the openmmlab settings. +- This may make the experiment results are different from that in the original papers. + +## Get Started + +We have three steps to apply GroupFisher to your model, including Prune, Finetune, Deploy. + +Note: please use torch>=1.12, as we need fxtracer to parse the models automatically. + +### Prune + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \ + {config_folder}/group_fisher_{normalization_type}_prune_{model_name}.py 8 \ + --work-dir $WORK_DIR +``` + +In the pruning config file. You have to fill some args as below. + +```python +""" +_base_ (str): The path to your pretrained model checkpoint. +pretrained_path (str): The path to your pretrained model checkpoint. + +interval (int): Interval between pruning two channels. You should ensure you + can reach your target pruning ratio when the training ends. +normalization_type (str): GroupFisher uses two methods to normlized the channel + importance, including ['flops','act']. The former uses flops, while the + latter uses the memory occupation of activation feature maps. +lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable, + you need to decrease the original lr rate until the pruning training work + steadly without getting nan. + +target_flop_ratio (float): The target flop ratio to prune your model. +input_shape (Tuple): input shape to measure the flops. +""" +``` + +After the pruning process, you will get a checkpoint of the pruned model named flops\_{target_flop_ratio}.pth in your workdir. + +### Finetune + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh \ + {config_folder}/group_fisher_{normalization_type}_finetune_{model_name}.py 8 \ + --work-dir $WORK_DIR +``` + +There are also some args for you to fill in the config file as below. + +```python +""" +_base_(str): The path to your pruning config file. +pruned_path (str): The path to the checkpoint of the pruned model. +finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr + rate of the pretrain. +""" +``` + +After finetuning, except a checkpoint of the best model, there is also a fix_subnet.json, which records the pruned model structure. It will be used when deploying. + +### Test + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_test.sh \ + {config_folder}/group_fisher_{normalization_type}_finetune_{model_name}.py {checkpoint_path} 8 +``` + +### Deploy + +First, we assume you are fimilar to mmdeploy. For a pruned model, you only need to use the pruning deploy config to instead the pretrain config to deploy the pruned version of your model. + +```bash +python {mmdeploy}/tools/deploy.py \ + {mmdeploy}/{mmdeploy_config}.py \ + {config_folder}/group_fisher_{normalization_type}_deploy_{model_name}.py \ + {path_to_finetuned_checkpoint}.pth \ + {mmdeploy}/tests/data/tiger.jpeg +``` + +The deploy config has some args as below: + +```python +""" +_base_ (str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" +``` + +The divisor is important for the actual inference speed, and we suggest you to test it in \[1,2,4,8,16,32\] to find the fastest divisor. + +## Implementation + +All the modules of GroupFisher is placesded in mmrazor/implementations/pruning/group_fisher/. + +| File | Module | Feature | +| -------------------- | -------------------------------------------------------------------- | --------------------------------------------------------------------------------------- | +| algorithm.py | GroupFisherAlgorithm | Dicide when to prune a channel according to the interval and the current iteration. | +| mutator.py | GroupFisherChannelMutator | Select the unit with the channel of the minimal importance and to prune it. | +| unit.py | GroupFisherChannelUnit | Compute fisher info | +| ops.py
counters | GroupFisherConv2d
GroupFisherLinear
corresbonding counters | Collect model info to compute fisher info, including activation, grad and tensor shape. | + +There are also some modules to support GroupFisher. These modules may be refactored and moved to other folders as common modules for all pruning algorithms. + +| File | Module | Feature | +| ------------------------- | ---------------------------------------- | ------------------------------------------------------------------- | +| hook.py | PruningStructureHook
ResourceInfoHook | Display pruning Structure iteratively. | +| prune_sub_model.py | GroupFisherSubModel | Convert a pruning algorithm(architecture) to a pruned static model. | +| prune_deploy_sub_model.py | GroupFisherDeploySubModel | Init a pruned static model for mmdeploy. | + +## Citation + +```latex +@InProceedings{Liu:2021, + TITLE = {Group Fisher Pruning for Practical Network Compression}, + AUTHOR = {Liu, Liyang + AND Zhang, Shilong + AND Kuang, Zhanghui + AND Zhou, Aojun + AND Xue, Jing-hao + AND Wang, Xinjiang + AND Chen, Yimin + AND Yang, Wenming + AND Liao, Qingmin + AND Zhang, Wayne}, + BOOKTITLE = {Proceedings of the 38th International Conference on Machine Learning}, + YEAR = {2021}, + SERIES = {Proceedings of Machine Learning Research}, + MONTH = {18--24 Jul}, + PUBLISHER = {PMLR}, +} +``` + + + +[cls_m_c]: /~https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py +[cls_m_m]: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth +[cls_r50_c]: /~https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py +[cls_r50_m]: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth +[det_rt_c]: /~https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_1x_coco.py +[det_rt_m]: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth +[m_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.pth +[m_a_fc]: ../../mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py +[m_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/20230130_203443.json +[m_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.pth +[m_a_pc]: ../../mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py +[m_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.pth +[m_f_fc]: ../../mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py +[m_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/20230201_211550.json +[m_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.pth +[m_f_pc]: ../../mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py +[rt_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.pth +[rt_a_fc]: ../../mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py +[rt_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/retinanet/act/20230113_231904.json +[rt_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.pth +[rt_a_pc]: ../../mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py +[rt_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.pth +[rt_f_fc]: ../../mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py +[rt_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/20230129_101502.json +[rt_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.pth +[rt_f_pc]: ../../mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py +[r_a_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_finetune_resnet50_8xb32_in1k.pth +[r_a_fc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py +[r_a_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/resnet50/act/20230130_175426.json +[r_a_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth +[r_a_pc]: ../../mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py +[r_f_f]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_finetune_resnet50_8xb32_in1k.pth +[r_f_fc]: ../../mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py +[r_f_l]: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/20230129_190931.json +[r_f_p]: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_prune_resnet50_8xb32_in1k.pth +[r_f_pc]: ../../mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py diff --git a/configs/pruning/base/group_fisher/group_fisher_deploy_template.py b/configs/pruning/base/group_fisher/group_fisher_deploy_template.py new file mode 100644 index 000000000..996444837 --- /dev/null +++ b/configs/pruning/base/group_fisher/group_fisher_deploy_template.py @@ -0,0 +1,24 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = '' +fix_subnet = {} +divisor = 8 +############################################################################## + +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/base/group_fisher/group_fisher_finetune_template.py b/configs/pruning/base/group_fisher/group_fisher_finetune_template.py new file mode 100644 index 000000000..bee977d93 --- /dev/null +++ b/configs/pruning/base/group_fisher/group_fisher_finetune_template.py @@ -0,0 +1,32 @@ +############################################################################# +"""# You have to fill these args. + +_base_(str): The path to your pruning config file. +pruned_path (str): The path to the checkpoint of the pruned model. +finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr + rate of the pretrain. +""" + +_base_ = '' +pruned_path = '' +finetune_lr = 0.1 +############################################################################## + +algorithm = _base_.model +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherSubModel', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/configs/pruning/base/group_fisher/group_fisher_prune_template.py b/configs/pruning/base/group_fisher/group_fisher_prune_template.py new file mode 100644 index 000000000..dd2911d7e --- /dev/null +++ b/configs/pruning/base/group_fisher/group_fisher_prune_template.py @@ -0,0 +1,75 @@ +############################################################################# +"""You have to fill these args. + +_base_ (str): The path to your pretrained model checkpoint. +pretrained_path (str): The path to your pretrained model checkpoint. + +interval (int): Interval between pruning two channels. You should ensure you + can reach your target pruning ratio when the training ends. +normalization_type (str): GroupFisher uses two methods to normlized the channel + importance, including ['flops','act']. The former uses flops, while the + latter uses the memory occupation of activation feature maps. +lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable, + you need to decrease the original lr rate until the pruning training work + steadly without getting nan. + +target_flop_ratio (float): The target flop ratio to prune your model. +input_shape (Tuple): input shape to measure the flops. +""" + +_base_ = '' +pretrained_path = '' + +interval = 10 +normalization_type = 'act' +lr_ratio = 0.1 + +target_flop_ratio = 0.5 +input_shape = (1, 3, 224, 224) +############################################################################## + +architecture = _base_.model + +if hasattr(_base_, 'data_preprocessor'): + architecture.update({'data_preprocessor': _base_.data_preprocessor}) + data_preprocessor = None + +architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path) +architecture['_scope_'] = _base_.default_scope + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherAlgorithm', + architecture=architecture, + interval=interval, + mutator=dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict( + type='GroupFisherChannelUnit', + default_args=dict(normalization_type=normalization_type, ), + ), + ), +) + +model_wrapper_cfg = dict( + type='mmrazor.GroupFisherDDP', + broadcast_buffers=False, +) + +optim_wrapper = dict( + optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio)) + +custom_hooks = getattr(_base_, 'custom_hooks', []) + [ + dict(type='mmrazor.PruningStructureHook'), + dict( + type='mmrazor.ResourceInfoHook', + interval=interval, + demo_input=dict( + type='mmrazor.DefaultDemoInput', + input_shape=input_shape, + ), + save_ckpt_thr=[target_flop_ratio], + ), +] diff --git a/configs/pruning/mmcls/group_fisher/README.md b/configs/pruning/mmcls/group_fisher/README.md new file mode 100644 index 000000000..9b3b09936 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/README.md @@ -0,0 +1,11 @@ +# Group_fisher pruning + +> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf) + +## Abstract + +Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy. + +![pipeline](/~https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png?raw=true) + +**Please refer to the [full README](../../base/group_fisher/README.md) for more details.** diff --git a/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_deploy_mobilenet-v2_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_deploy_mobilenet-v2_8xb32_in1k.py new file mode 100644 index 000000000..21beb3370 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_deploy_mobilenet-v2_8xb32_in1k.py @@ -0,0 +1,50 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py' +fix_subnet = { + 'backbone.conv1.conv_(0, 32)_32': 21, + 'backbone.layer1.0.conv.1.conv_(0, 16)_16': 10, + 'backbone.layer2.0.conv.0.conv_(0, 96)_96': 45, + 'backbone.layer2.0.conv.2.conv_(0, 24)_24': 24, + 'backbone.layer2.1.conv.0.conv_(0, 144)_144': 73, + 'backbone.layer3.0.conv.0.conv_(0, 144)_144': 85, + 'backbone.layer3.0.conv.2.conv_(0, 32)_32': 32, + 'backbone.layer3.1.conv.0.conv_(0, 192)_192': 95, + 'backbone.layer3.2.conv.0.conv_(0, 192)_192': 76, + 'backbone.layer4.0.conv.0.conv_(0, 192)_192': 160, + 'backbone.layer4.0.conv.2.conv_(0, 64)_64': 64, + 'backbone.layer4.1.conv.0.conv_(0, 384)_384': 204, + 'backbone.layer4.2.conv.0.conv_(0, 384)_384': 200, + 'backbone.layer4.3.conv.0.conv_(0, 384)_384': 217, + 'backbone.layer5.0.conv.0.conv_(0, 384)_384': 344, + 'backbone.layer5.0.conv.2.conv_(0, 96)_96': 96, + 'backbone.layer5.1.conv.0.conv_(0, 576)_576': 348, + 'backbone.layer5.2.conv.0.conv_(0, 576)_576': 338, + 'backbone.layer6.0.conv.0.conv_(0, 576)_576': 543, + 'backbone.layer6.0.conv.2.conv_(0, 160)_160': 160, + 'backbone.layer6.1.conv.0.conv_(0, 960)_960': 810, + 'backbone.layer6.2.conv.0.conv_(0, 960)_960': 803, + 'backbone.layer7.0.conv.0.conv_(0, 960)_960': 944, + 'backbone.layer7.0.conv.2.conv_(0, 320)_320': 320 +} +divisor = 8 + +############################################################################## + +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py new file mode 100644 index 000000000..151e06103 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py @@ -0,0 +1,31 @@ +############################################################################# +"""# You have to fill these args. + +_base_(str): The path to your pruning config file. +pruned_path (str): The path to the checkpoint of the pruned model. +finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr + rate of the pretrain. +""" + +_base_ = './group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py' +pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.pth' # noqa +finetune_lr = 0.045 +############################################################################## +algorithm = _base_.model +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherSubModel', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py new file mode 100644 index 000000000..dd4d60ab1 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py @@ -0,0 +1,75 @@ +############################################################################# +"""You have to fill these args. + +_base_ (str): The path to your pretrained model checkpoint. +pretrained_path (str): The path to your pretrained model checkpoint. + +interval (int): Interval between pruning two channels. You should ensure you + can reach your target pruning ratio when the training ends. +normalization_type (str): GroupFisher uses two methods to normlized the channel + importance, including ['flops','act']. The former uses flops, while the + latter uses the memory occupation of activation feature maps. +lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable, + you need to decrease the original lr rate until the pruning training work + steadly without getting nan. + +target_flop_ratio (float): The target flop ratio to prune your model. +input_shape (Tuple): input shape to measure the flops. +""" + +_base_ = 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py' +pretrained_path = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa + +interval = 25 +normalization_type = 'act' +lr_ratio = 0.1125 + +target_flop_ratio = 0.65 +input_shape = (1, 3, 224, 224) +############################################################################## + +architecture = _base_.model + +if hasattr(_base_, 'data_preprocessor'): + architecture.update({'data_preprocessor': _base_.data_preprocessor}) + data_preprocessor = None + +architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path) +architecture['_scope_'] = _base_.default_scope + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherAlgorithm', + architecture=architecture, + interval=interval, + mutator=dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict( + type='GroupFisherChannelUnit', + default_args=dict(normalization_type=normalization_type, ), + ), + ), +) + +model_wrapper_cfg = dict( + type='mmrazor.GroupFisherDDP', + broadcast_buffers=False, +) + +optim_wrapper = dict( + optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio)) + +custom_hooks = getattr(_base_, 'custom_hooks', []) + [ + dict(type='mmrazor.PruningStructureHook'), + dict( + type='mmrazor.ResourceInfoHook', + interval=interval, + demo_input=dict( + type='mmrazor.DefaultDemoInput', + input_shape=input_shape, + ), + save_ckpt_thr=[target_flop_ratio], + ), +] diff --git a/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_deploy_mobilenet-v2_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_deploy_mobilenet-v2_8xb32_in1k.py new file mode 100644 index 000000000..f89689ae8 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_deploy_mobilenet-v2_8xb32_in1k.py @@ -0,0 +1,49 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py' +fix_subnet = { + 'backbone.conv1.conv_(0, 32)_32': 27, + 'backbone.layer1.0.conv.1.conv_(0, 16)_16': 16, + 'backbone.layer2.0.conv.0.conv_(0, 96)_96': 77, + 'backbone.layer2.0.conv.2.conv_(0, 24)_24': 24, + 'backbone.layer2.1.conv.0.conv_(0, 144)_144': 85, + 'backbone.layer3.0.conv.0.conv_(0, 144)_144': 115, + 'backbone.layer3.0.conv.2.conv_(0, 32)_32': 32, + 'backbone.layer3.1.conv.0.conv_(0, 192)_192': 102, + 'backbone.layer3.2.conv.0.conv_(0, 192)_192': 95, + 'backbone.layer4.0.conv.0.conv_(0, 192)_192': 181, + 'backbone.layer4.0.conv.2.conv_(0, 64)_64': 64, + 'backbone.layer4.1.conv.0.conv_(0, 384)_384': 169, + 'backbone.layer4.2.conv.0.conv_(0, 384)_384': 176, + 'backbone.layer4.3.conv.0.conv_(0, 384)_384': 180, + 'backbone.layer5.0.conv.0.conv_(0, 384)_384': 308, + 'backbone.layer5.0.conv.2.conv_(0, 96)_96': 96, + 'backbone.layer5.1.conv.0.conv_(0, 576)_576': 223, + 'backbone.layer5.2.conv.0.conv_(0, 576)_576': 241, + 'backbone.layer6.0.conv.0.conv_(0, 576)_576': 511, + 'backbone.layer6.0.conv.2.conv_(0, 160)_160': 160, + 'backbone.layer6.1.conv.0.conv_(0, 960)_960': 467, + 'backbone.layer6.2.conv.0.conv_(0, 960)_960': 510, + 'backbone.layer7.0.conv.0.conv_(0, 960)_960': 771, + 'backbone.layer7.0.conv.2.conv_(0, 320)_320': 320 +} +divisor = 8 + +############################################################################## +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py new file mode 100644 index 000000000..18c9a99f1 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py @@ -0,0 +1,32 @@ +############################################################################# +"""# You have to fill these args. + +_base_(str): The path to your pruning config file. +pruned_path (str): The path to the checkpoint of the pruned model. +finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr + rate of the pretrain. +""" + +_base_ = './group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py' +pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.pth' # noqa +finetune_lr = 0.045 +############################################################################## + +algorithm = _base_.model +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherSubModel', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py new file mode 100644 index 000000000..65a1fdd20 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py @@ -0,0 +1,5 @@ +_base_ = './group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py' +model = dict( + mutator=dict( + channel_unit_cfg=dict( + default_args=dict(normalization_type='flops', ), ), ), ) diff --git a/configs/pruning/mmcls/group_fisher/mobilenet/script.sh b/configs/pruning/mmcls/group_fisher/mobilenet/script.sh new file mode 100644 index 000000000..35bc63164 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/mobilenet/script.sh @@ -0,0 +1,7 @@ +# act mode +bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py 8 +bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py 8 + +# flops mode +bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py 8 +bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py 8 diff --git a/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_deploy_resnet50_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_deploy_resnet50_8xb32_in1k.py new file mode 100644 index 000000000..8fcb4082a --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_deploy_resnet50_8xb32_in1k.py @@ -0,0 +1,61 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = 'mmcls::resnet/resnet50_8xb32_in1k.py' +fix_subnet = { + 'backbone.conv1_(0, 64)_64': 61, + 'backbone.layer1.0.conv1_(0, 64)_64': 27, + 'backbone.layer1.0.conv2_(0, 64)_64': 35, + 'backbone.layer1.0.conv3_(0, 256)_256': 241, + 'backbone.layer1.1.conv1_(0, 64)_64': 32, + 'backbone.layer1.1.conv2_(0, 64)_64': 29, + 'backbone.layer1.2.conv1_(0, 64)_64': 27, + 'backbone.layer1.2.conv2_(0, 64)_64': 42, + 'backbone.layer2.0.conv1_(0, 128)_128': 87, + 'backbone.layer2.0.conv2_(0, 128)_128': 107, + 'backbone.layer2.0.conv3_(0, 512)_512': 512, + 'backbone.layer2.1.conv1_(0, 128)_128': 44, + 'backbone.layer2.1.conv2_(0, 128)_128': 50, + 'backbone.layer2.2.conv1_(0, 128)_128': 52, + 'backbone.layer2.2.conv2_(0, 128)_128': 81, + 'backbone.layer2.3.conv1_(0, 128)_128': 47, + 'backbone.layer2.3.conv2_(0, 128)_128': 50, + 'backbone.layer3.0.conv1_(0, 256)_256': 210, + 'backbone.layer3.0.conv2_(0, 256)_256': 206, + 'backbone.layer3.0.conv3_(0, 1024)_1024': 1024, + 'backbone.layer3.1.conv1_(0, 256)_256': 107, + 'backbone.layer3.1.conv2_(0, 256)_256': 108, + 'backbone.layer3.2.conv1_(0, 256)_256': 86, + 'backbone.layer3.2.conv2_(0, 256)_256': 126, + 'backbone.layer3.3.conv1_(0, 256)_256': 91, + 'backbone.layer3.3.conv2_(0, 256)_256': 112, + 'backbone.layer3.4.conv1_(0, 256)_256': 98, + 'backbone.layer3.4.conv2_(0, 256)_256': 110, + 'backbone.layer3.5.conv1_(0, 256)_256': 112, + 'backbone.layer3.5.conv2_(0, 256)_256': 115, + 'backbone.layer4.0.conv1_(0, 512)_512': 397, + 'backbone.layer4.0.conv2_(0, 512)_512': 427, + 'backbone.layer4.1.conv1_(0, 512)_512': 373, + 'backbone.layer4.1.conv2_(0, 512)_512': 348, + 'backbone.layer4.2.conv1_(0, 512)_512': 433, + 'backbone.layer4.2.conv2_(0, 512)_512': 384 +} +divisor = 8 +############################################################################## + +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py new file mode 100644 index 000000000..5d1f9380c --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py @@ -0,0 +1,31 @@ +############################################################################# +"""# You have to fill these args. + +_base_(str): The path to your pruning config file. +pruned_path (str): The path to the checkpoint of the pruned model. +finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr + rate of the pretrain. +""" + +_base_ = './group_fisher_act_prune_resnet50_8xb32_in1k.py' +pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth' # noqa +finetune_lr = 0.1 +############################################################################## +algorithm = _base_.model +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherSubModel', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py new file mode 100644 index 000000000..e37a2a79b --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py @@ -0,0 +1,75 @@ +############################################################################# +"""You have to fill these args. + +_base_ (str): The path to your pretrained model checkpoint. +pretrained_path (str): The path to your pretrained model checkpoint. + +interval (int): Interval between pruning two channels. You should ensure you + can reach your target pruning ratio when the training ends. +normalization_type (str): GroupFisher uses two methods to normlized the channel + importance, including ['flops','act']. The former uses flops, while the + latter uses the memory occupation of activation feature maps. +lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable, + you need to decrease the original lr rate until the pruning training work + steadly without getting nan. + +target_flop_ratio (float): The target flop ratio to prune your model. +input_shape (Tuple): input shape to measure the flops. +""" + +_base_ = 'mmcls::resnet/resnet50_8xb32_in1k.py' +pretrained_path = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa + +interval = 25 +normalization_type = 'act' +lr_ratio = 0.04 + +target_flop_ratio = 0.5 +input_shape = [1, 3, 224, 224] +############################################################################## + +architecture = _base_.model + +if hasattr(_base_, 'data_preprocessor'): + architecture.update({'data_preprocessor': _base_.data_preprocessor}) + data_preprocessor = None + +architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path) +architecture['_scope_'] = _base_.default_scope + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherAlgorithm', + architecture=architecture, + interval=interval, + mutator=dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict( + type='GroupFisherChannelUnit', + default_args=dict(normalization_type=normalization_type, ), + ), + ), +) + +model_wrapper_cfg = dict( + type='mmrazor.GroupFisherDDP', + broadcast_buffers=False, +) + +optim_wrapper = dict( + optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio)) + +custom_hooks = getattr(_base_, 'custom_hooks', []) + [ + dict(type='mmrazor.PruningStructureHook'), + dict( + type='mmrazor.ResourceInfoHook', + interval=interval, + demo_input=dict( + type='mmrazor.DefaultDemoInput', + input_shape=input_shape, + ), + save_ckpt_thr=[target_flop_ratio], + ), +] diff --git a/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_deploy_resnet50_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_deploy_resnet50_8xb32_in1k.py new file mode 100644 index 000000000..7dd84bd19 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_deploy_resnet50_8xb32_in1k.py @@ -0,0 +1,61 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = 'mmcls::resnet/resnet50_8xb32_in1k.py' +fix_subnet = { + 'backbone.conv1_(0, 64)_64': 61, + 'backbone.layer1.0.conv1_(0, 64)_64': 28, + 'backbone.layer1.0.conv2_(0, 64)_64': 35, + 'backbone.layer1.0.conv3_(0, 256)_256': 242, + 'backbone.layer1.1.conv1_(0, 64)_64': 31, + 'backbone.layer1.1.conv2_(0, 64)_64': 28, + 'backbone.layer1.2.conv1_(0, 64)_64': 26, + 'backbone.layer1.2.conv2_(0, 64)_64': 41, + 'backbone.layer2.0.conv1_(0, 128)_128': 90, + 'backbone.layer2.0.conv2_(0, 128)_128': 107, + 'backbone.layer2.0.conv3_(0, 512)_512': 509, + 'backbone.layer2.1.conv1_(0, 128)_128': 42, + 'backbone.layer2.1.conv2_(0, 128)_128': 50, + 'backbone.layer2.2.conv1_(0, 128)_128': 51, + 'backbone.layer2.2.conv2_(0, 128)_128': 84, + 'backbone.layer2.3.conv1_(0, 128)_128': 49, + 'backbone.layer2.3.conv2_(0, 128)_128': 51, + 'backbone.layer3.0.conv1_(0, 256)_256': 210, + 'backbone.layer3.0.conv2_(0, 256)_256': 207, + 'backbone.layer3.0.conv3_(0, 1024)_1024': 1024, + 'backbone.layer3.1.conv1_(0, 256)_256': 103, + 'backbone.layer3.1.conv2_(0, 256)_256': 108, + 'backbone.layer3.2.conv1_(0, 256)_256': 90, + 'backbone.layer3.2.conv2_(0, 256)_256': 124, + 'backbone.layer3.3.conv1_(0, 256)_256': 94, + 'backbone.layer3.3.conv2_(0, 256)_256': 114, + 'backbone.layer3.4.conv1_(0, 256)_256': 99, + 'backbone.layer3.4.conv2_(0, 256)_256': 111, + 'backbone.layer3.5.conv1_(0, 256)_256': 108, + 'backbone.layer3.5.conv2_(0, 256)_256': 111, + 'backbone.layer4.0.conv1_(0, 512)_512': 400, + 'backbone.layer4.0.conv2_(0, 512)_512': 421, + 'backbone.layer4.1.conv1_(0, 512)_512': 377, + 'backbone.layer4.1.conv2_(0, 512)_512': 347, + 'backbone.layer4.2.conv1_(0, 512)_512': 443, + 'backbone.layer4.2.conv2_(0, 512)_512': 376 +} +divisor = 8 +############################################################################## + +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py new file mode 100644 index 000000000..b05be2676 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py @@ -0,0 +1,31 @@ +############################################################################# +"""# You have to fill these args. + +_base_(str): The path to your pruning config file. +pruned_path (str): The path to the checkpoint of the pruned model. +finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr + rate of the pretrain. +""" + +_base_ = './group_fisher_flops_prune_resnet50_8xb32_in1k.py' +pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_prune_resnet50_8xb32_in1k.pth' # noqa +finetune_lr = 0.1 +############################################################################## +algorithm = _base_.model +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherSubModel', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py new file mode 100644 index 000000000..06b90bda0 --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py @@ -0,0 +1,5 @@ +_base_ = './group_fisher_act_prune_resnet50_8xb32_in1k.py' +model = dict( + mutator=dict( + channel_unit_cfg=dict( + default_args=dict(normalization_type='flops', ), ), ), ) diff --git a/configs/pruning/mmcls/group_fisher/resnet50/script.sh b/configs/pruning/mmcls/group_fisher/resnet50/script.sh new file mode 100644 index 000000000..534c3339c --- /dev/null +++ b/configs/pruning/mmcls/group_fisher/resnet50/script.sh @@ -0,0 +1,7 @@ +# act mode +bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_prune_resnet50_8xb32_in1k.py.py 8 +bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py.py 8 + +# flops mode +bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py.py 8 +bash ./tools/dist_train.sh configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py 8 diff --git a/configs/pruning/mmdet/group_fisher/README.md b/configs/pruning/mmdet/group_fisher/README.md new file mode 100644 index 000000000..9b3b09936 --- /dev/null +++ b/configs/pruning/mmdet/group_fisher/README.md @@ -0,0 +1,11 @@ +# Group_fisher pruning + +> [Group Fisher Pruning for Practical Network Compression.](https://arxiv.org/pdf/2108.00708.pdf) + +## Abstract + +Network compression has been widely studied since it is able to reduce the memory and computation cost during inference. However, previous methods seldom deal with complicated structures like residual connections, group/depthwise convolution and feature pyramid network, where channels of multiple layers are coupled and need to be pruned simultaneously. In this paper, we present a general channel pruning approach that can be applied to various complicated structures. Particularly, we propose a layer grouping algorithm to find coupled channels automatically. Then we derive a unified metric based on Fisher information to evaluate the importance of a single channel and coupled channels. Moreover, we find that inference speedup on GPUs is more correlated with the reduction of memory rather than FLOPs, and thus we employ the memory reduction of each channel to normalize the importance. Our method can be used to prune any structures including those with coupled channels. We conduct extensive experiments on various backbones, including the classic ResNet and ResNeXt, mobilefriendly MobileNetV2, and the NAS-based RegNet, both on image classification and object detection which is under-explored. Experimental results validate that our method can effectively prune sophisticated networks, boosting inference speed without sacrificing accuracy. + +![pipeline](/~https://github.com/jshilong/FisherPruning/blob/main/resources/structures.png?raw=true) + +**Please refer to the [full README](../../base/group_fisher/README.md) for more details.** diff --git a/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_deploy_retinanet_r50_fpn_1x_coco.py b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_deploy_retinanet_r50_fpn_1x_coco.py new file mode 100644 index 000000000..ecc6afbfa --- /dev/null +++ b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_deploy_retinanet_r50_fpn_1x_coco.py @@ -0,0 +1,73 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py' +fix_subnet = { + 'backbone.conv1_(0, 64)_64': 60, + 'backbone.layer1.0.conv1_(0, 64)_64': 48, + 'backbone.layer1.0.conv2_(0, 64)_64': 44, + 'backbone.layer1.0.conv3_(0, 256)_256': 250, + 'backbone.layer1.1.conv1_(0, 64)_64': 40, + 'backbone.layer1.1.conv2_(0, 64)_64': 41, + 'backbone.layer1.2.conv1_(0, 64)_64': 48, + 'backbone.layer1.2.conv2_(0, 64)_64': 62, + 'backbone.layer2.0.conv1_(0, 128)_128': 115, + 'backbone.layer2.0.conv2_(0, 128)_128': 127, + 'backbone.layer2.0.conv3_(0, 512)_512': 511, + 'backbone.layer2.1.conv1_(0, 128)_128': 69, + 'backbone.layer2.1.conv2_(0, 128)_128': 83, + 'backbone.layer2.2.conv1_(0, 128)_128': 111, + 'backbone.layer2.2.conv2_(0, 128)_128': 121, + 'backbone.layer2.3.conv1_(0, 128)_128': 122, + 'backbone.layer2.3.conv2_(0, 128)_128': 128, + 'backbone.layer3.0.conv1_(0, 256)_256': 255, + 'backbone.layer3.0.conv2_(0, 256)_256': 256, + 'backbone.layer3.0.conv3_(0, 1024)_1024': 1024, + 'backbone.layer3.1.conv1_(0, 256)_256': 216, + 'backbone.layer3.1.conv2_(0, 256)_256': 223, + 'backbone.layer3.2.conv1_(0, 256)_256': 229, + 'backbone.layer3.2.conv2_(0, 256)_256': 247, + 'backbone.layer3.3.conv1_(0, 256)_256': 239, + 'backbone.layer3.3.conv2_(0, 256)_256': 246, + 'backbone.layer3.4.conv1_(0, 256)_256': 237, + 'backbone.layer3.4.conv2_(0, 256)_256': 239, + 'backbone.layer3.5.conv1_(0, 256)_256': 233, + 'backbone.layer3.5.conv2_(0, 256)_256': 221, + 'backbone.layer4.0.conv1_(0, 512)_512': 499, + 'backbone.layer4.0.conv2_(0, 512)_512': 494, + 'backbone.layer4.0.conv3_(0, 2048)_2048': 2031, + 'backbone.layer4.1.conv1_(0, 512)_512': 451, + 'backbone.layer4.1.conv2_(0, 512)_512': 401, + 'backbone.layer4.2.conv1_(0, 512)_512': 396, + 'backbone.layer4.2.conv2_(0, 512)_512': 237, + 'neck.lateral_convs.0.conv_(0, 256)_256': 237, + 'neck.fpn_convs.0.conv_(0, 256)_256': 241, + 'bbox_head.cls_convs.0.conv_(0, 256)_256': 133, + 'bbox_head.cls_convs.1.conv_(0, 256)_256': 134, + 'bbox_head.cls_convs.2.conv_(0, 256)_256': 139, + 'bbox_head.cls_convs.3.conv_(0, 256)_256': 79, + 'bbox_head.reg_convs.0.conv_(0, 256)_256': 89, + 'bbox_head.reg_convs.1.conv_(0, 256)_256': 92, + 'bbox_head.reg_convs.2.conv_(0, 256)_256': 82, + 'bbox_head.reg_convs.3.conv_(0, 256)_256': 117 +} +divisor = 8 + +############################################################################## + +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py new file mode 100644 index 000000000..b0f7d08de --- /dev/null +++ b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py @@ -0,0 +1,31 @@ +############################################################################# +"""# You have to fill these args. + +_base_(str): The path to your pruning config file. +pruned_path (str): The path to the checkpoint of the pruned model. +finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr + rate of the pretrain. +""" + +_base_ = './group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py' +pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.pth' # noqa +finetune_lr = 0.005 +############################################################################## +algorithm = _base_.model +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherSubModel', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py new file mode 100644 index 000000000..d324c933a --- /dev/null +++ b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py @@ -0,0 +1,75 @@ +############################################################################# +"""You have to fill these args. + +_base_ (str): The path to your pretrained model checkpoint. +pretrained_path (str): The path to your pretrained model checkpoint. + +interval (int): Interval between pruning two channels. You should ensure you + can reach your target pruning ratio when the training ends. +normalization_type (str): GroupFisher uses two methods to normlized the channel + importance, including ['flops','act']. The former uses flops, while the + latter uses the memory occupation of activation feature maps. +lr_ratio (float): Ratio to decrease lr rate. As pruning progress is unstable, + you need to decrease the original lr rate until the pruning training work + steadly without getting nan. + +target_flop_ratio (float): The target flop ratio to prune your model. +input_shape (Tuple): input shape to measure the flops. +""" + +_base_ = 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py' +pretrained_path = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa + +interval = 10 +normalization_type = 'act' +lr_ratio = 0.1 + +target_flop_ratio = 0.5 +input_shape = (1, 3, 1333, 800) +############################################################################## + +architecture = _base_.model + +if hasattr(_base_, 'data_preprocessor'): + architecture.update({'data_preprocessor': _base_.data_preprocessor}) + data_preprocessor = None + +architecture.init_cfg = dict(type='Pretrained', checkpoint=pretrained_path) +architecture['_scope_'] = _base_.default_scope + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherAlgorithm', + architecture=architecture, + interval=interval, + mutator=dict( + type='GroupFisherChannelMutator', + parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'), + channel_unit_cfg=dict( + type='GroupFisherChannelUnit', + default_args=dict(normalization_type=normalization_type, ), + ), + ), +) + +model_wrapper_cfg = dict( + type='mmrazor.GroupFisherDDP', + broadcast_buffers=False, +) + +optim_wrapper = dict( + optimizer=dict(lr=_base_.optim_wrapper.optimizer.lr * lr_ratio)) + +custom_hooks = getattr(_base_, 'custom_hooks', []) + [ + dict(type='mmrazor.PruningStructureHook'), + dict( + type='mmrazor.ResourceInfoHook', + interval=interval, + demo_input=dict( + type='mmrazor.DefaultDemoInput', + input_shape=input_shape, + ), + save_ckpt_thr=[target_flop_ratio], + ), +] diff --git a/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_deploy_retinanet_r50_fpn_1x_coco.py b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_deploy_retinanet_r50_fpn_1x_coco.py new file mode 100644 index 000000000..f9cf21cc1 --- /dev/null +++ b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_deploy_retinanet_r50_fpn_1x_coco.py @@ -0,0 +1,73 @@ +############################################################################# +"""You have to fill these args. + +_base_(str): The path to your pretrain config file. +fix_subnet (Union[dict,str]): The dict store the pruning structure or the + json file including it. +divisor (int): The divisor the make the channel number divisible. +""" + +_base_ = 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py' +fix_subnet = { + 'backbone.conv1_(0, 64)_64': 60, + 'backbone.layer1.0.conv1_(0, 64)_64': 47, + 'backbone.layer1.0.conv2_(0, 64)_64': 44, + 'backbone.layer1.0.conv3_(0, 256)_256': 249, + 'backbone.layer1.1.conv1_(0, 64)_64': 37, + 'backbone.layer1.1.conv2_(0, 64)_64': 37, + 'backbone.layer1.2.conv1_(0, 64)_64': 44, + 'backbone.layer1.2.conv2_(0, 64)_64': 62, + 'backbone.layer2.0.conv1_(0, 128)_128': 114, + 'backbone.layer2.0.conv2_(0, 128)_128': 127, + 'backbone.layer2.0.conv3_(0, 512)_512': 511, + 'backbone.layer2.1.conv1_(0, 128)_128': 65, + 'backbone.layer2.1.conv2_(0, 128)_128': 83, + 'backbone.layer2.2.conv1_(0, 128)_128': 106, + 'backbone.layer2.2.conv2_(0, 128)_128': 118, + 'backbone.layer2.3.conv1_(0, 128)_128': 118, + 'backbone.layer2.3.conv2_(0, 128)_128': 127, + 'backbone.layer3.0.conv1_(0, 256)_256': 255, + 'backbone.layer3.0.conv2_(0, 256)_256': 256, + 'backbone.layer3.0.conv3_(0, 1024)_1024': 1024, + 'backbone.layer3.1.conv1_(0, 256)_256': 214, + 'backbone.layer3.1.conv2_(0, 256)_256': 232, + 'backbone.layer3.2.conv1_(0, 256)_256': 224, + 'backbone.layer3.2.conv2_(0, 256)_256': 247, + 'backbone.layer3.3.conv1_(0, 256)_256': 240, + 'backbone.layer3.3.conv2_(0, 256)_256': 246, + 'backbone.layer3.4.conv1_(0, 256)_256': 240, + 'backbone.layer3.4.conv2_(0, 256)_256': 243, + 'backbone.layer3.5.conv1_(0, 256)_256': 238, + 'backbone.layer3.5.conv2_(0, 256)_256': 232, + 'backbone.layer4.0.conv1_(0, 512)_512': 503, + 'backbone.layer4.0.conv2_(0, 512)_512': 500, + 'backbone.layer4.0.conv3_(0, 2048)_2048': 2041, + 'backbone.layer4.1.conv1_(0, 512)_512': 466, + 'backbone.layer4.1.conv2_(0, 512)_512': 430, + 'backbone.layer4.2.conv1_(0, 512)_512': 406, + 'backbone.layer4.2.conv2_(0, 512)_512': 274, + 'neck.lateral_convs.0.conv_(0, 256)_256': 236, + 'neck.fpn_convs.0.conv_(0, 256)_256': 225, + 'bbox_head.cls_convs.0.conv_(0, 256)_256': 140, + 'bbox_head.cls_convs.1.conv_(0, 256)_256': 133, + 'bbox_head.cls_convs.2.conv_(0, 256)_256': 139, + 'bbox_head.cls_convs.3.conv_(0, 256)_256': 86, + 'bbox_head.reg_convs.0.conv_(0, 256)_256': 89, + 'bbox_head.reg_convs.1.conv_(0, 256)_256': 89, + 'bbox_head.reg_convs.2.conv_(0, 256)_256': 76, + 'bbox_head.reg_convs.3.conv_(0, 256)_256': 122, +} +divisor = 8 + +############################################################################## + +architecture = _base_.model + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherDeploySubModel', + architecture=architecture, + fix_subnet=fix_subnet, + divisor=divisor, +) diff --git a/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py new file mode 100644 index 000000000..9d2d3a001 --- /dev/null +++ b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py @@ -0,0 +1,31 @@ +############################################################################# +"""# You have to fill these args. + +_base_(str): The path to your pruning config file. +pruned_path (str): The path to the checkpoint of the pruned model. +finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr + rate of the pretrain. +""" + +_base_ = './group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py' +pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.pth' # noqa +finetune_lr = 0.005 +############################################################################## +algorithm = _base_.model +algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GroupFisherSubModel', + algorithm=algorithm, +) + +# restore lr +optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) + +# remove pruning related hooks +custom_hooks = _base_.custom_hooks[:-2] + +# delete ddp +model_wrapper_cfg = None diff --git a/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py new file mode 100644 index 000000000..162db3bed --- /dev/null +++ b/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py @@ -0,0 +1,5 @@ +_base_ = './group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py' +model = dict( + mutator=dict( + channel_unit_cfg=dict( + default_args=dict(normalization_type='flops', ), ), ), ) diff --git a/configs/pruning/mmdet/group_fisher/retinanet/script.sh b/configs/pruning/mmdet/group_fisher/retinanet/script.sh new file mode 100644 index 000000000..246c5e34f --- /dev/null +++ b/configs/pruning/mmdet/group_fisher/retinanet/script.sh @@ -0,0 +1,7 @@ +# act mode +bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py 8 +bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py 8 + +# flops mode +bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py 8 +bash ./tools/dist_train.sh configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py 8 diff --git a/mmrazor/engine/hooks/group_fisher_hooks.py b/mmrazor/engine/hooks/group_fisher_hooks.py new file mode 100644 index 000000000..3ce8631b7 --- /dev/null +++ b/mmrazor/engine/hooks/group_fisher_hooks.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This file includes the modules in the impl folder. + +As it only records impl modules, it is not initialized automatically. +""" +from mmrazor.implementations.pruning.group_fisher import \ + PruningStructureHook # noqa +from mmrazor.implementations.pruning.group_fisher import \ + ResourceInfoHook # noqa diff --git a/mmrazor/implementations/__init__.py b/mmrazor/implementations/__init__.py new file mode 100644 index 000000000..a03158f53 --- /dev/null +++ b/mmrazor/implementations/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""impl folder is an experimental file structure to store algorithm +implementations. + +Previous file structure splits the files of an algorithm into different folders +according to the types of these files. It may make it hard to understand an +algorithm. So we add the impl folder, where all files of an algorithm are +stored in one folder. As this structure is experimental, it may change rapidly. +""" + +from . import pruning # noqa + +__all__ = ['pruning'] diff --git a/mmrazor/implementations/pruning/__init__.py b/mmrazor/implementations/pruning/__init__.py new file mode 100644 index 000000000..e28ae7dc2 --- /dev/null +++ b/mmrazor/implementations/pruning/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from . import group_fisher + +__all__ = ['group_fisher'] diff --git a/mmrazor/implementations/pruning/group_fisher/__init__.py b/mmrazor/implementations/pruning/group_fisher/__init__.py new file mode 100644 index 000000000..5dd85ce3c --- /dev/null +++ b/mmrazor/implementations/pruning/group_fisher/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .algorithm import GroupFisherAlgorithm +from .counters import GroupFisherConv2dCounter, GroupFisherLinearCounter +from .hook import PruningStructureHook, ResourceInfoHook +from .mutator import GroupFisherChannelMutator +from .ops import GroupFisherConv2d, GroupFisherLinear, GroupFisherMixin +from .prune_deploy_sub_model import GroupFisherDeploySubModel +from .prune_sub_model import GroupFisherSubModel +from .unit import GroupFisherChannelUnit + +__all__ = [ + 'GroupFisherDeploySubModel', + 'GroupFisherSubModel', + 'GroupFisherAlgorithm', + 'GroupFisherConv2dCounter', + 'GroupFisherLinearCounter', + 'PruningStructureHook', + 'ResourceInfoHook', + 'GroupFisherChannelMutator', + 'GroupFisherChannelUnit', + 'GroupFisherConv2d', + 'GroupFisherLinear', + 'GroupFisherMixin', +] diff --git a/mmrazor/implementations/pruning/group_fisher/algorithm.py b/mmrazor/implementations/pruning/group_fisher/algorithm.py new file mode 100644 index 000000000..a90b406db --- /dev/null +++ b/mmrazor/implementations/pruning/group_fisher/algorithm.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from mmengine.logging import print_log +from mmengine.model import BaseModel, MMDistributedDataParallel + +from mmrazor.models.algorithms.base import BaseAlgorithm +from mmrazor.registry import MODEL_WRAPPERS, MODELS +from mmrazor.utils import RuntimeInfo +from .mutator import GroupFisherChannelMutator + + +@MODELS.register_module() +class GroupFisherAlgorithm(BaseAlgorithm): + """`Group Fisher Pruning for Practical Network Compression`. + https://arxiv.org/pdf/2108.00708.pdf. + + Args: + architecture (Union[BaseModel, Dict]): The model to be pruned. + mutator (Union[Dict, ChannelMutator], optional): The config + of a mutator. Defaults to dict( type='GroupFisherChannelMutator', + channel_unit_cfg=dict( type='GroupFisherChannelUnit')). + interval (int): The interval of pruning two channels. Defaults to 10. + data_preprocessor (Optional[Union[Dict, nn.Module]], optional): + Defaults to None. + init_cfg (Optional[Dict], optional): init config for the model. + Defaults to None. + """ + + def __init__(self, + architecture: Union[BaseModel, Dict], + mutator: Union[Dict, GroupFisherChannelMutator] = dict( + type='GroupFisherChannelMutator', + channel_unit_cfg=dict(type='GroupFisherChannelUnit')), + interval: int = 10, + data_preprocessor: Optional[Union[Dict, nn.Module]] = None, + init_cfg: Optional[Dict] = None) -> None: + + super().__init__(architecture, data_preprocessor, init_cfg) + + self.interval = interval + + # using sync bn or normal bn + if dist.is_initialized(): + print_log('Convert Bn to SyncBn.') + self.architecture = nn.SyncBatchNorm.convert_sync_batchnorm( + self.architecture) + else: + from mmengine.model import revert_sync_batchnorm + self.architecture = revert_sync_batchnorm(self.architecture) + + # mutator + self.mutator: GroupFisherChannelMutator = MODELS.build(mutator) + self.mutator.prepare_from_supernet(self.architecture) + + def train_step(self, data: Union[dict, tuple, list], + optim_wrapper) -> Dict[str, torch.Tensor]: + return self._train_step(data, optim_wrapper) + + def _train_step(self, data: Union[dict, tuple, list], optim_wrapper): + """Train step function for GroupFisherAlgorithm and GroupFisherDDP.""" + self.mutator.start_record_info() + res = super().train_step(data, optim_wrapper) + self.mutator.end_record_info() + + self.mutator.update_imp() + self.mutator.reset_recorded_info() + + if RuntimeInfo.iter() % self.interval == 0: + self.mutator.try_prune() + self.mutator.reset_imp() + + return res + + +@MODEL_WRAPPERS.register_module() +class GroupFisherDDP(MMDistributedDataParallel): + """Train step for group fisher.""" + + def train_step(self, data: Union[dict, tuple, list], + optim_wrapper) -> Dict[str, torch.Tensor]: + algorithm = self.module + return GroupFisherAlgorithm._train_step(algorithm, data, optim_wrapper) diff --git a/mmrazor/implementations/pruning/group_fisher/counters.py b/mmrazor/implementations/pruning/group_fisher/counters.py new file mode 100644 index 000000000..6f41a0244 --- /dev/null +++ b/mmrazor/implementations/pruning/group_fisher/counters.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.models.task_modules.estimators.counters.op_counters.dynamic_op_counters import ( # noqa + DynamicConv2dCounter, DynamicLinearCounter) +from mmrazor.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class GroupFisherConv2dCounter(DynamicConv2dCounter): + """Counter of GroupFisherConv2d.""" + pass + + +@TASK_UTILS.register_module() +class GroupFisherLinearCounter(DynamicLinearCounter): + """Counter of GroupFisherLinear.""" + pass diff --git a/mmrazor/implementations/pruning/group_fisher/hook.py b/mmrazor/implementations/pruning/group_fisher/hook.py new file mode 100644 index 000000000..ef2e3aece --- /dev/null +++ b/mmrazor/implementations/pruning/group_fisher/hook.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.dist import master_only +from mmengine.hooks import Hook +from mmengine.runner import Runner, save_checkpoint +from torch import distributed as torch_dist + +from mmrazor.models.algorithms import BaseAlgorithm +from mmrazor.models.mutators.channel_mutator.channel_mutator import \ + ChannelMutator +from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput +from mmrazor.models.task_modules.estimators import ResourceEstimator +from mmrazor.registry import HOOKS, TASK_UTILS +from mmrazor.utils import RuntimeInfo, print_log + + +def get_model_from_runner(runner): + """Get the model from a runner.""" + if torch_dist.is_initialized(): + return runner.model.module + else: + return runner.model + + +def is_pruning_algorithm(algorithm): + """Check whether a model is a pruning algorithm.""" + return isinstance(algorithm, BaseAlgorithm) \ + and isinstance(getattr(algorithm, 'mutator', None), ChannelMutator) # noqa + + +@HOOKS.register_module() +class PruningStructureHook(Hook): + """This hook is used to display the structurn information during pruning. + + Args: + by_epoch (bool, optional): Whether to display structure information + iteratively by epoch. Defaults to True. + interval (int, optional): The interval between two structure + information display. + """ + + def __init__(self, by_epoch=True, interval=1) -> None: + + super().__init__() + self.by_epoch = by_epoch + self.interval = interval + + def show_unit_info(self, algorithm): + """Show unit information of an algorithm.""" + if is_pruning_algorithm(algorithm): + chices = algorithm.mutator.choice_template + import json + print_log(json.dumps(chices, indent=4)) + + for unit in algorithm.mutator.mutable_units: + if hasattr(unit, 'importance'): + imp = unit.importance() + print_log( + f'{unit.name}: \t{imp.min().item()}\t{imp.max().item()}' # noqa + ) + + @master_only + def show(self, runner): + """Show pruning algorithm information of a runner.""" + algorithm = get_model_from_runner(runner) + if is_pruning_algorithm(algorithm): + self.show_unit_info(algorithm) + + # hook points + + def after_train_epoch(self, runner) -> None: + if self.by_epoch and RuntimeInfo.epoch() % self.interval == 0: + self.show(runner) + + def after_train_iter(self, runner, batch_idx: int, data_batch, + outputs) -> None: + if not self.by_epoch and RuntimeInfo.iter() % self.interval == 0: + self.show(runner) + + +@HOOKS.register_module() +class ResourceInfoHook(Hook): + """This hook is used to display the resource related information and save + the checkpoint according to a threshold during pruning. + + Args: + demo_input (dict, optional): the demo input for ResourceEstimator. + Defaults to DefaultDemoInput([1, 3, 224, 224]). + interval (int, optional): the interval to check the resource. Defaults + to 10. + resource_type (str, optional): the type of resource to check. + Defaults to 'flops'. + save_ckpt_thr (list, optional): the threshold to save checkpoint. + Defaults to [0.5]. + early_stop (bool, optional): whether to stop when all checkpoints have + been saved according to save_ckpt_thr. Defaults to True. + """ + + def __init__(self, + demo_input=DefaultDemoInput([1, 3, 224, 224]), + interval=10, + resource_type='flops', + save_ckpt_thr=[0.5], + early_stop=True) -> None: + + super().__init__() + if isinstance(demo_input, dict): + demo_input = TASK_UTILS.build(demo_input) + + self.demo_input = demo_input + self.save_ckpt_thr = sorted( + save_ckpt_thr, reverse=True) # big to small + self.resource_type = resource_type + self.early_stop = early_stop + self.estimator: ResourceEstimator = TASK_UTILS.build( + dict( + _scope_='mmrazor', + type='ResourceEstimator', + flops_params_cfg=dict( + input_shape=tuple(demo_input.input_shape), ))) + self.interval = interval + self.origin_delta = None + + def before_run(self, runner) -> None: + """Init original_resource.""" + model = get_model_from_runner(runner) + original_resource = self._evaluate(model) + print_log(f'get original resource: {original_resource}') + + self.origin_delta = original_resource[self.resource_type] + + # save checkpoint + + def after_train_iter(self, + runner: Runner, + batch_idx: int, + data_batch=None, + outputs=None) -> None: + """Check resource after train iteration.""" + if RuntimeInfo.iter() % self.interval == 0 and len( + self.save_ckpt_thr) > 0: + model = get_model_from_runner(runner) + current_delta = self._evaluate(model)[self.resource_type] + percent = current_delta / self.origin_delta + if percent < self.save_ckpt_thr[0]: + self._save_checkpoint(model, runner.work_dir, + self.save_ckpt_thr.pop(0)) + if self.early_stop and len(self.save_ckpt_thr) == 0: + exit() + + # show info + + @master_only + def after_train_epoch(self, runner) -> None: + """Check resource after train epoch.""" + model = get_model_from_runner(runner) + current_delta = self._evaluate(model)[self.resource_type] + print_log( + f'current {self.resource_type}: {current_delta} / {self.origin_delta}' # noqa + ) + + # + + def _evaluate(self, model: nn.Module): + """Evaluate the resource required by a model.""" + with torch.no_grad(): + training = model.training + model.eval() + res = self.estimator.estimate(model) + if training: + model.train() + return res + + @master_only + def _save_checkpoint(self, model, path, delta_percent): + """Save the checkpoint of a model.""" + ckpt = {'state_dict': model.state_dict()} + save_path = f'{path}/{self.resource_type}_{delta_percent:.2f}.pth' + save_checkpoint(ckpt, save_path) + print_log( + f'Save checkpoint to {save_path} with {self._evaluate(model)}' # noqa + ) diff --git a/mmrazor/implementations/pruning/group_fisher/mutator.py b/mmrazor/implementations/pruning/group_fisher/mutator.py new file mode 100644 index 000000000..d9e521a38 --- /dev/null +++ b/mmrazor/implementations/pruning/group_fisher/mutator.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Dict, List, Type, Union + +from mmengine.dist import dist + +from mmrazor.models.mutators.channel_mutator.channel_mutator import \ + ChannelMutator +from mmrazor.registry import MODELS +from mmrazor.utils import print_log +from .unit import GroupFisherChannelUnit + + +@MODELS.register_module() +class GroupFisherChannelMutator(ChannelMutator[GroupFisherChannelUnit]): + """Channel mutator for GroupFisher Pruning Algorithm. + + Args: + channel_unit_cfg (Union[dict, Type[ChannelUnitType]], optional): + Config of MutableChannelUnits. Defaults to + dict(type='GroupFisherChannelUnit', + default_args=dict(choice_mode='ratio')). + parse_cfg (Dict): The config of the tracer to parse the model. + Defaults to dict(type='ChannelAnalyzer', + demo_input=(1, 3, 224, 224), + tracer_type='FxTracer'). + """ + + def __init__(self, + channel_unit_cfg: Union[dict, + Type[GroupFisherChannelUnit]] = dict( + type='GroupFisherChannelUnit'), + parse_cfg: Dict = dict( + type='ChannelAnalyzer', + demo_input=(1, 3, 224, 224), + tracer_type='FxTracer'), + **kwargs) -> None: + super().__init__(channel_unit_cfg, parse_cfg, **kwargs) + self.mutable_units: List[GroupFisherChannelUnit] + + def start_record_info(self) -> None: + """Start recording the related information.""" + for unit in self.mutable_units: + unit.start_record_fisher_info() + + def end_record_info(self) -> None: + """Stop recording the related information.""" + for unit in self.mutable_units: + unit.end_record_fisher_info() + + def reset_recorded_info(self) -> None: + """Reset the related information.""" + for unit in self.mutable_units: + unit.reset_recorded() + + def try_prune(self) -> None: + """Prune the channel with the minimum fisher unless it is the last + channel of the current layer.""" + min_imp = 1e5 + min_unit = self.mutable_units[0] + for unit in self.mutable_units: + if unit.mutable_channel.activated_channels > 1: + imp = unit.importance() + if imp.isnan().any(): + if dist.get_rank() == 0: + print_log( + f'{unit.name} detects nan in importance, this pruning skips.' # noqa + ) + return + if imp.min() < min_imp: + min_imp = imp.min().item() + min_unit = unit + if min_unit.try_to_prune_min_channel(): + if dist.get_rank() == 0: + print_log( + f'{min_unit.name} prunes a channel with min imp = {min_imp}' # noqa + ) + + def update_imp(self) -> None: + """Update the fisher information of each unit.""" + for unit in self.mutable_units: + unit.update_fisher_info() + + def reset_imp(self) -> None: + """Reset the fisher information of each unit.""" + for unit in self.mutable_units: + unit.reset_fisher_info() diff --git a/mmrazor/implementations/pruning/group_fisher/ops.py b/mmrazor/implementations/pruning/group_fisher/ops.py new file mode 100644 index 000000000..35dbbd749 --- /dev/null +++ b/mmrazor/implementations/pruning/group_fisher/ops.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch + +from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_conv import \ + DynamicConv2d +from mmrazor.models.architectures.dynamic_ops.bricks.dynamic_linear import \ + DynamicLinear + + +class GroupFisherMixin: + """The mixin class for GroupFisher ops.""" + + def _init(self) -> None: + self.handlers: list = [] + self.recorded_input: List = [] + self.recorded_grad: List = [] + self.recorded_out_shape: List = [] + + def forward_hook_wrapper(self): + """Wrap the hook used in forward.""" + + def forward_hook(module: GroupFisherMixin, input, output): + module.recorded_out_shape.append(output.shape) + module.recorded_input.append(input[0]) + + return forward_hook + + def backward_hook_wrapper(self): + """Wrap the hook used in backward.""" + + def backward_hook(module: GroupFisherMixin, grad_in, grad_out): + module.recorded_grad.insert(0, grad_in[0]) + + return backward_hook + + def start_record(self: torch.nn.Module) -> None: + """Start recording information during forward and backward.""" + self.end_record() # ensure to run start_record only once + self.handlers.append( + self.register_forward_hook(self.forward_hook_wrapper())) + self.handlers.append( + self.register_backward_hook(self.backward_hook_wrapper())) + + def end_record(self): + """Stop recording information during forward and backward.""" + for handle in self.handlers: + handle.remove() + self.handlers = [] + + def reset_recorded(self): + """Reset the recorded information.""" + self.recorded_input = [] + self.recorded_grad = [] + self.recorded_out_shape = [] + + @property + def delta_flop_of_a_out_channel(self): + raise NotImplementedError() + + @property + def delta_flop_of_a_in_channel(self): + raise NotImplementedError() + + @property + def delta_memory_of_a_out_channel(self): + raise NotImplementedError() + + +class GroupFisherConv2d(DynamicConv2d, GroupFisherMixin): + """The Dynamic Conv2d operation used in GroupFisher Algorithm.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._init() + + @property + def delta_flop_of_a_out_channel(self) -> torch.Tensor: + """Calculate the summation of flops when prune an out_channel.""" + delta_flop_sum = 0 + for shape in self.recorded_out_shape: + _, _, h, w = shape + in_c = int(self.mutable_attrs['in_channels'].current_mask.float(). + sum().item()) + # normal conv + if self.groups == 1: + delta_flop = h * w * self.kernel_size[0] * self.kernel_size[ + 1] * in_c + # dwconv + elif self.groups == self.in_channels == self.out_channels: + delta_flop = h * w * self.kernel_size[0] * self.kernel_size[1] + # groupwise conv + else: + raise NotImplementedError() + delta_flop_sum += delta_flop + return delta_flop_sum + + @property + def delta_flop_of_a_in_channel(self): + """Calculate the summation of flops when prune an in_channel.""" + delta_flop_sum = 0 + for shape in self.recorded_out_shape: + _, out_c, h, w = shape + # normal conv + if self.groups == 1: + delta_flop = h * w * self.kernel_size[0] * self.kernel_size[ + 1] * out_c + # dwconv + elif self.groups == self.in_channels == self.out_channels: + delta_flop = h * w * self.kernel_size[0] * self.kernel_size[1] + # groupwise conv + else: + raise NotImplementedError() + delta_flop_sum += delta_flop + return delta_flop_sum + + @property + def delta_memory_of_a_out_channel(self): + """Calculate the summation of memory when prune a channel.""" + delta_flop_sum = 0 + for shape in self.recorded_out_shape: + _, _, h, w = shape + delta_flop_sum += h * w + return delta_flop_sum + + +class GroupFisherLinear(DynamicLinear, GroupFisherMixin): + """The Dynamic Linear operation used in GroupFisher Algorithm.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._init() + + @property + def delta_flop_of_a_out_channel(self): + """Calculate the summation of flops when prune an out_channel.""" + in_c = self.mutable_attrs['in_channels'].current_mask.float().sum() + return in_c * len(self.recorded_out_shape) + + @property + def delta_flop_of_a_in_channel(self): + """Calculate the summation of flops when prune an in_channel.""" + out_c = self.mutable_attrs['out_channels'].current_mask.float().sum() + return out_c * len(self.recorded_out_shape) + + @property + def delta_memory_of_a_out_channel(self): + """Calculate the summation of memory when prune a channel.""" + return 1 * len(self.recorded_out_shape) diff --git a/mmrazor/implementations/pruning/group_fisher/prune_deploy_sub_model.py b/mmrazor/implementations/pruning/group_fisher/prune_deploy_sub_model.py new file mode 100644 index 000000000..c197c3ff7 --- /dev/null +++ b/mmrazor/implementations/pruning/group_fisher/prune_deploy_sub_model.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from typing import Union + +import torch.nn as nn +from mmengine import fileio + +from mmrazor.registry import MODELS +from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet, + load_fix_subnet) +from mmrazor.utils import print_log + + +@MODELS.register_module() +def GroupFisherDeploySubModel(architecture, + fix_subnet: Union[dict, str] = {}, + divisor=1, + parse_cfg=dict( + _scope_='mmrazor', + type='ChannelAnalyzer', + demo_input=(1, 3, 224, 224), + tracer_type='FxTracer'), + **kwargs): + """Convert a architecture to a pruned static architecture for mmdeploy. + + Args: + architecture (Union[nn.Module, dict]): the model to be pruned. + fix_subnet (Union[dict, str]): the channel remaining ratio for each + unit, or the path of a file including this info. Defaults to {}. + divisor (int, optional): The divisor to make the channel number + divisible. Defaults to 1. + parse_cfg (dict, optional): The args for channel mutator. + Returns: + BaseModel: a BaseModel of mmengine. + """ + # import avoid circular import + from mmrazor.models.mutables import SequentialMutableChannelUnit + from mmrazor.models.mutators import ChannelMutator + from mmrazor.models.utils.expandable_utils.unit import ExpandableUnit + + # build architecture + if isinstance(architecture, dict): + architecture = MODELS.build(architecture) + assert isinstance(architecture, nn.Module) + + # to dynamic model + mutator = ChannelMutator[ExpandableUnit]( + channel_unit_cfg=SequentialMutableChannelUnit, parse_cfg=parse_cfg) + + mutator.prepare_from_supernet(architecture) + if isinstance(fix_subnet, str): + fix_subnet = fileio.load(fix_subnet) + assert isinstance(fix_subnet, dict) + mutator.set_choices(fix_subnet) + print_log(json.dumps(mutator.current_choices, indent=4)) + + fix_subnet = export_fix_subnet(architecture)[0] + load_fix_subnet(architecture, fix_subnet) + + # cooperate with mmdeploy to make the channel divisible after load + # the checkpoint. + if divisor != 1: + setattr(architecture, '_razor_divisor', divisor) + + return architecture diff --git a/mmrazor/implementations/pruning/group_fisher/prune_sub_model.py b/mmrazor/implementations/pruning/group_fisher/prune_sub_model.py new file mode 100644 index 000000000..87a77346d --- /dev/null +++ b/mmrazor/implementations/pruning/group_fisher/prune_sub_model.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import types + +import torch.nn as nn +from mmengine import dist, fileio +from mmengine.model import BaseModel, BaseModule + +from mmrazor.models.algorithms import BaseAlgorithm +from mmrazor.models.utils.expandable_utils import make_channel_divisible +from mmrazor.registry import MODELS +from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet, + load_fix_subnet) +from mmrazor.utils import RuntimeInfo, print_log + + +def clean_params_init_info(model: nn.Module): + """Clean param init info.""" + if hasattr(model, '_params_init_info'): + delattr(model, '_params_init_info') + for module in model.modules(): + if hasattr(module, '_params_init_info'): + delattr(module, '_params_init_info') + + +def clean_init_cfg(model: BaseModule): + """Clean init cfg.""" + for module in model.modules(): + if module is model: + continue + if isinstance(module, BaseModule): + module.init_cfg = {} + + +def hacky_init_weights_wrapper(fix_subnet): + """This init weight method is used to prevent the model init again after + build. + + Besides, It also save fix_subnet.json after RuntimeInfo is ready. + """ + + def hacky_init_weights(model): + if dist.get_rank() == 0: + try: + work_dir = RuntimeInfo.work_dir() + fileio.dump( + fix_subnet, work_dir + '/fix_subnet.json', indent=4) + print_log( + f'save pruning structure in {work_dir}/fix_subnet.json') + except Exception: + pass + + return hacky_init_weights + + +@MODELS.register_module() +def GroupFisherSubModel( + algorithm, + divisor=1, + **kargs, +): + """Convert a algorithm(with an architecture) to a static pruned + architecture. + + Args: + algorithm (Union[BaseAlgorithm, dict]): The pruning algorithm to + finetune. + divisor (int): The divisor to make the channel number + divisible. Defaults to 1. + + Returns: + nn.Module: a static model. + """ + # init algorithm + if isinstance(algorithm, dict): + algorithm = MODELS.build(algorithm) # type: ignore + assert isinstance(algorithm, BaseAlgorithm) + algorithm.init_weights() + clean_params_init_info(algorithm) + + pruning_structure = algorithm.mutator.choice_template + print_log('PruneSubModel get pruning structure:') + print_log(json.dumps(pruning_structure, indent=4)) + + # to static model + fix_mutable = export_fix_subnet(algorithm.architecture)[0] + load_fix_subnet(algorithm.architecture, fix_mutable) + model = algorithm.architecture + + # make channel divisible + if divisor != 1: + divisible_structure = make_channel_divisible( + model, divisor=divisor, zero_weight=False) + + print_log('PruneSubModel get divisible pruning structure:') + print_log(json.dumps(divisible_structure, indent=4)) + pruning_structure = divisible_structure + + # refine model + model.data_preprocessor = algorithm.data_preprocessor + if isinstance(model, BaseModel): + model.init_cfg = None + model.init_weights = types.MethodType( + hacky_init_weights_wrapper(pruning_structure), model) + return model diff --git a/mmrazor/implementations/pruning/group_fisher/unit.py b/mmrazor/implementations/pruning/group_fisher/unit.py new file mode 100644 index 000000000..1c9128b78 --- /dev/null +++ b/mmrazor/implementations/pruning/group_fisher/unit.py @@ -0,0 +1,230 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +from mmengine.model.utils import _BatchNormXd +from mmengine.utils.dl_utils.parrots_wrapper import \ + SyncBatchNorm as EngineSyncBatchNorm +from torch import distributed as dist + +import mmrazor.models.architectures.dynamic_ops as dynamic_ops +from mmrazor.models.mutables.mutable_channel.mutable_channel_container import \ + MutableChannelContainer +from mmrazor.models.mutables.mutable_channel.units.l1_mutable_channel_unit import \ + L1MutableChannelUnit # noqa +from mmrazor.registry import MODELS +from .ops import GroupFisherConv2d, GroupFisherLinear, GroupFisherMixin + + +@MODELS.register_module() +class GroupFisherChannelUnit(L1MutableChannelUnit): + """ChannelUnit for GroupFisher Pruning Algorithm. + + Args: + num_channels (int): Number of channels. + normalization_type (str): Type of normalization. It can be one of + ['flops','act','none',]. Defaults to 'flop'. + mutate_linear (bool): Whether to prune linear layers. + """ + + def __init__(self, + num_channels: int, + normalization_type: str = 'flops', + mutate_linear=False, + *args) -> None: + super().__init__(num_channels, *args) + normalized_fisher_info = torch.zeros([self.num_channels]) + self.register_buffer('normalized_fisher_info', normalized_fisher_info) + self.normalized_fisher_info: torch.Tensor + + self.hook_handles: List = [] + assert normalization_type in ['flops', 'act', 'none'] + self.delta_type = normalization_type + + self.mutate_linear = mutate_linear + + def prepare_for_pruning(self, model: nn.Module) -> None: + """Prepare for pruning, including register mutable channels. + + Args: + model (nn.Module): The model need to be pruned. + """ + # register MutableMask + self._replace_with_dynamic_ops( + model, { + nn.Conv2d: GroupFisherConv2d, + nn.BatchNorm2d: dynamic_ops.DynamicBatchNorm2d, + nn.Linear: GroupFisherLinear, + nn.SyncBatchNorm: dynamic_ops.DynamicSyncBatchNorm, + EngineSyncBatchNorm: dynamic_ops.DynamicSyncBatchNorm, + _BatchNormXd: dynamic_ops.DynamicBatchNormXd, + }) + self._register_channel_container(model, MutableChannelContainer) + self._register_mutable_channel(self.mutable_channel) + + # prune + def try_to_prune_min_channel(self) -> bool: + """Prune the channel with the minimum value of fisher information.""" + if self.mutable_channel.activated_channels > 1: + imp = self.importance() + index = imp.argmin() + self.mutable_channel.mask.scatter_(0, index, 0.0) + return True + else: + return False + + @property + def is_mutable(self) -> bool: + """Whether the unit is mutable.""" + mutable = super().is_mutable + if self.mutate_linear: + return mutable + else: + has_linear = False + for layer in self.input_related: + if isinstance(layer.module, nn.Linear): + has_linear = True + return mutable and (not has_linear) + + @property + def input_related_dynamic_ops(self): + for channel in self.input_related: + if isinstance(channel.module, GroupFisherMixin): + yield channel.module + + @property + def output_related_dynamic_ops(self): + for channel in self.output_related: + if isinstance(channel.module, GroupFisherMixin): + yield channel.module + + @property + def dynamic_ops(self): + for module in self.input_related_dynamic_ops: + yield module + for module in self.output_related_dynamic_ops: + yield module + + # fisher information recorded + + def start_record_fisher_info(self) -> None: + """Start recording the related fisher info of each channel.""" + for module in self.dynamic_ops: + module.start_record() + + def end_record_fisher_info(self) -> None: + """Stop recording the related fisher info of each channel.""" + for module in self.dynamic_ops: + module.end_record() + + def reset_recorded(self) -> None: + """Reset the recorded info of each channel.""" + for module in self.dynamic_ops: + module.reset_recorded() + + # fisher related computation + + def importance(self): + """The importance of each channel.""" + fisher = self.normalized_fisher_info.clone() + mask = self.mutable_channel.current_mask + n_mask = (1 - mask.float()).bool() + fisher.masked_fill_(n_mask, fisher.max() + 1) + return fisher + + def reset_fisher_info(self) -> None: + """Reset the related fisher info.""" + self.normalized_fisher_info.zero_() + + @torch.no_grad() + def update_fisher_info(self) -> None: + """Update the fisher info of each channel.""" + + batch_fisher_sum = self.current_batch_fisher + assert isinstance(batch_fisher_sum, torch.Tensor) + if dist.is_initialized(): + dist.all_reduce(batch_fisher_sum) + batch_fisher_sum = self._get_normalized_fisher_info( + batch_fisher_sum, self.delta_type) + self.normalized_fisher_info = self.normalized_fisher_info + batch_fisher_sum # noqa + + @property + def current_batch_fisher(self) -> torch.Tensor: + """Accumulate the unit's fisher info of this batch.""" + with torch.no_grad(): + fisher: torch.Tensor = 0 + for module in self.input_related_dynamic_ops: + fisher = fisher + self._fisher_of_a_module(module) + return (fisher**2).sum(0) # shape: [C] + + @torch.no_grad() + def _fisher_of_a_module(self, module: GroupFisherMixin) -> torch.Tensor: + """Calculate the fisher info of one module. + + Args: + module (GroupFisherConv2d): A `GroupFisherConv2d` module. + + Return: + torch.Tensor: Whose shape is [B C] + """ + assert len(module.recorded_input) > 0 and \ + len(module.recorded_input) == len(module.recorded_grad) + fisher_sum: torch.Tensor = 0 + for input, grad_input in zip(module.recorded_input, + module.recorded_grad): + fisher: torch.Tensor = input * grad_input + if len(fisher.shape) == 4: + fisher = fisher.sum(dim=[2, 3]) + assert len(fisher.shape) == 2 # B C + fisher_sum = fisher_sum + fisher + assert isinstance(fisher_sum, torch.Tensor) + # expand to full num_channel + batch_size = fisher_sum.shape[0] + mask = self.mutable_channel.current_mask.unsqueeze(0).expand( + [batch_size, self.num_channels]) + zeros = fisher_sum.new_zeros([batch_size, self.num_channels]) + fisher_sum = zeros.masked_scatter_(mask, fisher_sum) + return fisher_sum + + @torch.no_grad() + def _get_normalized_fisher_info(self, + fisher_info, + delta_type='flop') -> torch.Tensor: + """Get the normalized fisher info. + + Args: + delta_type (str): Type of delta. Defaults to 'flop'. + """ + fisher = fisher_info.double() + if delta_type == 'flops': + delta_flop = self._delta_flop_of_a_channel + assert delta_flop > 0 + fisher = fisher / (float(delta_flop) / 1e9) + elif delta_type == 'act': + delta_memory = self._delta_memory_of_a_channel + assert delta_memory > 0 + fisher = fisher / (float(delta_memory) / 1e6) + elif delta_type == 'none': + pass + else: + raise NotImplementedError(delta_type) + return fisher + + @property + def _delta_flop_of_a_channel(self) -> torch.Tensor: + """Calculate the flops of a channel.""" + delta_flop = 0 + for module in self.output_related_dynamic_ops: + delta_flop += module.delta_flop_of_a_out_channel + for module in self.input_related_dynamic_ops: + delta_flop += module.delta_flop_of_a_in_channel + return delta_flop + + @property + def _delta_memory_of_a_channel(self) -> torch.Tensor: + """Calculate the memory of a channel.""" + delta_memory = 0 + for module in self.output_related_dynamic_ops: + delta_memory += module.delta_memory_of_a_out_channel + return delta_memory diff --git a/mmrazor/models/algorithms/pruning/group_fisher_algoritho.py b/mmrazor/models/algorithms/pruning/group_fisher_algoritho.py new file mode 100644 index 000000000..eccbe1228 --- /dev/null +++ b/mmrazor/models/algorithms/pruning/group_fisher_algoritho.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This file includes the modules in the impl folder. + +As it only records impl modules, it is not initialized automatically. +""" +from mmrazor.implementations.pruning.group_fisher import \ + GroupFisherAlgorithm # noqa diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/group_fisher_ops.py b/mmrazor/models/architectures/dynamic_ops/bricks/group_fisher_ops.py new file mode 100644 index 000000000..c4a635607 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_ops/bricks/group_fisher_ops.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This file includes the modules in the impl folder. + +As it only records impl modules, it is not initialized automatically. +""" +from mmrazor.implementations.pruning.group_fisher import \ + GroupFisherConv2d # noqa +from mmrazor.implementations.pruning.group_fisher import \ + GroupFisherLinear # noqa +from mmrazor.implementations.pruning.group_fisher import \ + GroupFisherMixin # noqa diff --git a/mmrazor/models/mutables/mutable_channel/units/group_fisher_unit.py b/mmrazor/models/mutables/mutable_channel/units/group_fisher_unit.py new file mode 100644 index 000000000..7d33f4232 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/units/group_fisher_unit.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This file includes the modules in the impl folder. + +As it only records impl modules, it is not initialized automatically. +""" +from mmrazor.implementations.pruning.group_fisher import \ + GroupFisherChannelUnit # noqa diff --git a/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py index 89dc785ed..d32c5fead 100644 --- a/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py @@ -10,7 +10,6 @@ SyncBatchNorm as EngineSyncBatchNorm from mmrazor.models.architectures import dynamic_ops -from mmrazor.models.utils import make_divisible from mmrazor.registry import MODELS from ..mutable_channel_container import MutableChannelContainer from ..sequential_mutable_channel import SquentialMutableChannel @@ -134,6 +133,7 @@ def _get_valid_int_choice(self, choice: Union[float, int]) -> int: def _make_divisible(self, choice_int: int): """Make the choice divisible.""" + from mmrazor.models.utils import make_divisible return make_divisible(choice_int, self.divisor, self.min_value, self.min_ratio) diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index 38abd2fcc..3de024635 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -66,6 +66,7 @@ def __init__(self, dict, Type[MutableChannelUnit]] = SequentialMutableChannelUnit, parse_cfg: Dict = dict( + _scope_='mmrazor', type='ChannelAnalyzer', demo_input=(1, 3, 224, 224), tracer_type='BackwardTracer'), diff --git a/mmrazor/models/mutators/channel_mutator/group_fisher_mutator.py b/mmrazor/models/mutators/channel_mutator/group_fisher_mutator.py new file mode 100644 index 000000000..cde31bacf --- /dev/null +++ b/mmrazor/models/mutators/channel_mutator/group_fisher_mutator.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This file includes the modules in the impl folder. + +As it only records impl modules, it is not initialized automatically. +""" +from mmrazor.implementations.pruning.group_fisher import \ + GroupFisherChannelMutator # noqa diff --git a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py index 75a1db293..63c60e8cb 100644 --- a/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/default_demo_inputs.py @@ -6,6 +6,7 @@ from mmrazor.registry import TASK_UTILS from mmrazor.utils import get_placeholder +from ...algorithms.base import BaseAlgorithm from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput, DefaultMMDemoInput, DefaultMMDetDemoInput, DefaultMMPoseDemoInput, DefaultMMRotateDemoInput, @@ -70,8 +71,12 @@ def get_default_demo_input_class(model, scope): def defaul_demo_inputs(model, input_shape, training=False, scope=None): """Get demo input according to a model and scope.""" - demo_input = get_default_demo_input_class(model, scope) - return demo_input().get_data(model, input_shape, training) + if isinstance(model, BaseAlgorithm): + return defaul_demo_inputs(model.architecture, input_shape, training, + scope) + else: + demo_input = get_default_demo_input_class(model, scope) + return demo_input().get_data(model, input_shape, training) @TASK_UTILS.register_module() diff --git a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py index 8664f3a2d..e5c05fbcf 100644 --- a/mmrazor/models/task_modules/demo_inputs/demo_inputs.py +++ b/mmrazor/models/task_modules/demo_inputs/demo_inputs.py @@ -51,7 +51,9 @@ def _get_data(self, model, input_shape=None, training=None): return data def _get_mm_data(self, model, input_shape, training=False): - return {'inputs': torch.rand(input_shape), 'data_samples': None} + data = {'inputs': torch.rand(input_shape), 'data_samples': None} + data = model.data_preprocessor(data, training) + return data @TASK_UTILS.register_module() @@ -84,7 +86,7 @@ def _get_mm_data(self, model, input_shape, training=False): """Helper for get_data, including core logic to generate demo input.""" from mmdet.models import BaseDetector from mmdet.testing._utils import demo_mm_inputs - assert isinstance(model, BaseDetector) + assert isinstance(model, BaseDetector), f'{type(model)}' data = demo_mm_inputs(1, [input_shape[1:]], with_mask=True) data = model.data_preprocessor(data, training) @@ -132,7 +134,7 @@ def _get_mm_data(self, model, input_shape, training=False): from mmpose.models import TopdownPoseEstimator from .mmpose_demo_input import demo_mmpose_inputs - assert isinstance(model, TopdownPoseEstimator) + assert isinstance(model, TopdownPoseEstimator), f'{type(model)}' data = demo_mmpose_inputs(model, input_shape) return data diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/__init__.py b/mmrazor/models/task_modules/estimators/counters/op_counters/__init__.py index 6e33babe2..67c3d6207 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/__init__.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/__init__.py @@ -13,10 +13,24 @@ from .upsample_layer_counter import UpsampleCounter __all__ = [ - 'ReLUCounter', 'PReLUCounter', 'ELUCounter', 'LeakyReLUCounter', - 'ReLU6Counter', 'BatchNorm1dCounter', 'BatchNorm2dCounter', - 'BatchNorm3dCounter', 'Conv1dCounter', 'Conv2dCounter', 'Conv3dCounter', - 'ConvTranspose2dCounter', 'UpsampleCounter', 'LinearCounter', - 'GroupNormCounter', 'InstanceNorm1dCounter', 'InstanceNorm2dCounter', - 'InstanceNorm3dCounter', 'LayerNormCounter', 'BaseCounter' + 'ReLUCounter', + 'PReLUCounter', + 'ELUCounter', + 'LeakyReLUCounter', + 'ReLU6Counter', + 'BatchNorm1dCounter', + 'BatchNorm2dCounter', + 'BatchNorm3dCounter', + 'Conv1dCounter', + 'Conv2dCounter', + 'Conv3dCounter', + 'ConvTranspose2dCounter', + 'UpsampleCounter', + 'LinearCounter', + 'GroupNormCounter', + 'InstanceNorm1dCounter', + 'InstanceNorm2dCounter', + 'InstanceNorm3dCounter', + 'LayerNormCounter', + 'BaseCounter', ] diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/dynamic_op_counters.py b/mmrazor/models/task_modules/estimators/counters/op_counters/dynamic_op_counters.py new file mode 100644 index 000000000..2a58a09ec --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/dynamic_op_counters.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn + +from mmrazor.registry import TASK_UTILS +from .conv_layer_counter import Conv2dCounter +from .linear_layer_counter import LinearCounter + + +@TASK_UTILS.register_module() +class DynamicConv2dCounter(Conv2dCounter): + """Flop counter for DynamicCon2d.""" + + @staticmethod + def add_count_hook(module: nn.Conv2d, input: Tuple[torch.Tensor], + output: torch.Tensor) -> None: + """Count the flops and params of a DynamicConv2d. + + Args: + module (nn.Conv2d): A Conv2d module. + input (Tuple[torch.Tensor]): Input of this module. + output (torch.Tensor): Output of this module. + """ + batch_size = input[0].shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(module.kernel_size) + + out_channels = module.mutable_attrs['out_channels'].activated_channels + in_channels = module.mutable_attrs['in_channels'].activated_channels + + groups = module.groups + + filters_per_channel = out_channels / groups + conv_per_position_flops = int( + np.prod(kernel_dims)) * in_channels * filters_per_channel + + active_elements_count = batch_size * int(np.prod(output_dims)) + + overall_conv_flops = conv_per_position_flops * active_elements_count + overall_params = conv_per_position_flops + + bias_flops = 0 + overall_params = conv_per_position_flops + if module.bias is not None: + bias_flops = out_channels * active_elements_count + overall_params += out_channels + + overall_flops = overall_conv_flops + bias_flops + + module.__flops__ += overall_flops + module.__params__ += int(overall_params) + + +@TASK_UTILS.register_module() +class DynamicLinearCounter(LinearCounter): + pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/group_fisher_counters.py b/mmrazor/models/task_modules/estimators/counters/op_counters/group_fisher_counters.py new file mode 100644 index 000000000..7e85c33ae --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/group_fisher_counters.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This file includes the modules in the impl folder. + +As it only records impl modules, it is not initialized automatically. +""" +from mmrazor.implementations.pruning.group_fisher import ( # noqa + GroupFisherConv2dCounter, GroupFisherLinearCounter) diff --git a/mmrazor/models/utils/expandable_utils/__init__.py b/mmrazor/models/utils/expandable_utils/__init__.py new file mode 100644 index 000000000..23eeb6073 --- /dev/null +++ b/mmrazor/models/utils/expandable_utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""This module is used to expand the channels of a supernet. + +We only expose some tool functions, rather than all DynamicOps and +MutableChannelUnits, as They uses a few hacky operations. +""" +from .tools import (expand_expandable_dynamic_model, expand_static_model, + make_channel_divisible, to_expandable_model) + +__all__ = [ + 'make_channel_divisible', + 'to_expandable_model', + 'expand_expandable_dynamic_model', + 'expand_static_model', +] diff --git a/mmrazor/models/utils/expandable_utils/ops.py b/mmrazor/models/utils/expandable_utils/ops.py new file mode 100644 index 000000000..fa4c41db9 --- /dev/null +++ b/mmrazor/models/utils/expandable_utils/ops.py @@ -0,0 +1,237 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmrazor.models.architectures import dynamic_ops +from mmrazor.models.mutables import MutableChannelContainer + + +class ExpandableMixin: + """This minin coroperates with dynamic ops. + + It defines interfaces to expand the channels of ops. We can get a wider + network than original supernet with it. + """ + + def expand(self, zero=False): + """Expand the op. + + Args: + zero (bool, optional): whether to set new weights to zero. Defaults + to False. + """ + return self.get_expand_op( + self.expanded_in_channel, + self.expanded_out_channel, + zero=zero, + ) + + def get_expand_op(self, in_c, out_c, zero=False): + """Get an expanded op. + + Args: + in_c (int): New input channels + out_c (int): New output channels + zero (bool, optional): Whether to zero new weights. Defaults to + False. + """ + pass + + @property + def _original_in_channel(self): + """Return original in channel.""" + raise NotImplementedError() + + @property + def _original_out_channel(self): + """Return original out channel.""" + + @property + def expanded_in_channel(self): + """Return expanded in channel number.""" + if self.in_mutable is not None: + return self.in_mutable.current_mask.numel() + else: + return self._original_in_channel + + @property + def expanded_out_channel(self): + """Return expanded out channel number.""" + if self.out_mutable is not None: + return self.out_mutable.current_mask.numel() + else: + return self._original_out_channel + + @property + def mutable_in_mask(self): + """Return the mutable in mask.""" + if self.in_mutable is not None: + return self.in_mutable.current_mask + else: + if hasattr(self, 'weight'): + return self.weight.new_ones([self.expanded_in_channel]) + else: + return torch.ones([self.expanded_in_channel]) + + @property + def mutable_out_mask(self): + """Return the mutable out mask.""" + if self.out_mutable is not None: + return self.out_mutable.current_mask + else: + if hasattr(self, 'weight'): + return self.weight.new_ones([self.expanded_out_channel]) + else: + return torch.ones([self.expanded_out_channel]) + + @property + def in_mutable(self) -> MutableChannelContainer: + """In channel mask.""" + return self.get_mutable_attr('in_channels') # type: ignore + + @property + def out_mutable(self) -> MutableChannelContainer: + """Out channel mask.""" + return self.get_mutable_attr('out_channels') # type: ignore + + def zero_weight_(self: nn.Module): + """Zero all weights.""" + for p in self.parameters(): + p.data.zero_() + + @torch.no_grad() + def expand_matrix(self, weight: torch.Tensor, old_weight: torch.Tensor): + """Expand weight matrix.""" + assert len(weight.shape) == 3 # out in c + assert len(old_weight.shape) == 3 # out in c + mask = self.mutable_out_mask.float().unsqueeze( + -1) * self.mutable_in_mask.float().unsqueeze(0) + mask = mask.unsqueeze(-1).expand(*weight.shape) + weight.data.masked_scatter_(mask.bool(), old_weight) + return weight + + @torch.no_grad() + def expand_vector(self, weight: torch.Tensor, old_weight: torch.Tensor): + """Expand weight vector which has the shape of [out, c].""" + assert len(weight.shape) == 2 # out c + assert len(old_weight.shape) == 2 # out c + mask = self.mutable_out_mask + mask = mask.unsqueeze(-1).expand(*weight.shape) + weight.data.masked_scatter_(mask.bool(), old_weight) + return weight + + @torch.no_grad() + def expand_bias(self, bias: torch.Tensor, old_bias: torch.Tensor): + """Expand bias.""" + assert len(bias.shape) == 1 # out c + assert len(old_bias.shape) == 1 # out c + return self.expand_vector(bias.unsqueeze(-1), + old_bias.unsqueeze(-1)).squeeze(1) + + +class ExpandableConv2d(dynamic_ops.DynamicConv2d, ExpandableMixin): + + @property + def _original_in_channel(self): + return self.in_channels + + @property + def _original_out_channel(self): + return self.out_channels + + def get_expand_op(self, in_c, out_c, zero=False): + + if self.groups == 1: + return self._get_expand_op_normal_conv(in_c, out_c, zero=zero) + elif self.in_channels == self.out_channels == self.groups: + return self._get_expand_op_dw_conv(in_c, out_c, zero=zero) + else: + raise NotImplementedError('Groupwise conv is not supported yet.') + + def _get_expand_op_normal_conv(self, in_c, out_c, zero=False): + + module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride, + self.padding, self.dilation, self.groups, self.bias + is not None, self.padding_mode) + if zero: + ExpandableMixin.zero_weight_(module) + + weight = self.expand_matrix( + module.weight.flatten(2), self.weight.flatten(2)) + module.weight.data = weight.reshape(module.weight.shape) + if module.bias is not None and self.bias is not None: + bias = self.expand_vector( + module.bias.unsqueeze(-1), self.bias.unsqueeze(-1)) + module.bias.data = bias.reshape(module.bias.shape) + return module + + def _get_expand_op_dw_conv(self, in_c, out_c, zero=False): + assert in_c == out_c + module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride, + self.padding, self.dilation, in_c, self.bias + is not None, self.padding_mode) + if zero: + ExpandableMixin.zero_weight_(module) + + weight = self.expand_vector( + module.weight.flatten(1), self.weight.flatten(1)) + module.weight.data = weight.reshape(module.weight.shape) + if module.bias is not None and self.bias is not None: + bias = self.expand_vector( + module.bias.unsqueeze(-1), self.bias.unsqueeze(-1)) + module.bias.data = bias.reshape(module.bias.shape) + return module + + +class ExpandLinear(dynamic_ops.DynamicLinear, ExpandableMixin): + + @property + def _original_in_channel(self): + return self.in_features + + @property + def _original_out_channel(self): + return self.out_features + + def get_expand_op(self, in_c, out_c, zero=False): + module = nn.Linear(in_c, out_c, self.bias is not None) + if zero: + ExpandableMixin.zero_weight_(module) + + weight = self.expand_matrix( + module.weight.unsqueeze(-1), self.weight.unsqueeze(-1)) + module.weight.data = weight.reshape(module.weight.shape) + if module.bias is not None: + bias = self.expand_vector( + module.bias.unsqueeze(-1), self.bias.unsqueeze(-1)) + module.bias.data = bias.reshape(module.bias.shape) + return module + + +class ExpandableBatchNorm2d(dynamic_ops.DynamicBatchNorm2d, ExpandableMixin): + + @property + def _original_in_channel(self): + return self.num_features + + @property + def _original_out_channel(self): + return self.num_features + + def get_expand_op(self, in_c, out_c, zero=False): + assert in_c == out_c + module = nn.BatchNorm2d(in_c, self.eps, self.momentum, self.affine, + self.track_running_stats) + if zero: + ExpandableMixin.zero_weight_(module) + + if module.running_mean is not None: + module.running_mean.data = self.expand_bias( + module.running_mean, self.running_mean) + + if module.running_var is not None: + module.running_var.data = self.expand_bias(module.running_var, + self.running_var) + module.weight.data = self.expand_bias(module.weight, self.weight) + module.bias.data = self.expand_bias(module.bias, self.bias) + return module diff --git a/mmrazor/models/utils/expandable_utils/tools.py b/mmrazor/models/utils/expandable_utils/tools.py new file mode 100644 index 000000000..d6c559d91 --- /dev/null +++ b/mmrazor/models/utils/expandable_utils/tools.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict + +import torch.nn as nn + +from mmrazor.models.mutators import ChannelMutator +from .ops import ExpandableMixin +from .unit import ExpandableUnit + + +def to_expandable_model(model: nn.Module) -> ChannelMutator[ExpandableUnit]: + """Convert a static model to an expandable model.""" + state_dict = model.state_dict() + mutator = ChannelMutator[ExpandableUnit](channel_unit_cfg=ExpandableUnit) + mutator.prepare_from_supernet(model) + model.load_state_dict(state_dict) + return mutator + + +def expand_expandable_dynamic_model(model: nn.Module, zero=False) -> nn.Module: + """Expand a expandable model and return a expanded static model. + + Args: + model (nn.Module): The model to be expanded. + zero (bool, optional): Whether to zero expanded weight. Defaults to + False. + """ + + def traverse_children(module: nn.Module) -> None: + for name, mutable in module.items(): + if isinstance(mutable, ExpandableMixin): + module[name] = mutable.expand(zero=zero) + if hasattr(mutable, '_modules'): + traverse_children(mutable._modules) + + if isinstance(model, ExpandableMixin): + raise RuntimeError('Root model can not be dynamic op.') + + if hasattr(model, '_modules'): + traverse_children(model._modules) + return model + + +def expand_static_model(model: nn.Module, structure: Dict, zero_weight=True): + """Expand the channels of a model. + + Args: + model (nn.Module): the model to be expanded. + structure (Dict): the channel structure for the model. + divisor (_type_): the divisor to make the channels divisible. + """ + mutator = to_expandable_model(model) + for key, value in structure.items(): + mutator._name2unit[key].expand_to(value) + expand_expandable_dynamic_model(model, zero=zero_weight) + return model + + +def make_channel_divisible(model: nn.Module, divisor, zero_weight=True): + """Expand the channels of a model and return the new divisible channel + structure. + + Args: + model (nn.Module): the model to be expanded. + divisor (_type_): the divisor to make the channels divisible. + """ + # to sta + mutator = to_expandable_model(model) + + structure = mutator.choice_template + for key, num in structure.items(): + unit = mutator._name2unit[key] + if num % divisor == 0: + continue + else: + num = (num // divisor + 1) * divisor + num = max(num, unit.num_channels) + unit.expand_to(num) + + model = expand_expandable_dynamic_model(model, zero=zero_weight) + mutator = to_expandable_model(copy.deepcopy(model)) + + return mutator.choice_template diff --git a/mmrazor/models/utils/expandable_utils/unit.py b/mmrazor/models/utils/expandable_utils/unit.py new file mode 100644 index 000000000..3a9b628c2 --- /dev/null +++ b/mmrazor/models/utils/expandable_utils/unit.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmrazor.models.mutables import (L1MutableChannelUnit, + MutableChannelContainer) +from .ops import ExpandableBatchNorm2d, ExpandableConv2d, ExpandLinear + + +class ExpandableUnit(L1MutableChannelUnit): + """The units to inplace modules with expandable dynamic ops.""" + + def prepare_for_pruning(self, model: nn.Module): + self._replace_with_dynamic_ops( + model, { + nn.Conv2d: ExpandableConv2d, + nn.BatchNorm2d: ExpandableBatchNorm2d, + nn.Linear: ExpandLinear, + }) + self._register_channel_container(model, MutableChannelContainer) + self._register_mutable_channel(self.mutable_channel) + + def expand(self, num): + expand_mask = self.mutable_channel.mask.new_zeros([num]) + mask = torch.cat([self.mutable_channel.mask, expand_mask]) + self.mutable_channel.mask = mask + + def expand_to(self, num): + self.expand(num - self.num_channels) diff --git a/mmrazor/utils/__init__.py b/mmrazor/utils/__init__.py index 2d6e1ae43..a69480e94 100644 --- a/mmrazor/utils/__init__.py +++ b/mmrazor/utils/__init__.py @@ -3,6 +3,7 @@ from .log_tools import get_level, print_log from .misc import find_latest_checkpoint from .placeholder import get_placeholder +from .runtime_info import RuntimeInfo from .setup_env import register_all_modules, setup_multi_processes from .typing import (FixMutable, MultiMutatorsRandomSubnet, SingleMutatorRandomSubnet, SupportRandomSubnet, @@ -12,5 +13,5 @@ 'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules', 'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet', 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder', - 'IndexDict', 'get_level', 'print_log' + 'IndexDict', 'get_level', 'print_log', 'RuntimeInfo' ] diff --git a/mmrazor/utils/runtime_info.py b/mmrazor/utils/runtime_info.py new file mode 100644 index 000000000..f117c2d06 --- /dev/null +++ b/mmrazor/utils/runtime_info.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmengine import Config, MessageHub + + +class RuntimeInfo(): + """A tools to get runtime info in MessageHub.""" + + @classmethod + def info(cls): + hub = MessageHub.get_current_instance() + return hub.runtime_info + + @classmethod + def get_info(cls, key): + info = cls.info() + if key in info: + return info[key] + else: + raise KeyError(key) + + @classmethod + def epoch(cls): + return cls.get_info('epoch') + + @classmethod + def max_epochs(cls): + return cls.get_info('max_epochs') + + @classmethod + def iter(cls): + return cls.get_info('iter') + + @classmethod + def max_iters(cls): + return cls.get_info('max_iters') + + @classmethod + def iter_by_epoch(cls): + iter_per_epoch = math.ceil(cls.max_iters() / cls.max_epochs()) + return cls.iter() % iter_per_epoch + + @classmethod + def iter_pre_epoch(cls): + iter_per_epoch = math.ceil(cls.max_iters() / cls.max_epochs()) + return iter_per_epoch + + @classmethod + def config(cls): + cfg: str = cls.get_info('cfg') + config = Config.fromstring(cfg, '.py') + return config + + @classmethod + def work_dir(cls): + config = cls.config() + return config['work_dir'] diff --git a/mmrazor/utils/setup_env.py b/mmrazor/utils/setup_env.py index 385be8624..a091933aa 100644 --- a/mmrazor/utils/setup_env.py +++ b/mmrazor/utils/setup_env.py @@ -63,6 +63,7 @@ def register_all_modules(init_default_scope: bool = True) -> None: import mmrazor.datasets # noqa: F401,F403 import mmrazor.engine # noqa: F401,F403 + import mmrazor.implementations # noqa: F401,F403 import mmrazor.models # noqa: F401,F403 import mmrazor.structures # noqa: F401,F403 if init_default_scope: diff --git a/tests/data/models.py b/tests/data/models.py index 220130b56..33fb0c624 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -78,6 +78,7 @@ def untracable_method(self, x): x = x * -2 return x + @MODELS.register_module() class UntracableBackBone(nn.Module): @@ -106,7 +107,6 @@ def forward(self, x): return self.head(self.backbone(x)) - class ConvAttnModel(Module): def __init__(self) -> None: @@ -123,6 +123,7 @@ def forward(self, x): x_last = self.conv2(x_attn) return self.head(x_last) + @MODELS.register_module() class LinearHeadForTest(Module): @@ -623,6 +624,27 @@ def _forward_attention(self, x: torch.Tensor): return self.proj(y) +def MMClsResNet18() -> BaseModel: + model_cfg = dict( + _scope_='mmcls', + type='ImageClassifier', + backbone=dict( + type='ResNet', + depth=18, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) + return MODELS.build(model_cfg) + + # models with dynamicop @@ -682,7 +704,7 @@ def current_choice(self): def current_choice(self, choice): super().current_choice(choice) - + class DynamicLinearModel(nn.Module): """ x @@ -843,7 +865,7 @@ class DynamicMMBlock(nn.Module): [4, 6, 1], [4, 6, 1], [6, 6, 1], - [6, 6, 1] + [6, 6, 1], ], num_out_channels=[ # [min_channel, max_channel, step] [16, 24, 8], @@ -852,11 +874,11 @@ class DynamicMMBlock(nn.Module): [64, 72, 8], [112, 128, 8], [192, 216, 8], - [216, 224, 8] + [216, 224, 8], ]) def __init__( - self, + self, conv_cfg: Dict = dict(type='mmrazor.BigNasConv2d'), norm_cfg: Dict = dict(type='mmrazor.DynamicBatchNorm2d'), fine_grained_mode: bool = False, @@ -936,12 +958,11 @@ def __init__( act_cfg=dict(type='Swish')))])) self.add_module('last_conv', last_layers) self.layers.append(last_layers) - + self.register_mutables() - def _make_single_layer(self, out_channels, num_blocks, - kernel_sizes, expand_ratios, - act_cfg, stride, use_se): + def _make_single_layer(self, out_channels, num_blocks, kernel_sizes, + expand_ratios, act_cfg, stride, use_se): _layers = [] for i in range(max(num_blocks)): if i >= 1: diff --git a/tests/test_impl/__init__.py b/tests/test_impl/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_impl/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_impl/test_pruning/__init__.py b/tests/test_impl/test_pruning/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_impl/test_pruning/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_impl/test_pruning/test_group_fisher/__init__.py b/tests/test_impl/test_pruning/test_group_fisher/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_impl/test_pruning/test_group_fisher/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_impl/test_pruning/test_group_fisher/test_algorithm.py b/tests/test_impl/test_pruning/test_group_fisher/test_algorithm.py new file mode 100644 index 000000000..ec2707282 --- /dev/null +++ b/tests/test_impl/test_pruning/test_group_fisher/test_algorithm.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmcls.structures import ClsDataSample +from mmengine import MessageHub + +from mmrazor.implementations.pruning.group_fisher.algorithm import \ + GroupFisherAlgorithm +from mmrazor.implementations.pruning.group_fisher.ops import GroupFisherConv2d +from ....data.models import MMClsResNet18 + +if torch.cuda.is_available(): + DEVICE = torch.device('cuda:0') +else: + DEVICE = torch.device('cpu') + + +class TestGroupFisherPruneAlgorithm(TestCase): + + def fake_cifar_data(self): + imgs = torch.randn(16, 3, 32, 32).to(DEVICE) + data_samples = [ + ClsDataSample().set_gt_label(torch.randint(0, 10, + (16, ))).to(DEVICE) + ] + + return {'inputs': imgs, 'data_samples': data_samples} + + def test_group_fisher_prune(self): + data = self.fake_cifar_data() + + MUTATOR_CONFIG = dict( + type='GroupFisherChannelMutator', + parse_cfg=dict( + type='ChannelAnalyzer', tracer_type='BackwardTracer'), + channel_unit_cfg=dict(type='GroupFisherChannelUnit')) + + epoch = 2 + interval = 1 + + algorithm = GroupFisherAlgorithm( + MMClsResNet18(), mutator=MUTATOR_CONFIG, + interval=interval).to(DEVICE) + mutator = algorithm.mutator + + for e in range(epoch): + for ite in range(10): + self._set_epoch_ite(e, ite, epoch) + algorithm.forward( + data['inputs'], data['data_samples'], mode='loss') + self.gen_fake_grad(mutator) + self.assertEqual(interval, algorithm.interval) + + def gen_fake_grad(self, mutator): + for unit in mutator.mutable_units: + for channel in unit.input_related: + module = channel.module + if isinstance(module, GroupFisherConv2d): + module.recorded_grad = module.recorded_input + + def _set_epoch_ite(self, epoch, ite, max_epoch): + iter_per_epoch = 10 + message_hub = MessageHub.get_current_instance() + message_hub.update_info('epoch', epoch) + message_hub.update_info('max_epochs', max_epoch) + message_hub.update_info('max_iters', max_epoch * 10) + message_hub.update_info('iter', ite + iter_per_epoch * epoch) diff --git a/tests/test_impl/test_pruning/test_group_fisher/test_prune_deploy_sub_model.py b/tests/test_impl/test_pruning/test_group_fisher/test_prune_deploy_sub_model.py new file mode 100644 index 000000000..53452c178 --- /dev/null +++ b/tests/test_impl/test_pruning/test_group_fisher/test_prune_deploy_sub_model.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +from unittest import TestCase + +from mmengine import fileio + +from mmrazor.implementations.pruning.group_fisher.prune_deploy_sub_model import \ + GroupFisherDeploySubModel # noqa +from ....data.models import MMClsResNet18 +from .test_prune_sub_model import PruneAlgorithm, get_model_structure + + +class TestPruneDeploySubModel(TestCase): + + def test_build_sub_model(self): + model = MMClsResNet18() + + parse_cfg = dict( + _scope_='mmrazor', + type='ChannelAnalyzer', + demo_input=(1, 3, 224, 224), + tracer_type='BackwardTracer') + # get structure + algorithm = PruneAlgorithm(copy.deepcopy(model)) + algorithm.random_prune() + strucutrue = algorithm.mutator.current_choices + + # test divisor + wrapper = GroupFisherDeploySubModel( + copy.deepcopy(model), strucutrue, divisor=1, parse_cfg=parse_cfg) + self.assertSequenceEqual( + list(strucutrue.values()), + list(get_model_structure(wrapper).values())) + + wrapper = GroupFisherDeploySubModel( + copy.deepcopy(model), strucutrue, divisor=8, parse_cfg=parse_cfg) + self.assertSequenceEqual( + list(strucutrue.values()), + list(get_model_structure(wrapper).values())) + + mutable_path = os.path.dirname(__file__) + '/mutable.json' + fileio.dump(algorithm.mutator.current_choices, mutable_path) + GroupFisherDeploySubModel( + copy.deepcopy(model), + divisor=1, + mutable_cfg=mutable_path, + parse_cfg=parse_cfg) + os.remove(mutable_path) diff --git a/tests/test_impl/test_pruning/test_group_fisher/test_prune_sub_model.py b/tests/test_impl/test_pruning/test_group_fisher/test_prune_sub_model.py new file mode 100644 index 000000000..83f29f281 --- /dev/null +++ b/tests/test_impl/test_pruning/test_group_fisher/test_prune_sub_model.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, Union +from unittest import TestCase + +import torch + +from mmrazor.implementations.pruning.group_fisher.prune_sub_model import \ + GroupFisherSubModel +from mmrazor.models import BaseAlgorithm +from mmrazor.models.mutators import ChannelMutator +from mmrazor.registry import MODELS +from ....data.models import MMClsResNet18 + + +class PruneAlgorithm(BaseAlgorithm): + + def __init__(self, + architecture, + mutator: Union[Dict, ChannelMutator] = dict( + type='ChannelMutator', + channel_unit_cfg=dict( + type='SequentialMutableChannelUnit')), + data_preprocessor=None, + init_cfg=None) -> None: + super().__init__( + architecture, data_preprocessor, init_cfg, module_inplace=False) + if isinstance(mutator, dict): + mutator = MODELS.build(mutator) + assert isinstance(mutator, ChannelMutator) + self.mutator = mutator + mutator.prepare_from_supernet(self.architecture) + + def random_prune(self): + choices = self.mutator.sample_choices() + self.mutator.set_choices(choices) + + +def get_model_structure(model): + algorithm = PruneAlgorithm(copy.deepcopy(model)) + return algorithm.mutator.current_choices + + +class TestPruneSubModel(TestCase): + + def test_build_sub_model(self): + x = torch.rand([1, 3, 224, 224]) + model = MMClsResNet18() + algorithm = PruneAlgorithm(model) + algorithm.random_prune() + + # test divisor + static_model1 = GroupFisherSubModel(algorithm, divisor=1) + self.assertSequenceEqual( + list(algorithm.mutator.current_choices.values()), + list(get_model_structure(static_model1).values())) + + static_model2 = GroupFisherSubModel(algorithm, divisor=8) + for value in get_model_structure(static_model2).values(): + self.assertTrue(value % 8 == 0) + + y1 = static_model1(x) + y2 = static_model2(x) + self.assertTrue((y1 - y2).abs().max() < 1e-3) diff --git a/tests/test_impl/test_pruning/test_group_fisher/test_unit.py b/tests/test_impl/test_pruning/test_group_fisher/test_unit.py new file mode 100644 index 000000000..712d2fb50 --- /dev/null +++ b/tests/test_impl/test_pruning/test_group_fisher/test_unit.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmrazor.implementations.pruning.group_fisher import \ + GroupFisherChannelMutator +from ....data.models import MMClsResNet18 + + +class TestGroupFisherChannelUnit(unittest.TestCase): + + def test_init(self): + model = MMClsResNet18() + mutator = GroupFisherChannelMutator( + parse_cfg=dict( + type='ChannelAnalyzer', + demo_input=(1, 3, 224, 224), + tracer_type='BackwardTracer')) + mutator.prepare_from_supernet(model) + + x = torch.rand([1, 3, 224, 224]) + mutator.start_record_info() + for i in range(2): + model.train() + loss = model(x).sum() + loss.backward() + mutator.end_record_info() + + for unit in mutator.mutable_units: + for module in unit.input_related_dynamic_ops: + self.assertEqual(len(module.recorded_input), 2) + self.assertEqual(len(module.recorded_grad), 2) + self.assertIsInstance(module.recorded_grad[0], torch.Tensor) + + unit = mutator.mutable_units[0] + fisher = unit._fisher_of_a_module(next(unit.input_related_dynamic_ops)) + self.assertEqual(list(fisher.shape), [1, unit.num_channels]) + + fisher = unit.current_batch_fisher + self.assertEqual(list(fisher.shape), [unit.num_channels]) + + fisher = unit._get_normalized_fisher_info(fisher, unit.delta_type) + unit.update_fisher_info() diff --git a/tests/test_models/test_utils/__init__.py b/tests/test_models/test_utils/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_utils/test_expandable_utils/__init__.py b/tests/test_models/test_utils/test_expandable_utils/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_utils/test_expandable_utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_utils/test_expandable_utils/test_expand.py b/tests/test_models/test_utils/test_expandable_utils/test_expand.py new file mode 100644 index 000000000..486bf79eb --- /dev/null +++ b/tests/test_models/test_utils/test_expandable_utils/test_expand.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmrazor.models.mutables import SimpleMutableChannel +from mmrazor.models.utils.expandable_utils import ( + expand_expandable_dynamic_model, make_channel_divisible, + to_expandable_model) +from mmrazor.models.utils.expandable_utils.ops import ExpandLinear +from ....data.models import DwConvModel, MultiConcatModel, SingleLineModel + + +class TestExpand(unittest.TestCase): + + def test_expand(self): + for Model in [MultiConcatModel, DwConvModel]: + x = torch.rand([1, 3, 224, 224]) + model = Model() + print(model) + mutator = to_expandable_model(model) + print(mutator.choice_template) + print(model) + y1 = model(x) + + for unit in mutator.mutable_units: + unit.expand(10) + print(unit.mutable_channel.mask.shape) + expand_expandable_dynamic_model(model, zero=True) + print(model) + y2 = model(x) + self.assertTrue((y1 - y2).abs().max() < 1e-3) + + def test_expand_static_model(self): + x = torch.rand([1, 3, 224, 224]) + model = SingleLineModel() + y1 = model(x) + make_channel_divisible(model, divisor=4) + y2 = model(x) + print(y1.reshape([-1])[:5]) + print(y2.reshape([-1])[:5]) + self.assertTrue((y1 - y2).abs().max() < 1e-3) + + def test_ExpandConv2d(self): + linear = ExpandLinear(3, 3) + mutable_in = SimpleMutableChannel(3) + mutable_out = SimpleMutableChannel(3) + linear.register_mutable_attr('in_channels', mutable_in) + linear.register_mutable_attr('out_channels', mutable_out) + + print(linear.weight) + + mutable_in.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0]) + mutable_out.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0]) + linear_ex = linear.expand(zero=True) + print(linear_ex.weight) diff --git a/tools/pruning/get_flops.py b/tools/pruning/get_flops.py new file mode 100644 index 000000000..409817e10 --- /dev/null +++ b/tools/pruning/get_flops.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +from mmengine import Config + +from mmrazor.models.algorithms import ItePruneAlgorithm +from mmrazor.models.task_modules import ResourceEstimator +from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput +from mmrazor.registry import MODELS + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('config') + parser.add_argument('-H', default=224, type=int) + parser.add_argument('-W', default=224, type=int) + args = parser.parse_args() + return args + + +def input_generator_wrapper(model, shape, training, scope=None): + + def input_generator(input_shape): + inputs = DefaultDemoInput(scope=scope).get_data( + model, input_shape=input_shape, training=training) + if isinstance(input, dict) and 'mode' in inputs: + inputs['mode'] = 'tensor' + return inputs + + return input_generator + + +if __name__ == '__main__': + args = parse_args() + config = Config.fromfile(args.config) + H = args.H + W = args.W + + default_scope = config['default_scope'] + model_config = config['model'] + # model_config['_scope_'] = default_scope + model: ItePruneAlgorithm = MODELS.build(model_config) + + estimator = ResourceEstimator( + flops_params_cfg=dict( + input_shape=(1, 3, H, W), + print_per_layer_stat=False, + input_constructor=input_generator_wrapper( + model, + (1, 3, H, W), + training=False, + scope=default_scope, + ))) + result = estimator.estimate(model) + print(result)