-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathassign_answer.py
288 lines (249 loc) · 9.27 KB
/
assign_answer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
#%%
# Imports
import argparse
import json
import torch
from torch.utils.data import DataLoader
from dataset import Dictionary, VQAFeatureDataset
from aug_dataset import VQAAugFeatureDataset
import base_modelKD
from train import train
import utils
from vqa_debias_loss_functions import *
from tqdm import tqdm
from torch.autograd import Variable
import pickle
import os
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import random
def parse_args():
parser = argparse.ArgumentParser("Assign Label For Low Quality/High Quality DATA")
parser.add_argument(
'--dataset', default='cpv2',
choices=["v2", "cpv2"],
help="datset name"
)
parser.add_argument(
'--name', default='other',
choices=['number', 'color', 'other', 'paraphrasing', 'yesno', 'paired'],
help="augmented dataset name"
)
parser.add_argument(
'--split', default='low',
choices=['low', 'high'],
help="low quality or high quality dataset"
)
parser.add_argument(
'--teacher_path', default='./logs/lmh_css/model.pth',
type=str,
help="Path of teacher model"
)
args = parser.parse_args()
return args
args = parse_args()
dataset = args.dataset
name = args.name
split = args.split
if dataset == 'v2':
with open('./aug_data/v2_original_dataset.pkl', 'rb') as f:
original_dataset = pickle.load(f)
else:
with open('./aug_data/original_dataset.pkl', 'rb') as f:
original_dataset = pickle.load(f)
print('DATASET LEN', len(original_dataset))
# load label
if dataset == 'v2':
cache_file = os.path.join('data', 'cache', 'trainval_label2ans.pkl')
label2ans = pickle.load(open(cache_file, 'rb'))
else:
cache_file = os.path.join('data', 'cp-cache', 'trainval_label2ans.pkl')
label2ans = pickle.load(open(cache_file, 'rb'))
#load qid2type
qid2qtype = {}
qid2type = {}
if dataset == 'v2':
# get question type
v2_question_annotation = json.load(open('./data/v2_mscoco_train2014_annotations.json', 'r'))['annotations']
for anno in v2_question_annotation:
qid = anno['question_id']
qtype = anno['question_type'].lower()
qid2qtype[qid] = qtype
qid2type[qid] = anno['answer_type']
else:
# dataset = 'cpv2'
cpv2_question_annotation = json.load(open('./data/vqacp_v2_train_annotations.json', 'r'))
for anno in cpv2_question_annotation:
qid = anno['question_id']
qtype = anno['question_type'].lower()
qid2qtype[qid] = qtype
qid2type[qid] = anno['answer_type']
#%%
# handle sentence function
def handle(sentence:str):
sentence = sentence.lower()
sentence = sentence.replace(',', '').replace('?', '').replace('\'s', ' \'s').\
replace('-',' ').replace('.','').replace('"', '').replace('n\'t', ' not').\
replace('$', ' dollar ')
return sentence
#%%
from tqdm import tqdm
# collect all image information
image_info = {}
for i in tqdm(range(len(original_dataset)), ncols=100, total=len(original_dataset)):
entry = original_dataset[i]
img_id = entry['img_id']
if image_info.get(img_id, None) is None:
info = {
'objects': entry['objects'],
'attributes': entry['attributes'],
}
image_info[img_id] = info
# Collect question information
question_info = {}
for i in tqdm(range(len(original_dataset)), ncols=100, total=len(original_dataset)):
entry = original_dataset[i]
question = handle(entry['question'])
if question_info.get(question, None) is not None:
question_info[question]['entry_idxs'].append(i)
continue
info = {
'nouns': entry['nouns'],
'ori_nouns': entry['ori_nouns'],
'qtype': qid2qtype[entry['q_id']],
'type': qid2type[entry['q_id']],
'entry_idxs': [i],
'returned_imgs': [],
}
question_info[question] = info
# Get language bias, which is an input of CSS teacher model.
print('Get language bias, which is an input of CSS teacher model.')
dictionary = Dictionary.load_from_file('data/dictionary.pkl')
train_dset = VQAFeatureDataset('train', dictionary, dataset=dataset, cache_image_features=False)
# get bias
answer_voc_size = train_dset.num_ans_candidates
# question_type -> answer -> total score
question_type_to_probs = defaultdict(Counter)
# question_type -> num_occurances
question_type_to_count = Counter()
for ex in train_dset.entries:
ans = ex["answer"]
q_type = ans["question_type"]
question_type_to_count[q_type] += 1
if ans["labels"] is not None:
for label, score in zip(ans["labels"], ans["scores"]):
question_type_to_probs[q_type][label] += score
question_type_to_prob_array = {}
for q_type, count in question_type_to_count.items():
prob_array = np.zeros(answer_voc_size, np.float32)
for label, total_score in question_type_to_probs[q_type].items():
prob_array[label] += total_score
prob_array /= count
question_type_to_prob_array[q_type] = prob_array
print('Load model from:', args.teacher_path)
constructor = 'build_baseline0_newatt'
ood_model = getattr(base_modelKD, constructor)(train_dset, 1024).cuda()
ood_model.debias_loss_fn = LearnedMixinKD()
model_state = torch.load(args.teacher_path)
ood_model.load_state_dict(model_state)
ood_model = ood_model.cuda().eval()
if split == 'high':
load_name = 'high_' + name
clean_name = 'high_clean_' + name
else:
load_name = 'low_' + name
clean_name = 'clean_' + name
low_assign_flag = (load_name.find('low') != -1)
print("LOAD PATH: ", load_name)
print("SAVE PATH: ", clean_name)
# dataset
aug_dset = VQAAugFeatureDataset(load_name, dictionary, cache_image_features=True, dataset=dataset)
eval_loader = DataLoader(aug_dset, 512, shuffle=False, num_workers=0)
#%%
for ex in aug_dset.entries:
if name == 'color':
q_type = ex["qtype"]
else:
q_type = question_info[ex['question']]['qtype']
ex["bias"] = question_type_to_prob_array[q_type]
#%%
# laod chosen yesno aug dataset
with open('./aug_data/' + VQAAugFeatureDataset.path[dataset][load_name], 'rb') as f:
low_quality_dataset = pickle.load(f)
#%%
for entry in low_quality_dataset:
entry['ori_answer_text'] = entry['answer_text']
#%%
# begin assign label
print('Predict to get ood prediction and id prediction.')
begin_idx = 0
logsigmoid = torch.nn.LogSigmoid()
with torch.no_grad():
for v, q, a, b in tqdm(eval_loader, ncols=100, total=len(eval_loader), desc="eval"):
v = Variable(v, requires_grad=False).cuda()
q = Variable(q, requires_grad=False).cuda()
b = Variable(b, requires_grad=False).cuda()
a = Variable(a, requires_grad=False).cuda()
# id_pred, _, _ = id_model(v, q, None, None, None)
ood_pred, id_pred, _, _ = ood_model(v, q, a, b, None)
# calculate weight
bb = a # a/torch.clamp(a.sum(1, keepdim=True), min=1e-24)
if low_assign_flag:
bb = b
s_id = 1 / (-bb * logsigmoid(id_pred) - (1 - bb) * logsigmoid(-id_pred)).sum(dim=1)
s_ood = 1 / (-bb * logsigmoid(ood_pred) - (1 - bb) * logsigmoid(-ood_pred)).sum(dim=1)
w_id = s_ood / (s_id + s_ood)
w_ood = s_id / (s_id + s_ood)
_, id_ans = torch.max(id_pred, dim=1)
_, ood_ans = torch.max(ood_pred, dim=1)
id_ans = id_ans.cpu().numpy()
ood_ans = ood_ans.cpu().numpy()
w_id = w_id.cpu().numpy()
w_ood = w_ood.cpu().numpy()
if not low_assign_flag:
id_pred = a.cpu().numpy()
else:
id_pred = id_pred.sigmoid().cpu().numpy()
ood_pred = ood_pred.sigmoid().cpu().numpy()
for i in range(len(v)):
idx = begin_idx + i
low_quality_dataset[idx]['id_ans'] = label2ans[id_ans[i]]
low_quality_dataset[idx]['ood_ans'] = label2ans[ood_ans[i]]
low_quality_dataset[idx]['id_w'] = w_id[i]
low_quality_dataset[idx]['ood_w'] = w_ood[i]
low_quality_dataset[idx]['id_preds'] = id_pred[i]
low_quality_dataset[idx]['ood_preds'] = ood_pred[i]
begin_idx = begin_idx + len(v)
#%%
print('Assign answer')
for entry in low_quality_dataset:
id_ans = entry['id_ans']
ood_ans = entry['ood_ans']
# by updn teacher
# entry['answer_text'] = [ood_ans]
# entry['scores'] = [1.0]
# # by two teacher
# if id_ans == ood_ans:
# entry['answer_text'] = [ood_ans]
# entry['scores'] = [1.0]
# else:
# entry['answer_text'] = [ood_ans, id_ans]
# entry['scores'] = [entry['ood_w'], entry['id_w']]
# soft version by two teacher
entry['logits'] = entry['id_w'] * entry['id_preds'] + entry['ood_w'] * entry['ood_preds']
# entry['logits'] = 0.5 * entry['id_preds'] + 0.5 * entry['ood_preds'] # simple average
# entry['logits'] = entry['id_preds'] # id weight 1
# entry['logits'] = entry['ood_preds'] # ood weight 1
print('Save')
valid_keys = ['q_id', 'img_id', 'question', 'answer_text', 'scores', 'qtype', 'ori_answer_text', 'logits', 'nouns']
for entry in low_quality_dataset:
delete_keys = []
for key in entry.keys():
if key in valid_keys:
continue
delete_keys.append(key)
for key in delete_keys:
entry.pop(key)
with open('./aug_data/' + VQAAugFeatureDataset.path[dataset][clean_name], 'wb') as f:
pickle.dump(low_quality_dataset, f)