-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtrainer.py
89 lines (77 loc) · 2.9 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
sys.path.append('./clip-grams')
import json
import os
import torch
import numpy as np
import pytorch_lightning as pl
from CLIP import clip
from argparse import ArgumentParser
from pytorch_lightning.callbacks import ModelCheckpoint
import dataset
import model
import clipgrams
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def main():
# Args
parser = ArgumentParser()
parser.add_argument('--datadir', type=str)
parser.add_argument('--train_datadir', type=str)
parser.add_argument('--dev_datadir', type=str)
parser.add_argument('--textfile', type=str)
parser.add_argument('--embfile', type=str)
parser.add_argument('--index_dir', type=str, default=None)
parser.add_argument('--clip_model', type=str, default='ViT-B/16')
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--maxlen_enc', type=int, default=128)
parser.add_argument('--maxlen_dec', type=int, default=32)
parser.add_argument('--gpus', type=int, default=1)
parser.add_argument('--tpu_cores', type=int, default=None)
parser.add_argument('--nworkers', type=int, default=4)
parser.add_argument('--topk', type=int, default=10)
parser.add_argument('--val_after_n_epochs', type=int, default=1)
parser.add_argument('--tmax', type=int, default=1e5)
parser.add_argument('--save_top_k', type=int, default=10)
parser.add_argument('--lrate', type=float, default=3e-4)
args = parser.parse_args()
# Load cache
if args.index_dir:
fname = os.path.join(args.index_dir, 'args.txt')
with open(fname, 'r') as f:
index_args = json.load(f)
for key in list(index_args.keys()):
if key not in args.__dict__.keys():
args.__dict__[key] = index_args[key]
cache = clipgrams.TextDataset(folder=args.text_dir, args=args).data
cache_emb = clipgrams.load_index(args)
else:
cache = []
with open(args.textfile) as f:
for line in f:
cache.append(line.strip())
cache_emb = np.load(args.embfile)
# Load image preprocessor
preprocess = clip.load(args.clip_model, jit=False)[1]
# Train model
ckpt_callback = ModelCheckpoint(
monitor='vloss',
mode='min',
filename='-{epoch:02d}-{vloss:.3f}',
save_top_k=args.save_top_k)
datamodule = dataset.DataModule(
train_datadir=args.train_datadir,
dev_datadir=args.dev_datadir,
batch_size=args.batch_size,
nworkers=args.nworkers,
preprocess=preprocess)
net = model.Model(args, cache=cache, cache_emb=cache_emb)
trainer = pl.Trainer(
default_root_dir=args.datadir,
gpus=args.gpus,
tpu_cores=args.tpu_cores,
max_steps=args.tmax,
callbacks=[ckpt_callback],
check_val_every_n_epoch=args.val_after_n_epochs)
trainer.fit(net, datamodule)
if __name__ == '__main__':
main()