diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 83763ac1c7..924653e894 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,11 +84,8 @@ jobs: strategy: matrix: python-version: [3.7] - torch: [1.3.0, 1.5.0+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101] + torch: [1.5.0+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101] include: - - torch: 1.3.0 - torchvision: 0.4.1 - mmcv: 1.3.0+cu101 - torch: 1.5.0+cu101 torchvision: 0.6.0+cu101 mmcv: 1.5.0+cu101 diff --git a/configs/recognition/tsn/README.md b/configs/recognition/tsn/README.md index 3466febc01..9f1f6388cc 100644 --- a/configs/recognition/tsn/README.md +++ b/configs/recognition/tsn/README.md @@ -67,11 +67,15 @@ It's possible and convenient to use a 3rd-party backbone for TSN under the frame - [x] Backbones from [MMClassification](/~https://github.com/open-mmlab/mmclassification/) - [x] Backbones from [TorchVision](/~https://github.com/pytorch/vision/) +- [x] Backbones from [TIMM (pytorch-image-models)](/~https://github.com/rwightman/pytorch-image-models) | config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json | | :----------------------------------------------------------- | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[MMCls](/~https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) | -| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | +| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | +| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.51 | 92.92 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb-805380f6.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.json) | + +1. Note that some backbones in TIMM are not supported due to multiple reasons. Please refer to to [PR #880](/~https://github.com/open-mmlab/mmaction2/pull/880) for details. ### Kinetics-400 Data Benchmark (8-gpus, ResNet50, ImageNet pretrain; 3 segments) diff --git a/configs/recognition/tsn/README_zh-CN.md b/configs/recognition/tsn/README_zh-CN.md index f6e9c08c61..40c14b28a3 100644 --- a/configs/recognition/tsn/README_zh-CN.md +++ b/configs/recognition/tsn/README_zh-CN.md @@ -66,10 +66,16 @@ 用户可在 MMAction2 的框架中使用第三方的主干网络训练 TSN,例如: - [x] MMClassification 中的主干网络 +- [x] TorchVision 中的主干网络 +- [x] pytorch-image-models(timm) 中的主干网络 | 配置文件 | 分辨率 | GPU 数量 | 主干网络 | 预训练 | top1 准确率 | top5 准确率 | ckpt | log | json | | :----------------------------------------------------------: | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8x2 | ResNeXt101-32x4d [[MMCls](/~https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) | +| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | +| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.51 | 92.92 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb-805380f6.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.json) | + +1. 由于多种原因,TIMM 中的一些模型未能收到支持,详情请参考 [PR #880](/~https://github.com/open-mmlab/mmaction2/pull/880)。 ### Kinetics-400 数据基准测试 (8 块 GPU, ResNet50, ImageNet 预训练; 3 个视频段) diff --git a/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py b/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py new file mode 100644 index 0000000000..815ef38d0f --- /dev/null +++ b/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py @@ -0,0 +1,105 @@ +_base_ = [ + '../../../_base_/schedules/sgd_100e.py', + '../../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='Recognizer2D', + backbone=dict(type='timm.swin_base_patch4_window7_224', pretrained=True), + cls_head=dict( + type='TSNHead', + num_classes=400, + in_channels=1024, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.4, + init_std=0.01), + # model training and testing settings + train_cfg=None, + test_cfg=dict(average_clips=None)) + +# dataset settings +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='DecordInit'), + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3), + dict(type='DecordDecode'), + dict(type='RandomResizedCrop'), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=3, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=25, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='TenCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=24, + workers_per_gpu=4, + test_dataloader=dict(videos_per_gpu=4), + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) +evaluation = dict( + interval=1, metrics=['top_k_accuracy', 'mean_class_accuracy']) + +# runtime settings +work_dir = './work_dirs/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/' # noqa +optimizer = dict( + type='SGD', + lr=0.0075, # this lr is used for 8 gpus + momentum=0.9, + weight_decay=0.0001) diff --git a/docs/changelog.md b/docs/changelog.md index 67fccc19a5..f3d58d3bfc 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,23 @@ ## Changelog +### Master + +**Highlights** + +- Support using backbone from pytorch-image-models(timm) + +**New Features** + +- Support using backbones from pytorch-image-models(timm) for TSN ([#880](/~https://github.com/open-mmlab/mmaction2/pull/880)) + +**Improvements** + +**Bug and Typo Fixes** + +**ModelZoo** + +- Add TSN with Swin Transformer backbone as an example for using pytorch-image-models(timm) backbones ([#880](/~https://github.com/open-mmlab/mmaction2/pull/880)) + ### 0.15.0 (31/05/2021) **Highlights** diff --git a/mmaction/models/recognizers/base.py b/mmaction/models/recognizers/base.py index bcbde8468a..41164f3bd2 100644 --- a/mmaction/models/recognizers/base.py +++ b/mmaction/models/recognizers/base.py @@ -60,6 +60,17 @@ def __init__(self, self.backbone.classifier = nn.Identity() self.backbone.fc = nn.Identity() self.backbone_from = 'torchvision' + elif backbone['type'].startswith('timm.'): + try: + import timm + except (ImportError, ModuleNotFoundError): + raise ImportError('Please install timm to use this ' + 'backbone.') + backbone_type = backbone.pop('type')[5:] + # disable the classifier + backbone['num_classes'] = 0 + self.backbone = timm.create_model(backbone_type, **backbone) + self.backbone_from = 'timm' else: self.backbone = builder.build_backbone(backbone) @@ -113,11 +124,11 @@ def init_weights(self): """Initialize the model network weights.""" if self.backbone_from in ['mmcls', 'mmaction2']: self.backbone.init_weights() - elif self.backbone_from == 'torchvision': + elif self.backbone_from in ['torchvision', 'timm']: warnings.warn('We do not initialize weights for backbones in ' - 'torchvision, since the weights for backbones in ' - 'torchvision are initialized in their __init__ ' - 'functions. ') + f'{self.backbone_from}, since the weights for ' + f'backbones in {self.backbone_from} are initialized' + 'in their __init__ functions.') else: raise NotImplementedError('Unsupported backbone source ' f'{self.backbone_from}!') @@ -140,6 +151,8 @@ def extract_feat(self, imgs): if (hasattr(self.backbone, 'features') and self.backbone_from == 'torchvision'): x = self.backbone.features(imgs) + elif self.backbone_from == 'timm': + x = self.backbone.forward_features(imgs) else: x = self.backbone(imgs) return x diff --git a/mmaction/models/recognizers/recognizer2d.py b/mmaction/models/recognizers/recognizer2d.py index d3444845f6..6b4bedba04 100644 --- a/mmaction/models/recognizers/recognizer2d.py +++ b/mmaction/models/recognizers/recognizer2d.py @@ -21,7 +21,7 @@ def forward_train(self, imgs, labels, **kwargs): x = self.extract_feat(imgs) - if self.backbone_from == 'torchvision': + if self.backbone_from in ['torchvision', 'timm']: if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1): # apply adaptive avg pooling x = nn.AdaptiveAvgPool2d(1)(x) @@ -55,7 +55,7 @@ def _do_test(self, imgs): x = self.extract_feat(imgs) - if self.backbone_from == 'torchvision': + if self.backbone_from in ['torchvision', 'timm']: if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1): # apply adaptive avg pooling x = nn.AdaptiveAvgPool2d(1)(x) diff --git a/requirements/optional.txt b/requirements/optional.txt index 839d3acc87..3177ef6221 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -7,3 +7,4 @@ moviepy onnx onnxruntime PyTurboJPEG +timm diff --git a/tests/test_models/test_recognizers/test_recognizer2d.py b/tests/test_models/test_recognizers/test_recognizer2d.py index 927f046273..8d4cf23744 100644 --- a/tests/test_models/test_recognizers/test_recognizer2d.py +++ b/tests/test_models/test_recognizers/test_recognizer2d.py @@ -97,6 +97,28 @@ def test_tsn(): for one_img in img_list: recognizer(one_img, None, return_loss=False) + # test timm backbones + timm_backbone = dict(type='timm.efficientnet_b0', pretrained=False) + config.model['backbone'] = timm_backbone + config.model['cls_head']['in_channels'] = 1280 + + recognizer = build_recognizer(config.model) + + input_shape = (1, 3, 3, 32, 32) + demo_inputs = generate_recognizer_demo_inputs(input_shape) + + imgs = demo_inputs['imgs'] + gt_labels = demo_inputs['gt_labels'] + + losses = recognizer(imgs, gt_labels) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [img[None, :] for img in imgs] + for one_img in img_list: + recognizer(one_img, None, return_loss=False) + def test_tsm(): config = get_recognizer_cfg('tsm/tsm_r50_1x1x8_50e_kinetics400_rgb.py')