-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval.py
67 lines (53 loc) · 2.4 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '6'
import sys
sys.path.append('..')
from utils.dataset import getDataset
from models.tracker_eval import *
from torch.utils.data import DataLoader
if __name__ == '__main__':
import numpy as np
import torch
ckpt_path = './ckpt.pth'
data_folder = './data'
model = TrackerNetEval(feature_dim=384, hgnn=True)
model_info = torch.load(ckpt_path)
model.load_state_dict(model_info["state_dict"])
model = model.cuda()
print(f"Loaded from: {ckpt_path}")
testDataset = getDataset(data_folder=data_folder, train=False)
with torch.no_grad():
model.eval()
for n, td in enumerate(testDataset):
DL = DataLoader(td, 1)
model.reset()
opFolder = os.path.join('./output', td.sequence_name)
os.makedirs(opFolder, exist_ok=True)
for i, (ts, data, gt) in enumerate(DL):
for k, v in data.items():
data[k] = v.cuda()
for k, v in gt.items():
gt[k] = v.cuda()
current_pos_l = data['u_centers_l']
# current_pos_r = data['u_centers_r']
ref_patch = data['ref_img']
pred = None
pos_l = []
disp = []
pos_3d = []
for unroll in range(data['ev_frame_left'].shape[1]):
ev_frame_l = data['ev_frame_left'][:, unroll]
ev_frame_r = data['ev_frame_right'][:, unroll]
# gt_flow_l = (gt['track_l'][:, :, unroll] - current_pos_l)
# gt_disp = gt['disp'][:, :, unroll]
flow_l_pred, disp_pred, pred = model(ev_frame_l, ev_frame_r, ref_patch, current_pos_l, None, pred=pred)
current_pos_l += flow_l_pred.detach()
pos = td.reprojectImageTo3D_ph(disp_pred[0].cpu(), current_pos_l[0].cpu())
disp.append(disp_pred)
pos_l.append(current_pos_l.clone())
pos_3d.append(pos)
pos_3d = torch.stack(pos_3d)
np.save(os.path.join(opFolder, 'pos_3d_pred.npy'), np.array(pos_3d.cpu()))
np.save(os.path.join(opFolder, 'pos_3d_gt.npy'), np.array(gt['track_3d'][0].transpose(1,0).cpu()))
message = f'Test, Sequence: [{n}]/[{len(testDataset)}]'
print(message)