From 9e425a743bfe6afaed709f487300eb3fc6e2f967 Mon Sep 17 00:00:00 2001 From: irving Date: Sat, 22 May 2021 19:26:28 +0800 Subject: [PATCH 1/8] first commit --- configs/recognition/tsn/README.md | 2 + ...sformer_320p_1x1x3_100e_kinetics400_rgb.py | 100 ++++++++++++++++++ mmaction/models/recognizers/base.py | 21 +++- mmaction/models/recognizers/recognizer2d.py | 4 +- requirements/optional.txt | 1 + .../test_recognizers/test_recognizer2d.py | 22 ++++ 6 files changed, 144 insertions(+), 6 deletions(-) create mode 100644 configs/recognition/tsn/custom_backbones/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py diff --git a/configs/recognition/tsn/README.md b/configs/recognition/tsn/README.md index 3466febc01..745b20cb12 100644 --- a/configs/recognition/tsn/README.md +++ b/configs/recognition/tsn/README.md @@ -67,11 +67,13 @@ It's possible and convenient to use a 3rd-party backbone for TSN under the frame - [x] Backbones from [MMClassification](/~https://github.com/open-mmlab/mmclassification/) - [x] Backbones from [TorchVision](/~https://github.com/pytorch/vision/) +- [x] Backbones from [pytorch-image-models(timm)](/~https://github.com/rwightman/pytorch-image-models) | config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json | | :----------------------------------------------------------- | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[MMCls](/~https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) | | [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | +| [tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | | | [ckpt]() | [log]() | [json]() | ### Kinetics-400 Data Benchmark (8-gpus, ResNet50, ImageNet pretrain; 3 segments) diff --git a/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py b/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py new file mode 100644 index 0000000000..0d389c330b --- /dev/null +++ b/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py @@ -0,0 +1,100 @@ +_base_ = [ + '../../../_base_/schedules/sgd_100e.py', + '../../../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='Recognizer2D', + backbone=dict(type='timm.swin_base_patch4_window7_224', pretrained=True), + cls_head=dict( + type='TSNHead', + num_classes=400, + in_channels=1024, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.4, + init_std=0.01), + # model training and testing settings + train_cfg=None, + test_cfg=dict(average_clips=None)) + +# dataset settings +dataset_type = 'RawframeDataset' +data_root = 'data/kinetics400/rawframes_train_320p' +data_root_val = 'data/kinetics400/rawframes_val_320p' +ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes_320p.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='RandomResizedCrop'), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=3, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=256), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=25, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=12, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) + +# runtime settings +work_dir = './work_dirs/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb/' +optimizer = dict( + type='SGD', + lr=0.00375, # this lr is used for 8 gpus + momentum=0.9, + weight_decay=0.0001) diff --git a/mmaction/models/recognizers/base.py b/mmaction/models/recognizers/base.py index 281aa547e1..eea7f8ed8c 100644 --- a/mmaction/models/recognizers/base.py +++ b/mmaction/models/recognizers/base.py @@ -58,6 +58,17 @@ def __init__(self, self.backbone.classifier = nn.Identity() self.backbone.fc = nn.Identity() self.backbone_from = 'torchvision' + elif backbone['type'].startswith('timm.'): + try: + import timm + except (ImportError, ModuleNotFoundError): + raise ImportError('Please install timm to use this ' + 'backbone.') + backbone_type = backbone.pop('type')[5:] + # disable the classifier + backbone['num_classes'] = 0 + self.backbone = timm.create_model(backbone_type, **backbone) + self.backbone_from = 'timm' else: self.backbone = builder.build_backbone(backbone) @@ -100,11 +111,11 @@ def init_weights(self): """Initialize the model network weights.""" if self.backbone_from in ['mmcls', 'mmaction2']: self.backbone.init_weights() - elif self.backbone_from == 'torchvision': + elif self.backbone_from in ['torchvision', 'timm']: warnings.warn('We do not initialize weights for backbones in ' - 'torchvision, since the weights for backbones in ' - 'torchvision are initialized in their __init__ ' - 'functions. ') + f'{self.backbone_from}, since the weights for ' + f'backbones in {self.backbone_from} are initialized' + 'in their __init__ functions.') else: raise NotImplementedError('Unsupported backbone source ' f'{self.backbone_from}!') @@ -126,6 +137,8 @@ def extract_feat(self, imgs): if (hasattr(self.backbone, 'features') and self.backbone_from == 'torchvision'): x = self.backbone.features(imgs) + elif self.backbone_from == 'timm': + x = self.backbone.forward_features(imgs) else: x = self.backbone(imgs) return x diff --git a/mmaction/models/recognizers/recognizer2d.py b/mmaction/models/recognizers/recognizer2d.py index 16f6349be8..f7fb93ac7a 100644 --- a/mmaction/models/recognizers/recognizer2d.py +++ b/mmaction/models/recognizers/recognizer2d.py @@ -19,7 +19,7 @@ def forward_train(self, imgs, labels, **kwargs): x = self.extract_feat(imgs) - if self.backbone_from == 'torchvision': + if self.backbone_from in ['torchvision', 'timm']: if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1): # apply adaptive avg pooling x = nn.AdaptiveAvgPool2d(1)(x) @@ -53,7 +53,7 @@ def _do_test(self, imgs): x = self.extract_feat(imgs) - if self.backbone_from == 'torchvision': + if self.backbone_from in ['torchvision', 'timm']: if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1): # apply adaptive avg pooling x = nn.AdaptiveAvgPool2d(1)(x) diff --git a/requirements/optional.txt b/requirements/optional.txt index 839d3acc87..3177ef6221 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -7,3 +7,4 @@ moviepy onnx onnxruntime PyTurboJPEG +timm diff --git a/tests/test_models/test_recognizers/test_recognizer2d.py b/tests/test_models/test_recognizers/test_recognizer2d.py index 927f046273..8d4cf23744 100644 --- a/tests/test_models/test_recognizers/test_recognizer2d.py +++ b/tests/test_models/test_recognizers/test_recognizer2d.py @@ -97,6 +97,28 @@ def test_tsn(): for one_img in img_list: recognizer(one_img, None, return_loss=False) + # test timm backbones + timm_backbone = dict(type='timm.efficientnet_b0', pretrained=False) + config.model['backbone'] = timm_backbone + config.model['cls_head']['in_channels'] = 1280 + + recognizer = build_recognizer(config.model) + + input_shape = (1, 3, 3, 32, 32) + demo_inputs = generate_recognizer_demo_inputs(input_shape) + + imgs = demo_inputs['imgs'] + gt_labels = demo_inputs['gt_labels'] + + losses = recognizer(imgs, gt_labels) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [img[None, :] for img in imgs] + for one_img in img_list: + recognizer(one_img, None, return_loss=False) + def test_tsm(): config = get_recognizer_cfg('tsm/tsm_r50_1x1x8_50e_kinetics400_rgb.py') From f692e60ddabaf2df0bd4be680b99085c0be059c0 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Mon, 7 Jun 2021 10:59:12 +0800 Subject: [PATCH 2/8] update swin transformer res --- configs/recognition/tsn/README.md | 4 +-- ..._video_320p_1x1x3_100e_kinetics400_rgb.py} | 36 +++++++++++-------- 2 files changed, 23 insertions(+), 17 deletions(-) rename configs/recognition/tsn/custom_backbones/{tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py => tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py} (74%) diff --git a/configs/recognition/tsn/README.md b/configs/recognition/tsn/README.md index 745b20cb12..a1c3b8e81c 100644 --- a/configs/recognition/tsn/README.md +++ b/configs/recognition/tsn/README.md @@ -72,8 +72,8 @@ It's possible and convenient to use a 3rd-party backbone for TSN under the frame | config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json | | :----------------------------------------------------------- | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[MMCls](/~https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) | -| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | -| [tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | | | [ckpt]() | [log]() | [json]() | +| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | +| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.31 | 92.88 | [ckpt]() | [log]() | [json]() | ### Kinetics-400 Data Benchmark (8-gpus, ResNet50, ImageNet pretrain; 3 segments) diff --git a/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py b/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py similarity index 74% rename from configs/recognition/tsn/custom_backbones/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py rename to configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py index 0d389c330b..fd310324bc 100644 --- a/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb.py +++ b/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py @@ -20,18 +20,18 @@ test_cfg=dict(average_clips=None)) # dataset settings -dataset_type = 'RawframeDataset' -data_root = 'data/kinetics400/rawframes_train_320p' -data_root_val = 'data/kinetics400/rawframes_val_320p' -ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes_320p.txt' -ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt' -ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt' +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) train_pipeline = [ + dict(type='DecordInit'), dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3), - dict(type='RawFrameDecode'), - dict(type='Resize', scale=(-1, 256)), + dict(type='DecordDecode'), dict(type='RandomResizedCrop'), dict(type='Resize', scale=(224, 224), keep_ratio=False), dict(type='Flip', flip_ratio=0.5), @@ -41,15 +41,16 @@ dict(type='ToTensor', keys=['imgs', 'label']) ] val_pipeline = [ + dict(type='DecordInit'), dict( type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3, test_mode=True), - dict(type='RawFrameDecode'), + dict(type='DecordDecode'), dict(type='Resize', scale=(-1, 256)), - dict(type='CenterCrop', crop_size=256), + dict(type='CenterCrop', crop_size=224), dict(type='Flip', flip_ratio=0), dict(type='Normalize', **img_norm_cfg), dict(type='FormatShape', input_format='NCHW'), @@ -57,15 +58,16 @@ dict(type='ToTensor', keys=['imgs']) ] test_pipeline = [ + dict(type='DecordInit'), dict( type='SampleFrames', clip_len=1, frame_interval=1, num_clips=25, test_mode=True), - dict(type='RawFrameDecode'), + dict(type='DecordDecode'), dict(type='Resize', scale=(-1, 256)), - dict(type='ThreeCrop', crop_size=256), + dict(type='TenCrop', crop_size=224), dict(type='Flip', flip_ratio=0), dict(type='Normalize', **img_norm_cfg), dict(type='FormatShape', input_format='NCHW'), @@ -73,8 +75,9 @@ dict(type='ToTensor', keys=['imgs']) ] data = dict( - videos_per_gpu=12, + videos_per_gpu=24, workers_per_gpu=4, + test_dataloader=dict(videos_per_gpu=4), train=dict( type=dataset_type, ann_file=ann_file_train, @@ -90,11 +93,14 @@ ann_file=ann_file_test, data_prefix=data_root_val, pipeline=test_pipeline)) +evaluation = dict( + interval=1, metrics=['top_k_accuracy', 'mean_class_accuracy']) + # runtime settings -work_dir = './work_dirs/tsn_swin_transformer_320p_1x1x3_100e_kinetics400_rgb/' +work_dir = './work_dirs/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/' optimizer = dict( type='SGD', - lr=0.00375, # this lr is used for 8 gpus + lr=0.0075, # this lr is used for 8 gpus momentum=0.9, weight_decay=0.0001) From e443f033f4409655a58c3f9d678e11cf2af07c46 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Mon, 7 Jun 2021 11:25:20 +0800 Subject: [PATCH 3/8] update cn docs --- configs/recognition/tsn/README_zh-CN.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/configs/recognition/tsn/README_zh-CN.md b/configs/recognition/tsn/README_zh-CN.md index f6e9c08c61..404d58f004 100644 --- a/configs/recognition/tsn/README_zh-CN.md +++ b/configs/recognition/tsn/README_zh-CN.md @@ -66,10 +66,14 @@ 用户可在 MMAction2 的框架中使用第三方的主干网络训练 TSN,例如: - [x] MMClassification 中的主干网络 +- [x] TorchVision 中的主干网络 +- [x] pytorch-image-models(timm) 中的主干网络 | 配置文件 | 分辨率 | GPU 数量 | 主干网络 | 预训练 | top1 准确率 | top5 准确率 | ckpt | log | json | | :----------------------------------------------------------: | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8x2 | ResNeXt101-32x4d [[MMCls](/~https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) | +| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | +| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.31 | 92.88 | [ckpt]() | [log]() | [json]() | ### Kinetics-400 数据基准测试 (8 块 GPU, ResNet50, ImageNet 预训练; 3 个视频段) From 25b2fcc95a32d5e45e950446f2320d38e12409d8 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Mon, 7 Jun 2021 11:30:37 +0800 Subject: [PATCH 4/8] update changelog --- docs/changelog.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 67fccc19a5..f3d58d3bfc 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,23 @@ ## Changelog +### Master + +**Highlights** + +- Support using backbone from pytorch-image-models(timm) + +**New Features** + +- Support using backbones from pytorch-image-models(timm) for TSN ([#880](/~https://github.com/open-mmlab/mmaction2/pull/880)) + +**Improvements** + +**Bug and Typo Fixes** + +**ModelZoo** + +- Add TSN with Swin Transformer backbone as an example for using pytorch-image-models(timm) backbones ([#880](/~https://github.com/open-mmlab/mmaction2/pull/880)) + ### 0.15.0 (31/05/2021) **Highlights** From 9fa4be9bd275e81de77697ee3d921fa00f0db035 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Mon, 7 Jun 2021 15:22:56 +0800 Subject: [PATCH 5/8] fix lint --- ...n_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py b/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py index fd310324bc..815ef38d0f 100644 --- a/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py +++ b/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py @@ -96,9 +96,8 @@ evaluation = dict( interval=1, metrics=['top_k_accuracy', 'mean_class_accuracy']) - # runtime settings -work_dir = './work_dirs/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/' +work_dir = './work_dirs/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/' # noqa optimizer = dict( type='SGD', lr=0.0075, # this lr is used for 8 gpus From 4621a7f5a16ecbdfc5ac1fe98cd21236c28d8893 Mon Sep 17 00:00:00 2001 From: Kenny Date: Thu, 10 Jun 2021 14:12:40 +0800 Subject: [PATCH 6/8] update model links --- configs/recognition/tsn/README.md | 4 ++-- configs/recognition/tsn/README_zh-CN.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/recognition/tsn/README.md b/configs/recognition/tsn/README.md index a1c3b8e81c..0edd90630f 100644 --- a/configs/recognition/tsn/README.md +++ b/configs/recognition/tsn/README.md @@ -67,13 +67,13 @@ It's possible and convenient to use a 3rd-party backbone for TSN under the frame - [x] Backbones from [MMClassification](/~https://github.com/open-mmlab/mmclassification/) - [x] Backbones from [TorchVision](/~https://github.com/pytorch/vision/) -- [x] Backbones from [pytorch-image-models(timm)](/~https://github.com/rwightman/pytorch-image-models) +- [x] Backbones from [TIMM (pytorch-image-models)](/~https://github.com/rwightman/pytorch-image-models) | config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json | | :----------------------------------------------------------- | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[MMCls](/~https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) | | [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | -| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.31 | 92.88 | [ckpt]() | [log]() | [json]() | +| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.51 | 92.92 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb-805380f6.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.json) | ### Kinetics-400 Data Benchmark (8-gpus, ResNet50, ImageNet pretrain; 3 segments) diff --git a/configs/recognition/tsn/README_zh-CN.md b/configs/recognition/tsn/README_zh-CN.md index 404d58f004..63fe85e4c7 100644 --- a/configs/recognition/tsn/README_zh-CN.md +++ b/configs/recognition/tsn/README_zh-CN.md @@ -73,7 +73,7 @@ | :----------------------------------------------------------: | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8x2 | ResNeXt101-32x4d [[MMCls](/~https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) | | [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | -| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.31 | 92.88 | [ckpt]() | [log]() | [json]() | +| [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.51 | 92.92 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb-805380f6.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.json) | ### Kinetics-400 数据基准测试 (8 块 GPU, ResNet50, ImageNet 预训练; 3 个视频段) From c6f57e915e0cd59e30b8ecf98b687402088d3d60 Mon Sep 17 00:00:00 2001 From: Kenny Date: Thu, 10 Jun 2021 14:45:39 +0800 Subject: [PATCH 7/8] no longer run torch1.3.0 in CI --- .github/workflows/build.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 83763ac1c7..924653e894 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,11 +84,8 @@ jobs: strategy: matrix: python-version: [3.7] - torch: [1.3.0, 1.5.0+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101] + torch: [1.5.0+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101] include: - - torch: 1.3.0 - torchvision: 0.4.1 - mmcv: 1.3.0+cu101 - torch: 1.5.0+cu101 torchvision: 0.6.0+cu101 mmcv: 1.5.0+cu101 From fd5531d4ece5dfc03c7da5bed9519fd46ca0585d Mon Sep 17 00:00:00 2001 From: Kenny Date: Thu, 10 Jun 2021 19:39:12 +0800 Subject: [PATCH 8/8] upload README --- configs/recognition/tsn/README.md | 2 ++ configs/recognition/tsn/README_zh-CN.md | 2 ++ 2 files changed, 4 insertions(+) diff --git a/configs/recognition/tsn/README.md b/configs/recognition/tsn/README.md index 0edd90630f..9f1f6388cc 100644 --- a/configs/recognition/tsn/README.md +++ b/configs/recognition/tsn/README.md @@ -75,6 +75,8 @@ It's possible and convenient to use a 3rd-party backbone for TSN under the frame | [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | | [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.51 | 92.92 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb-805380f6.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.json) | +1. Note that some backbones in TIMM are not supported due to multiple reasons. Please refer to to [PR #880](/~https://github.com/open-mmlab/mmaction2/pull/880) for details. + ### Kinetics-400 Data Benchmark (8-gpus, ResNet50, ImageNet pretrain; 3 segments) In data benchmark, we compare: diff --git a/configs/recognition/tsn/README_zh-CN.md b/configs/recognition/tsn/README_zh-CN.md index 63fe85e4c7..40c14b28a3 100644 --- a/configs/recognition/tsn/README_zh-CN.md +++ b/configs/recognition/tsn/README_zh-CN.md @@ -75,6 +75,8 @@ | [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | 短边 320 | 8x2 | Densenet-161 [[TorchVision](/~https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) | | [tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8 | Swin Transformer Base [[timm](/~https://github.com/rwightman/pytorch-image-models)] | ImageNet | 77.51 | 92.92 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb-805380f6.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb/tsn_swin_transformer_video_320p_1x1x3_100e_kinetics400_rgb.json) | +1. 由于多种原因,TIMM 中的一些模型未能收到支持,详情请参考 [PR #880](/~https://github.com/open-mmlab/mmaction2/pull/880)。 + ### Kinetics-400 数据基准测试 (8 块 GPU, ResNet50, ImageNet 预训练; 3 个视频段) 在数据基准测试中,比较: