-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun_blip2_infoseek.py
122 lines (102 loc) · 4.57 KB
/
run_blip2_infoseek.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""BLIP2 Zeroshot Inference - InfosEEK"""
import os
import json
import torch
from PIL import Image
from lavis.models import load_model_and_preprocess
from multiprocessing import Pool
import argparse
from tqdm import tqdm
import time
def load_and_process_image(item):
# Load and preprocess the image
raw_image = Image.open(id2path[item["image_id"]]).convert("RGB")
processed_image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
return processed_image, item["question"], item["data_id"]
def process_images_in_batches(batch_data, batch_size, prompt):
# Create a pool of workers
# Monitor the progress of the pool
output = []
print("Generate predictions...")
# Process images in batches
for idx, i in enumerate(range(0, len(batch_data), batch_size)):
if (idx + 1) % 100 == 0:
print(f"Processing batch {idx}/{len(batch_data)/batch_size}")
# Subset results for the current batch
batch_subset = batch_data[i:i+batch_size]
# Separate the images, questions, and ids
batch_images, batch_questions, batch_ids = [], [], []
# Load and preprocess the images
for item in batch_subset:
tmp_img, tmp_q, tmp_id = load_and_process_image(item)
batch_images.append(tmp_img)
batch_questions.append(tmp_q)
batch_ids.append(tmp_id)
# Concatenate the batch images
image_batch = torch.cat(batch_images, dim=0)
# add prompt to questions
batch_questions = [prompt.format(q) for q in batch_questions]
# Generate predictions for the batch
start_time = time.time()
answers = model.generate({"image": image_batch, "prompt": batch_questions},
length_penalty=-1)
print(f"Time for batch {idx}: {time.time() - start_time}")
for idx, ans in zip(batch_ids, answers):
output.append({"data_id": idx, "prediction": ans})
return output
if __name__ == "__main__":
# argparse
parser = argparse.ArgumentParser()
parser.add_argument("--split", type=str, default="val", help="val, test, or human")
parser.add_argument("--model_name", type=str, default="blip2_t5", help="blip2_t5 | blip2_vicuna_instruct | blip2_t5_instruct")
parser.add_argument("--model_type", type=str, default="pretrain_flant5xxl", help="pretrain_flant5xxl | vicuna13b | flant5xxl")
parser.add_argument("--output_dir", type=str, default="predictions", help="output directory")
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
args = parser.parse_args()
split2data = {
"val": "infoseek/infoseek_val.jsonl",
"test": "infoseek/infoseek_test.jsonl",
"human": "infoseek/infoseek_human.jsonl"
}
id2path = dict()
# load image paths: Prepare a jsonl file to map image_id to image_path
with open("id2image.jsonl", "r") as f:
for line in f:
line = json.loads(line)
image_id = line["image_id"]
path = line["image_path"]
id2path[image_id] = path
# Read the input JSONL file
with open(split2data[args.split], 'r') as f:
batch_data = [json.loads(line) for line in f]
# double check data exists:
not_exist = []
clean_batch_data = []
for idx, item in enumerate(batch_data):
if idx % 10000 == 0:
print(f"Processing {idx}/{len(batch_data)}")
path = id2path[item["image_id"]]
# check path exists
if not os.path.exists(path):
not_exist.append(item["image_id"])
else:
clean_batch_data.append(item)
print(len(not_exist))
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
print("Load pretrained model...")
# loads BLIP-2 pre-trained model
model, vis_processors, _ = load_model_and_preprocess(name=args.model_name,
model_type=args.model_type,
is_eval=True, device=device)
# Desired batch size
batch_size = args.batch_size
PROMPT = "Question: {} Short answer:"
# Run the batch processing function
output = process_images_in_batches(clean_batch_data, batch_size, prompt=PROMPT)
# save output into jsonl
with open(os.path.join(args.output_dir, "zeroshot_{}_{}_{}.jsonl".format(
args.model_name, args.model_type, args.split
)), 'w') as f:
for item in output:
f.write(json.dumps(item, ensure_ascii=False) + "\n")