-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathEVQA.py
59 lines (43 loc) · 1.48 KB
/
EVQA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
class EVQA(nn.Module):
def __init__(self, vid_encoder, qns_encoder, device):
"""
:param vid_encoder:
:param qns_encoder:
:param ans_decoder:
:param device:
"""
super(EVQA, self).__init__()
self.vid_encoder = vid_encoder
self.qns_encoder = qns_encoder
self.device = device
self.FC = nn.Linear(qns_encoder.dim_hidden, 1)
def forward(self, vid_feats, qas, qas_lengths):
"""
:param vid_feats:
:param qns:
:param qns_lengths:
:param mode:
:return:
"""
if self.qns_encoder.use_bert:
cand_qas = qas.permute(1, 0, 2, 3) # for BERT
else:
cand_qas = qas.permute(1, 0, 2)
cand_len = qas_lengths.permute(1, 0)
out = []
for idx, qa in enumerate(cand_qas):
encoder_out = self.vq_encoder(vid_feats, qa, cand_len[idx])
out.append(encoder_out)
out = torch.stack(out, 0).transpose(1, 0)
_, predict_idx = torch.max(out, 1)
return out, predict_idx
def vq_encoder(self, vid_feats, qns, qns_lengths):
vid_outputs, vid_hidden = self.vid_encoder(vid_feats)
qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths)
qns_embed = qns_hidden[0].squeeze()
vid_embed = vid_hidden[0].squeeze()
fuse = qns_embed + vid_embed
outputs = self.FC(fuse).squeeze()
return outputs