Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support using backbones from pytorch-image-models (timm) for TSN #880

Merged
merged 10 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,8 @@ jobs:
strategy:
matrix:
python-version: [3.7]
torch: [1.3.0, 1.5.0+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101]
torch: [1.5.0+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101]
include:
- torch: 1.3.0
torchvision: 0.4.1
mmcv: 1.3.0+cu101
- torch: 1.5.0+cu101
torchvision: 0.6.0+cu101
mmcv: 1.5.0+cu101
Expand Down
4 changes: 3 additions & 1 deletion configs/recognition/tsn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ It's possible and convenient to use a 3rd-party backbone for TSN under the frame

- [x] Backbones from [MMClassification](/~https://github.com/open-mmlab/mmclassification/)
- [x] Backbones from [TorchVision](/~https://github.com/pytorch/vision/)
- [x] Backbones from [TIMM (pytorch-image-models)](/~https://github.com/rwightman/pytorch-image-models)

| config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json |
| :----------------------------------------------------------- | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[MMCls](/~https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) |
| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) |
| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) |
| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.51 | 92.92 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb-805380f6.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.json) |

### Kinetics-400 Data Benchmark (8-gpus, ResNet50, ImageNet pretrain; 3 segments)

Expand Down
4 changes: 4 additions & 0 deletions configs/recognition/tsn/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@
用户可在 MMAction2 的框架中使用第三方的主干网络训练 TSN,例如:

- [x] MMClassification 中的主干网络
- [x] TorchVision 中的主干网络
- [x] pytorch-image-models(timm) 中的主干网络

| 配置文件 | 分辨率 | GPU 数量 | 主干网络 | 预训练 | top1 准确率 | top5 准确率 | ckpt | log | json |
| :----------------------------------------------------------: | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8x2 | ResNeXt101-32x4d [[MMCls](/~https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) |
| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) |
| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.51 | 92.92 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb-805380f6.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.json) |

### Kinetics-400 数据基准测试 (8 块 GPU, ResNet50, ImageNet 预训练; 3 个视频段)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
_base_ = [
'../../../_base_/schedules/sgd_100e.py',
'../../../_base_/default_runtime.py'
]

# model settings
model = dict(
type='Recognizer2D',
backbone=dict(type='timm.swin_base_patch4_window7_224', pretrained=True),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=1024,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01),
# model training and testing settings
train_cfg=None,
test_cfg=dict(average_clips=None))

# dataset settings
dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3),
dict(type='DecordDecode'),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=3,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='TenCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=24,
workers_per_gpu=4,
test_dataloader=dict(videos_per_gpu=4),
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
evaluation = dict(
interval=1, metrics=['top_k_accuracy', 'mean_class_accuracy'])

# runtime settings
work_dir = './work_dirs/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/' # noqa
optimizer = dict(
type='SGD',
lr=0.0075, # this lr is used for 8 gpus
momentum=0.9,
weight_decay=0.0001)
18 changes: 18 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
## Changelog

### Master

**Highlights**

- Support using backbone from pytorch-image-models(timm)

**New Features**

- Support using backbones from pytorch-image-models(timm) for TSN ([#880](/~https://github.com/open-mmlab/mmaction2/pull/880))

**Improvements**

**Bug and Typo Fixes**

**ModelZoo**

- Add TSN with Swin Transformer backbone as an example for using pytorch-image-models(timm) backbones ([#880](/~https://github.com/open-mmlab/mmaction2/pull/880))

### 0.15.0 (31/05/2021)

**Highlights**
Expand Down
21 changes: 17 additions & 4 deletions mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ def __init__(self,
self.backbone.classifier = nn.Identity()
self.backbone.fc = nn.Identity()
self.backbone_from = 'torchvision'
elif backbone['type'].startswith('timm.'):
try:
import timm
except (ImportError, ModuleNotFoundError):
raise ImportError('Please install timm to use this '
'backbone.')
backbone_type = backbone.pop('type')[5:]
# disable the classifier
backbone['num_classes'] = 0
self.backbone = timm.create_model(backbone_type, **backbone)
self.backbone_from = 'timm'
else:
self.backbone = builder.build_backbone(backbone)

Expand Down Expand Up @@ -113,11 +124,11 @@ def init_weights(self):
"""Initialize the model network weights."""
if self.backbone_from in ['mmcls', 'mmaction2']:
self.backbone.init_weights()
elif self.backbone_from == 'torchvision':
elif self.backbone_from in ['torchvision', 'timm']:
warnings.warn('We do not initialize weights for backbones in '
'torchvision, since the weights for backbones in '
'torchvision are initialized in their __init__ '
'functions. ')
f'{self.backbone_from}, since the weights for '
f'backbones in {self.backbone_from} are initialized'
'in their __init__ functions.')
else:
raise NotImplementedError('Unsupported backbone source '
f'{self.backbone_from}!')
Expand All @@ -140,6 +151,8 @@ def extract_feat(self, imgs):
if (hasattr(self.backbone, 'features')
and self.backbone_from == 'torchvision'):
x = self.backbone.features(imgs)
elif self.backbone_from == 'timm':
x = self.backbone.forward_features(imgs)
else:
x = self.backbone(imgs)
return x
Expand Down
4 changes: 2 additions & 2 deletions mmaction/models/recognizers/recognizer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def forward_train(self, imgs, labels, **kwargs):

x = self.extract_feat(imgs)

if self.backbone_from == 'torchvision':
if self.backbone_from in ['torchvision', 'timm']:
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
Expand Down Expand Up @@ -55,7 +55,7 @@ def _do_test(self, imgs):

x = self.extract_feat(imgs)

if self.backbone_from == 'torchvision':
if self.backbone_from in ['torchvision', 'timm']:
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
Expand Down
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ moviepy
onnx
onnxruntime
PyTurboJPEG
timm
22 changes: 22 additions & 0 deletions tests/test_models/test_recognizers/test_recognizer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def test_tsn():
for one_img in img_list:
recognizer(one_img, None, return_loss=False)

# test timm backbones
timm_backbone = dict(type='timm.efficientnet_b0', pretrained=False)
config.model['backbone'] = timm_backbone
config.model['cls_head']['in_channels'] = 1280

recognizer = build_recognizer(config.model)

input_shape = (1, 3, 3, 32, 32)
demo_inputs = generate_recognizer_demo_inputs(input_shape)

imgs = demo_inputs['imgs']
gt_labels = demo_inputs['gt_labels']

losses = recognizer(imgs, gt_labels)
assert isinstance(losses, dict)

# Test forward test
with torch.no_grad():
img_list = [img[None, :] for img in imgs]
for one_img in img_list:
recognizer(one_img, None, return_loss=False)


def test_tsm():
config = get_recognizer_cfg('tsm/tsm_r50_1x1x8_50e_kinetics400_rgb.py')
Expand Down