-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathfew_shot_drive_huge.py
116 lines (103 loc) · 4.93 KB
/
few_shot_drive_huge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
_base_ = ['../../_base_/seg_default_runtime.py',
'./git_huge.py'
]
load_from = './universal_huge.pth'
global_bin = _base_.global_bin
base_img_size = _base_.base_img_size
backend_args = None
drive_semseg_cfgs = dict(
mode='semantic_segmentation',
grid_resolution_perwin=[14, 14],
samples_grids_eachwin=32,
grid_interpolate=True,
num_classes=2,
num_vocal=2+1,
total_num_vocal=2+1,
max_decoder_length=16,
global_only_image=True)
drive_semseg_train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='SegLoadAnnotations'),
dict(type='AddMetaInfo', meta_dict=dict(task_name='semantic_segmentation',
head_cfg=dict(num_classes=2,
num_vocal=3,
dec_length=16,
dec_pixel_resolution=[4, 4],
arg_max_inference=True,
ignore_index=255),
git_cfg=drive_semseg_cfgs)),
dict(type='RandomChoice',
transforms=[[dict(type='RandomChoiceResize', scales=[(672, 672)], keep_ratio=False)],
[dict(type='RandomChoiceResize', scales=[(int(672 * x * 0.1), int(672 * x * 0.1)) for x in range(10, 21)], keep_ratio=False),
dict(type='SegRandomCrop', crop_size=(672, 672), cat_max_ratio=0.75),]]),
dict(type='MMCVRandomFlip', prob=0.5),
dict(type='SegPhotoMetricDistortion'),
dict(type='PackSegInputs', meta_keys=('img_path', 'seg_map_path', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'reduce_zero_label', 'task_name', 'head_cfg', 'git_cfg'))]
drive_semseg_test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(672, 672), keep_ratio=False),
# If you don't have a gt annotation, delete the pipeline
dict(type='SegLoadAnnotations'),
dict(type='AddMetaInfo', meta_dict=dict(task_name='semantic_segmentation',
head_cfg=dict(num_classes=2,
num_vocal=3,
dec_length=16,
dec_pixel_resolution=[4, 4],
arg_max_inference=True,
ignore_index=255),
git_cfg=drive_semseg_cfgs)),
dict(type='PackSegInputs', meta_keys=('img_path', 'seg_map_path', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'reduce_zero_label', 'task_name', 'head_cfg', 'git_cfg'))]
train_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=False),
dataset=dict(
type='DRIVEDataset',
data_root='data/DRIVE',
data_prefix=dict(
img_path='images/training',
seg_map_path='annotations/training'),
support_num = 5*2,
return_classes=True,
pipeline=drive_semseg_train_pipeline)
)
max_iters=100
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=max_iters, val_interval=50)
test_cfg = dict(type='TestLoop')
val_cfg = dict(type='ValLoop')
param_scheduler = [
dict(type='MultiStepLR', by_epoch=False, milestones=[max_iters], gamma=0.1)
]
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='DRIVEDataset',
data_root='data/DRIVE',
data_prefix=dict(
img_path='images/validation',
seg_map_path='annotations/validation'),
return_classes=True,
pipeline=drive_semseg_test_pipeline))
test_pipeline = drive_semseg_test_pipeline
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU','mDice'])
test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=5, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=50, max_keep_ckpts=50),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook',draw=False,interval=10,show=False))
log_processor = dict(type='LogProcessor', window_size=4000, by_epoch=False)