Skip to content

Commit

Permalink
update vcr_wiki tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
sheryc committed Jun 10, 2024
1 parent db8e718 commit 205721e
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 71 deletions.
45 changes: 22 additions & 23 deletions lmms_eval/tasks/vcr_wiki/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,21 @@
# Download the English and Chinese models
try:
nlp_en = spacy.load("en_core_web_sm")
except:
except Exception as e:
download("en_core_web_sm")
nlp_en = spacy.load("en_core_web_sm")

try:
nlp_zh = spacy.load("zh_core_web_sm")
except:
except Exception as e:
download("zh_core_web_sm")
nlp_zh = spacy.load("zh_core_web_sm")

eval_logger = logging.getLogger("lmms-eval")

dir_name = os.path.dirname(os.path.abspath(__file__))

nlp = {"en": nlp_en, "zh": nlp_zh}
rouge = evaluate.load("rouge")

nlp = {"en": nlp_en, "zh": nlp_zh}
eval_logger = logging.getLogger("lmms-eval")
dir_name = os.path.dirname(os.path.abspath(__file__))

aggregate_results_template = {
"max_sim_val": 0,
Expand Down Expand Up @@ -89,7 +88,7 @@ def tokenize(text, language):
def vcr_process_results_single(crossed_text, result, language):
"""
Args:
doc: a instance of the eval dataset
doc: an instance of the eval dataset
results: [pred]
Returns:
a dictionary with key: metric name (in this case vcr score), value: metric value
Expand Down Expand Up @@ -180,8 +179,8 @@ def vcr_process_results_single(crossed_text, result, language):
def vcr_en_process_results(doc, results):
"""
Args:
doc: a instance of the eval dataset
results: [pred]
doc: an instance of the eval dataset
results: [pred], with length = 1
Returns:
a dictionary with key: metric name (in this case vcr score), value: metric value
"""
Expand All @@ -196,13 +195,13 @@ def vcr_en_process_results(doc, results):
}
crossed_text = doc["crossed_text"]
for i in range(len(crossed_text)):
tmp = vcr_process_results_single(crossed_text[i], results, "en")
tmp = vcr_process_results_single(crossed_text[i], results[0], "en")
for k in output.keys():
output[k].append(
{
"score": tmp[k],
"max_sim_string": tmp["max_sim_string"],
"crossed_text": crossed_text[i],
"pred_ngram": tmp["max_sim_string"],
"gt_ngram": crossed_text[i],
"caption": doc["caption"],
}
)
Expand All @@ -212,10 +211,10 @@ def vcr_en_process_results(doc, results):
def vcr_zh_process_results(doc, results):
"""
Args:
doc: a instance of the eval dataset
results: [pred]
doc: an instance of the eval dataset
results: [pred], with length = 1
Returns:
a dictionary with key: metric name (in this case vcr score), value: metric value
a dictionary with key: metric name (in this case vcr score), value: metric value and other info
"""
output = {
"max_sim_val": [],
Expand All @@ -228,13 +227,13 @@ def vcr_zh_process_results(doc, results):
}
crossed_text = doc["crossed_text"]
for i in range(len(crossed_text)):
tmp = vcr_process_results_single(crossed_text[i], results, "zh")
tmp = vcr_process_results_single(crossed_text[i], results[0], "zh")
for k in output.keys():
output[k].append(
{
"score": tmp[k],
"max_sim_string": tmp["max_sim_string"],
"crossed_text": crossed_text[i],
"pred_ngram": tmp["max_sim_string"],
"gt_ngram": crossed_text[i],
"caption": doc["caption"],
}
)
Expand All @@ -244,9 +243,9 @@ def vcr_zh_process_results(doc, results):
def vcr_aggregate_results(results, args):
"""
Args:
results: a list of values returned by process_results
results: List[List[Dict]], list of results returned by process_results
Returns:
A dictionary of dictionary of float, where the outer dictionary has keys "res_stacked_image" and "res_only_it_image"
A float value representing the final score of jaccard index or exact match
"""
scores = 0
count = 0
Expand All @@ -259,8 +258,8 @@ def vcr_aggregate_results(results, args):

now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
path = generate_submission_file(f"vcr_submission_{now_date_time}.json", args)
with open(path, "w") as f:
json.dump(output_dict, f)
with open(path, "w", encoding='utf-8') as f:
json.dump(output_dict, f, indent=4, ensure_ascii=False)
# print(f"Submission file saved to {path}")
eval_logger.info(f"Submission file saved to {path}")
return scores / count
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-easy-test
task: "vcr_wiki_en_easy_5000"
test_split: train
task: "vcr_wiki_en_easy"
test_split: test
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_easy_100.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-easy-test
dataset_path: vcr-org/VCR-wiki-en-easy-test-100
task: "vcr_wiki_en_easy_100"
test_split: train[:100]
test_split: test
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_easy_500.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-easy-test
dataset_path: vcr-org/VCR-wiki-en-easy-test-500
task: "vcr_wiki_en_easy_500"
test_split: train[:500]
test_split: test
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-hard-test
task: "vcr_wiki_en_hard_5000"
test_split: train
task: "vcr_wiki_en_hard"
test_split: test
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_hard_100.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-hard-test
dataset_path: vcr-org/VCR-wiki-en-hard-test-100
task: "vcr_wiki_en_hard_100"
test_split: train[:100]
test_split: test
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_en_hard_500.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-en-hard-test
dataset_path: vcr-org/VCR-wiki-en-hard-test-500
task: "vcr_wiki_en_hard_500"
test_split: train[:500]
test_split: test
process_results: !function utils.vcr_en_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_en_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-zh-easy-test
task: "vcr_wiki_zh_easy_5000"
test_split: train
task: "vcr_wiki_zh_easy"
test_split: test
process_results: !function utils.vcr_zh_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_zh_easy_100.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-zh-easy-test
dataset_path: vcr-org/VCR-wiki-zh-easy-test-100
task: "vcr_wiki_zh_easy_100"
test_split: train[:100]
test_split: test
process_results: !function utils.vcr_zh_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_zh_easy_500.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-zh-easy-test
dataset_path: vcr-org/VCR-wiki-zh-easy-test-500
task: "vcr_wiki_zh_easy_500"
test_split: train[:500]
test_split: test
process_results: !function utils.vcr_zh_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-zh-hard-test
task: "vcr_wiki_zh_hard_5000"
test_split: train
task: "vcr_wiki_zh_hard"
test_split: test
process_results: !function utils.vcr_zh_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_zh_hard_100.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-zh-hard-test
dataset_path: vcr-org/VCR-wiki-zh-hard-test-100
task: "vcr_wiki_zh_hard_100"
test_split: train[:100]
test_split: test
process_results: !function utils.vcr_zh_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down
8 changes: 4 additions & 4 deletions lmms_eval/tasks/vcr_wiki/vcr_wiki_zh_hard_500.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"include": "_default_template_vcr_yaml"
dataset_path: vcr-org/VCR-wiki-zh-hard-test
dataset_path: vcr-org/VCR-wiki-zh-hard-test-500
task: "vcr_wiki_zh_hard_500"
test_split: train[:500]
test_split: test
process_results: !function utils.vcr_zh_process_results
metric_list:
- metric: jaccard
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: !function utils.vcr_zh_process_results
aggregation: !function utils.vcr_aggregate_results
higher_is_better: true
model_specific_prompt_kwargs:
default:
Expand Down

0 comments on commit 205721e

Please sign in to comment.