-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
51 lines (44 loc) · 1.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
import os
import sys
import datetime
import yaml
import logging
from solver import SSDSolver
def parse_args():
parser = argparse.ArgumentParser(description='Train networks.')
parser.add_argument('--cfg', type=str, default='./config.yaml')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
with open(args.cfg, 'r') as f:
config = yaml.full_load(f)
logging.info(config)
network = config['network']
layers = config['layers']
num_filters = config['num_filters']
anchor_sizes = config['anchor_sizes']
anchor_ratios = config['anchor_ratios']
steps = config['steps']
dataset = config['dataset']
input_shape = config['input_shape']
train_split = config['train_split']
batch_size = config['batch_size']
optimizer = config['optimizer']
lr = config['lr']
wd = config['wd']
momentum = config['momentum']
epoch = config['epoch']
lr_decay = config.get('lr_decay', 0.1)
train_split = config['train_split']
val_split = config['val_split']
use_amp = config['use_amp']
gpus = config['gpus']
save_prefix = config['save_prefix']
solver = SSDSolver(network, layers, num_filters, anchor_sizes,
anchor_ratios, steps, dataset, input_shape,
batch_size, optimizer, lr, wd, momentum, epoch,
lr_decay, train_split,
val_split, use_amp, gpus, save_prefix)
solver.train()