diff --git a/CHANGELOG.md b/CHANGELOG.md index 40e6ef3fd7d..54b9d66b7cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - We can now transparently read compressed input files during prediction. - LZMA compression is now supported. +- Added a way to give JSON blobs as input to dataset readers in the `evaluate` command. - Added the argument `sub_module` in `PretrainedTransformerMismatchedEmbedder` diff --git a/allennlp/commands/evaluate.py b/allennlp/commands/evaluate.py index 2b9403a417f..ac7c6032b2b 100644 --- a/allennlp/commands/evaluate.py +++ b/allennlp/commands/evaluate.py @@ -7,6 +7,7 @@ import argparse import json import logging +from json import JSONDecodeError from pathlib import Path from os import PathLike from typing import Union, Dict, Any, Optional @@ -35,14 +36,14 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument subparser.add_argument( "input_file", type=str, - help="path to the file containing the evaluation data (for mutiple " + help="path to the file containing the evaluation data (for multiple " "files, put between filenames e.g., input1.txt,input2.txt)", ) subparser.add_argument( "--output-file", type=str, - help="optional path to write the metrics to as JSON (for mutiple " + help="optional path to write the metrics to as JSON (for multiple " "files, put between filenames e.g., output1.txt,output2.txt)", ) @@ -258,17 +259,26 @@ def evaluate_from_archive( dataset_reader = archive.validation_dataset_reader # split files - evaluation_data_path_list = input_file.split(",") + try: + # Try reading it as a list of JSON objects first. Some readers require + # that kind of input. + evaluation_data_path_list = json.loads(f"[{input_file}]") + except JSONDecodeError: + evaluation_data_path_list = input_file.split(",") # TODO(gabeorlanski): Is it safe to always default to .outputs and .preds? # TODO(gabeorlanski): Add in way to save to specific output directory if metrics_output_file is not None: if auto_names == "METRICS" or auto_names == "ALL": logger.warning( - f"Passed output_files will be ignored, auto_names is" f" set to {auto_names}" + f"Passed output_files will be ignored, auto_names is set to {auto_names}" ) # Keep the path of the parent otherwise it will write to the CWD + assert all(isinstance(p, str) for p in evaluation_data_path_list), ( + "When specifying JSON blobs as input, the output files must be explicitly named with " + "--output-file." + ) output_file_list = [ p.parent.joinpath(f"{p.stem}.outputs") for p in map(Path, evaluation_data_path_list) ] @@ -285,6 +295,10 @@ def evaluate_from_archive( ) # Keep the path of the parent otherwise it will write to the CWD + assert all(isinstance(p, str) for p in evaluation_data_path_list), ( + "When specifying JSON blobs as input, the predictions output files must be explicitly named with " + "--predictions-output-file." + ) predictions_output_file_list = [ p.parent.joinpath(f"{p.stem}.preds") for p in map(Path, evaluation_data_path_list) ] @@ -307,13 +321,15 @@ def evaluate_from_archive( ) all_metrics = {} - for index in range(len(evaluation_data_path_list)): + for index, evaluation_data_path in enumerate(evaluation_data_path_list): config = deepcopy(archive.config) - evaluation_data_path = evaluation_data_path_list[index] # Get the eval file name so we can save each metric by file name in the # output dictionary. - eval_file_name = Path(evaluation_data_path).stem + if isinstance(evaluation_data_path, str): + eval_file_name = Path(evaluation_data_path).stem + else: + eval_file_name = str(index) if metrics_output_file is not None: # noinspection PyUnboundLocalVariable @@ -323,7 +339,7 @@ def evaluate_from_archive( # noinspection PyUnboundLocalVariable predictions_output_file_path = predictions_output_file_list[index] - logger.info("Reading evaluation data from %s", evaluation_data_path) + logger.info("Reading evaluation data from %s", eval_file_name) data_loader_params = config.get("validation_data_loader", None) if data_loader_params is None: data_loader_params = config.get("data_loader") diff --git a/test_fixtures/basic_classifier/common.jsonnet b/test_fixtures/basic_classifier/common.jsonnet index 86572153a9d..615fe6b335f 100644 --- a/test_fixtures/basic_classifier/common.jsonnet +++ b/test_fixtures/basic_classifier/common.jsonnet @@ -16,7 +16,6 @@ "train_data_path": "test_fixtures/data/text_classification_json/imdb_corpus.jsonl", "validation_data_path": "test_fixtures/data/text_classification_json/imdb_corpus.jsonl", "data_loader": { - "batch_sampler": { "type": "bucket", "batch_size": 5 diff --git a/tests/commands/evaluate_test.py b/tests/commands/evaluate_test.py index eebf7753453..3ca2c967e68 100644 --- a/tests/commands/evaluate_test.py +++ b/tests/commands/evaluate_test.py @@ -163,7 +163,7 @@ def test_evaluate_works_with_vocab_expansion(self): kebab_args = ["evaluate", archive_path, evaluate_data_path, "--cuda-device", "-1"] # TODO(mattg): the unawarded_embeddings.gz file above doesn't exist, but this test still - # passes. This suggests that vocab extension in evaluate isn't currently doing anything, + # passes. This suggests that vocab extension in evaluate isn't currently doing anything, # and so it is broken. # Evaluate 1 with no vocab expansion, diff --git a/tests/models/multitask_test.py b/tests/models/multitask_test.py index 43ffc50a33a..b5257e4ff7c 100644 --- a/tests/models/multitask_test.py +++ b/tests/models/multitask_test.py @@ -1,5 +1,8 @@ +import os + import pytest +from allennlp.common import Params from allennlp.common.testing import ModelTestCase from allennlp.data import Instance, Vocabulary, Batch from allennlp.data.fields import LabelField, TextField, MetadataField @@ -101,3 +104,84 @@ def test_forward_works(self): ) with pytest.raises(ValueError, match="duplicate argument text"): outputs = model.forward_on_instance(instance) + + def test_train_and_evaluate(self): + from allennlp.commands.train import train_model + from allennlp.commands.evaluate import evaluate_from_args + import argparse + from allennlp.commands import Evaluate + + model_name = "epwalsh/bert-xsmall-dummy" + + def reader(): + return { + "type": "text_classification_json", + "tokenizer": {"type": "pretrained_transformer", "model_name": model_name}, + "token_indexers": { + "tokens": {"type": "pretrained_transformer", "model_name": model_name} + }, + } + + def head(): + return { + "type": "classifier", + "seq2vec_encoder": {"type": "cls_pooler", "embedding_dim": 20}, + "input_dim": 20, + "num_labels": 2, + } + + head_eins_input = "test_fixtures/data/text_classification_json/imdb_corpus.jsonl" + head_zwei_input = ( + "test_fixtures/data/text_classification_json/ag_news_corpus_fake_sentiment_labels.jsonl" + ) + + params = Params( + { + "dataset_reader": { + "type": "multitask", + "readers": { + "head_eins": reader(), + "head_zwei": reader(), + }, + }, + "model": { + "type": "multitask", + "backbone": {"type": "pretrained_transformer", "model_name": model_name}, + "heads": { + "head_eins": head(), + "head_zwei": head(), + }, + "arg_name_mapping": {"backbone": {"tokens": "text"}}, + }, + "train_data_path": {"head_eins": head_eins_input, "head_zwei": head_zwei_input}, + "data_loader": {"type": "multitask", "scheduler": {"batch_size": 2}}, + "trainer": { + "optimizer": { + "type": "huggingface_adamw", + "lr": 4e-5, + }, + "num_epochs": 2, + }, + } + ) + + serialization_dir = os.path.join(self.TEST_DIR, "serialization_dir") + train_model(params, serialization_dir=serialization_dir) + + args = [ + "evaluate", + str(self.TEST_DIR / "serialization_dir"), + f'{{"head_eins": "{head_eins_input}", "head_zwei": "{head_zwei_input}"}}', + "--output-file", + str(self.TEST_DIR / "output.txt"), + "--predictions-output-file", + str(self.TEST_DIR / "predictions.json"), + ] + + parser = argparse.ArgumentParser(description="Testing") + subparsers = parser.add_subparsers(title="Commands", metavar="") + Evaluate().add_subparser(subparsers) + args = parser.parse_args(args) + metrics = evaluate_from_args(args) + assert "head_eins_accuracy" in metrics + assert "head_zwei_accuracy" in metrics