-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathuniformerv2-base-p16-res224_clip-kinetics710-pre_8xb32-u8_kinetics600-rgb.py
174 lines (163 loc) · 5.41 KB
/
uniformerv2-base-p16-res224_clip-kinetics710-pre_8xb32-u8_kinetics600-rgb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
_base_ = ['../../_base_/default_runtime.py']
# model settings
num_frames = 8
model = dict(
type='Recognizer3D',
backbone=dict(
type='UniFormerV2',
input_resolution=224,
patch_size=16,
width=768,
layers=12,
heads=12,
t_size=num_frames,
dw_reduction=1.5,
backbone_drop_path_rate=0.,
temporal_downsample=False,
no_lmhra=True,
double_lmhra=True,
return_list=[8, 9, 10, 11],
n_layers=4,
n_dim=768,
n_head=12,
mlp_factor=4.,
drop_path_rate=0.,
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
clip_pretrained=False,
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmaction/v1.0/recognition/uniformerv2/kinetics710/uniformerv2-base-p16-res224_clip-pre_u8_kinetics710-rgb_20221219-77d34f81.pth', # noqa: E501
prefix='backbone.')),
cls_head=dict(
type='UniFormerHead',
dropout_ratio=0.5,
num_classes=600,
in_channels=768,
average_clips='prob',
channel_map= # noqa: E251
'configs/recognition/uniformerv2/k710_channel_map/map_k600.json',
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmaction/v1.0/recognition/uniformerv2/kinetics710/uniformerv2-base-p16-res224_clip-pre_u8_kinetics710-rgb_20221219-77d34f81.pth', # noqa: E501
prefix='cls_head.')),
data_preprocessor=dict(
type='ActionDataPreprocessor',
mean=[114.75, 114.75, 114.75],
std=[57.375, 57.375, 57.375],
format_shape='NCTHW'))
# dataset settings
dataset_type = 'VideoDataset'
data_root = 'data/kinetics600/videos_train'
data_root_val = 'data/kinetics600/videos_val'
ann_file_train = 'data/kinetics600/kinetics600_train_list_videos.txt'
ann_file_val = 'data/kinetics600/kinetics600_val_list_videos.txt'
ann_file_test = 'data/kinetics600/kinetics600_val_list_videos.txt'
file_client_args = dict(io_backend='disk')
train_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(type='UniformSample', clip_len=num_frames, num_clips=1),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='PytorchVideoWrapper',
op='RandAugment',
magnitude=7,
num_layers=4),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='PackActionInputs')
]
val_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSample', clip_len=num_frames, num_clips=1,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 224)),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='PackActionInputs')
]
test_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='UniformSample', clip_len=num_frames, num_clips=4,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 224)),
dict(type='ThreeCrop', crop_size=224),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='PackActionInputs')
]
train_dataloader = dict(
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=dict(video=data_root),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=dict(video=data_root_val),
pipeline=val_pipeline,
test_mode=True))
test_dataloader = dict(
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=dict(video=data_root_val),
pipeline=test_pipeline,
test_mode=True))
val_evaluator = dict(type='AccMetric')
test_evaluator = dict(type='AccMetric')
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=5, val_begin=1, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
base_lr = 2e-6
optim_wrapper = dict(
optimizer=dict(
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05),
paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0),
clip_grad=dict(max_norm=20, norm_type=2))
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.5,
by_epoch=True,
begin=0,
end=1,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=4,
eta_min_ratio=0.5,
by_epoch=True,
begin=1,
end=5,
convert_to_iter_based=True)
]
default_hooks = dict(
checkpoint=dict(interval=3, max_keep_ckpts=5), logger=dict(interval=100))
# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (8 samples per GPU).
auto_scale_lr = dict(enable=True, base_batch_size=256)