-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_facecnet.py
48 lines (40 loc) · 1.14 KB
/
train_facecnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import json
import sys
import pprint
from argparse import Namespace
sys.path.append(".")
from options.train_options import TrainOptions
from model.facenet import Net
def create_initial_experiment_dir(opts):
if os.path.exists(opts.exp_dir):
raise Exception('Oops... {} already exists'.format(opts.exp_dir))
os.makedirs(opts.exp_dir)
opts_dict = vars(opts)
pprint.pprint(opts_dict)
with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
json.dump(opts_dict, f, indent=4, sort_keys=True)
if __name__ == '__main__':
opts = TrainOptions().parse()
previous_train_ckpt = None
create_initial_experiment_dir(opts)
coach = Net(opts, previous_train_ckpt)
coach.train()
"""
python train_facecnet.py \
--dataset_type facecnet_data \
--exp_dir experiment/ \
--start_from_latent_avg \
--val_interval 40000 \
--save_interval 20000 \
--max_steps 200000 \
--image_interval 1000 \
--stylegan_size 512 \
--stylegan_weights pretrained_models/stylegan2.pt \
--workers 4 \
--batch_size 4 \
--test_batch_size 8 \
--test_workers 4 \
--checkpoint_path pretrained_models/0412_best_model.pt \
--load_model pretrained_model/facenet_lq_best_model.pt
"""