Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix wrong partition to types in DROP evaluation (#3263)
Browse files Browse the repository at this point in the history
* Fix wrong partition to types in drop evaluation

* add a simple test

* add another case so the issue is tested regardless of answers order
  • Loading branch information
eladsegal authored and matt-gardner committed Sep 19, 2019
1 parent 41a4776 commit daed835
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
22 changes: 22 additions & 0 deletions allennlp/tests/tools/drop_eval_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# pylint: disable=no-self-use,invalid-name

import io
from contextlib import redirect_stdout

from allennlp.tools.drop_eval import _normalize_answer, get_metrics, evaluate_json

class TestDropEvalNormalize:
Expand Down Expand Up @@ -139,3 +142,22 @@ def test_json_loader(self):
{"answer": {"spans": ["answer2"]}, "query_id":"qid2"}]}}
prediction = {"qid1": "answer", "qid2": "answer2"}
assert evaluate_json(annotation, prediction) == (0.5, 0.5)

def test_type_partition_output(self):
annotation = {"pid1": {"qa_pairs":[{"answer": {"number": "5"}, "validated_answers": \
[{"spans": ["7-meters"]}], "query_id":"qid1"}]}}
prediction = {"qid1": "5-yard"}
with io.StringIO() as buf, redirect_stdout(buf):
evaluate_json(annotation, prediction)
output = buf.getvalue()
lines = output.strip().split("\n")
assert lines[4] == 'number: 1 (100.00%)'

annotation = {"pid1": {"qa_pairs":[{"answer": {"spans": ["7-meters"]}, "validated_answers": \
[{"number": "5"}], "query_id":"qid1"}]}}
prediction = {"qid1": "5-yard"}
with io.StringIO() as buf, redirect_stdout(buf):
evaluate_json(annotation, prediction)
output = buf.getvalue()
lines = output.strip().split("\n")
assert lines[4] == 'number: 1 (100.00%)'
2 changes: 1 addition & 1 deletion allennlp/tools/drop_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def evaluate_json(annotations: Dict[str, Any], predicted_answers: Dict[str, Any]
if gold_answer[0].strip() != "":
max_em_score = max(max_em_score, em_score)
max_f1_score = max(max_f1_score, f1_score)
if max_em_score == em_score or max_f1_score == f1_score:
if max_em_score == em_score and max_f1_score == f1_score:
max_type = gold_type
else:
print("Missing prediction for question: {}".format(query_id))
Expand Down

0 comments on commit daed835

Please sign in to comment.