Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Makes the evaluate command work for the multitask case (Second Edit…
Browse files Browse the repository at this point in the history
…ion) (#5579)

* Adds the ability to evaluate on JSON blobs

* Formatting
  • Loading branch information
dirkgr authored Feb 28, 2022
1 parent 9f03803 commit 3fa5193
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- We can now transparently read compressed input files during prediction.
- LZMA compression is now supported.
- Added a way to give JSON blobs as input to dataset readers in the `evaluate` command.
- Added the argument `sub_module` in `PretrainedTransformerMismatchedEmbedder`


Expand Down
32 changes: 24 additions & 8 deletions allennlp/commands/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import json
import logging
from json import JSONDecodeError
from pathlib import Path
from os import PathLike
from typing import Union, Dict, Any, Optional
Expand Down Expand Up @@ -35,14 +36,14 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
subparser.add_argument(
"input_file",
type=str,
help="path to the file containing the evaluation data (for mutiple "
help="path to the file containing the evaluation data (for multiple "
"files, put between filenames e.g., input1.txt,input2.txt)",
)

subparser.add_argument(
"--output-file",
type=str,
help="optional path to write the metrics to as JSON (for mutiple "
help="optional path to write the metrics to as JSON (for multiple "
"files, put between filenames e.g., output1.txt,output2.txt)",
)

Expand Down Expand Up @@ -258,17 +259,26 @@ def evaluate_from_archive(
dataset_reader = archive.validation_dataset_reader

# split files
evaluation_data_path_list = input_file.split(",")
try:
# Try reading it as a list of JSON objects first. Some readers require
# that kind of input.
evaluation_data_path_list = json.loads(f"[{input_file}]")
except JSONDecodeError:
evaluation_data_path_list = input_file.split(",")

# TODO(gabeorlanski): Is it safe to always default to .outputs and .preds?
# TODO(gabeorlanski): Add in way to save to specific output directory
if metrics_output_file is not None:
if auto_names == "METRICS" or auto_names == "ALL":
logger.warning(
f"Passed output_files will be ignored, auto_names is" f" set to {auto_names}"
f"Passed output_files will be ignored, auto_names is set to {auto_names}"
)

# Keep the path of the parent otherwise it will write to the CWD
assert all(isinstance(p, str) for p in evaluation_data_path_list), (
"When specifying JSON blobs as input, the output files must be explicitly named with "
"--output-file."
)
output_file_list = [
p.parent.joinpath(f"{p.stem}.outputs") for p in map(Path, evaluation_data_path_list)
]
Expand All @@ -285,6 +295,10 @@ def evaluate_from_archive(
)

# Keep the path of the parent otherwise it will write to the CWD
assert all(isinstance(p, str) for p in evaluation_data_path_list), (
"When specifying JSON blobs as input, the predictions output files must be explicitly named with "
"--predictions-output-file."
)
predictions_output_file_list = [
p.parent.joinpath(f"{p.stem}.preds") for p in map(Path, evaluation_data_path_list)
]
Expand All @@ -307,13 +321,15 @@ def evaluate_from_archive(
)

all_metrics = {}
for index in range(len(evaluation_data_path_list)):
for index, evaluation_data_path in enumerate(evaluation_data_path_list):
config = deepcopy(archive.config)
evaluation_data_path = evaluation_data_path_list[index]

# Get the eval file name so we can save each metric by file name in the
# output dictionary.
eval_file_name = Path(evaluation_data_path).stem
if isinstance(evaluation_data_path, str):
eval_file_name = Path(evaluation_data_path).stem
else:
eval_file_name = str(index)

if metrics_output_file is not None:
# noinspection PyUnboundLocalVariable
Expand All @@ -323,7 +339,7 @@ def evaluate_from_archive(
# noinspection PyUnboundLocalVariable
predictions_output_file_path = predictions_output_file_list[index]

logger.info("Reading evaluation data from %s", evaluation_data_path)
logger.info("Reading evaluation data from %s", eval_file_name)
data_loader_params = config.get("validation_data_loader", None)
if data_loader_params is None:
data_loader_params = config.get("data_loader")
Expand Down
1 change: 0 additions & 1 deletion test_fixtures/basic_classifier/common.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"train_data_path": "test_fixtures/data/text_classification_json/imdb_corpus.jsonl",
"validation_data_path": "test_fixtures/data/text_classification_json/imdb_corpus.jsonl",
"data_loader": {

"batch_sampler": {
"type": "bucket",
"batch_size": 5
Expand Down
2 changes: 1 addition & 1 deletion tests/commands/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_evaluate_works_with_vocab_expansion(self):
kebab_args = ["evaluate", archive_path, evaluate_data_path, "--cuda-device", "-1"]

# TODO(mattg): the unawarded_embeddings.gz file above doesn't exist, but this test still
# passes. This suggests that vocab extension in evaluate isn't currently doing anything,
# passes. This suggests that vocab extension in evaluate isn't currently doing anything,
# and so it is broken.

# Evaluate 1 with no vocab expansion,
Expand Down
84 changes: 84 additions & 0 deletions tests/models/multitask_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os

import pytest

from allennlp.common import Params
from allennlp.common.testing import ModelTestCase
from allennlp.data import Instance, Vocabulary, Batch
from allennlp.data.fields import LabelField, TextField, MetadataField
Expand Down Expand Up @@ -101,3 +104,84 @@ def test_forward_works(self):
)
with pytest.raises(ValueError, match="duplicate argument text"):
outputs = model.forward_on_instance(instance)

def test_train_and_evaluate(self):
from allennlp.commands.train import train_model
from allennlp.commands.evaluate import evaluate_from_args
import argparse
from allennlp.commands import Evaluate

model_name = "epwalsh/bert-xsmall-dummy"

def reader():
return {
"type": "text_classification_json",
"tokenizer": {"type": "pretrained_transformer", "model_name": model_name},
"token_indexers": {
"tokens": {"type": "pretrained_transformer", "model_name": model_name}
},
}

def head():
return {
"type": "classifier",
"seq2vec_encoder": {"type": "cls_pooler", "embedding_dim": 20},
"input_dim": 20,
"num_labels": 2,
}

head_eins_input = "test_fixtures/data/text_classification_json/imdb_corpus.jsonl"
head_zwei_input = (
"test_fixtures/data/text_classification_json/ag_news_corpus_fake_sentiment_labels.jsonl"
)

params = Params(
{
"dataset_reader": {
"type": "multitask",
"readers": {
"head_eins": reader(),
"head_zwei": reader(),
},
},
"model": {
"type": "multitask",
"backbone": {"type": "pretrained_transformer", "model_name": model_name},
"heads": {
"head_eins": head(),
"head_zwei": head(),
},
"arg_name_mapping": {"backbone": {"tokens": "text"}},
},
"train_data_path": {"head_eins": head_eins_input, "head_zwei": head_zwei_input},
"data_loader": {"type": "multitask", "scheduler": {"batch_size": 2}},
"trainer": {
"optimizer": {
"type": "huggingface_adamw",
"lr": 4e-5,
},
"num_epochs": 2,
},
}
)

serialization_dir = os.path.join(self.TEST_DIR, "serialization_dir")
train_model(params, serialization_dir=serialization_dir)

args = [
"evaluate",
str(self.TEST_DIR / "serialization_dir"),
f'{{"head_eins": "{head_eins_input}", "head_zwei": "{head_zwei_input}"}}',
"--output-file",
str(self.TEST_DIR / "output.txt"),
"--predictions-output-file",
str(self.TEST_DIR / "predictions.json"),
]

parser = argparse.ArgumentParser(description="Testing")
subparsers = parser.add_subparsers(title="Commands", metavar="")
Evaluate().add_subparser(subparsers)
args = parser.parse_args(args)
metrics = evaluate_from_args(args)
assert "head_eins_accuracy" in metrics
assert "head_zwei_accuracy" in metrics

0 comments on commit 3fa5193

Please sign in to comment.