Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Consistently use underscores in Predictor names (#4340)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr authored Jun 8, 2020
1 parent 2d03c41 commit 73289bc
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixes `PretrainedTransformerMismatchedIndexer` in the case where a token consists of zero word pieces.
- Fixes a bug when using a lazy dataset reader that results in a `UserWarning` from PyTorch being printed at
every iteration during training.
- Predictors names were inconsistently switching between dashes and underscores. Now they all use underscores.
- `Predictor.from_path` now automatically loads plugins (unless you specify `load_plugins=False`) so
that you don't have to manually import a bunch of modules when instantiating predictors from
an archive path.
Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/simple_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,4 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics_to_return.update({x: y for x, y in f1_dict.items() if "overall" in x})
return metrics_to_return

default_predictor = "sentence-tagger"
default_predictor = "sentence_tagger"
4 changes: 2 additions & 2 deletions allennlp/predictors/sentence_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
from allennlp.predictors.predictor import Predictor


@Predictor.register("sentence-tagger")
@Predictor.register("sentence_tagger")
class SentenceTaggerPredictor(Predictor):
"""
Predictor for any model that takes in a sentence and returns
a single set of tags for it. In particular, it can be used with
the [`CrfTagger`](../models/crf_tagger.md) model
and also the [`SimpleTagger`](../models/simple_tagger.md) model.
Registered as a `Predictor` with name "sentence-tagger".
Registered as a `Predictor` with name "sentence_tagger".
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion tests/interpret/input_reduction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_input_reduction(self):
archive = load_archive(
self.FIXTURES_ROOT / "simple_tagger" / "serialization" / "model.tar.gz"
)
predictor = Predictor.from_archive(archive, "sentence-tagger")
predictor = Predictor.from_archive(archive, "sentence_tagger")

reducer = InputReduction(predictor)
reduced = reducer.attack_from_json(inputs, "tokens", "grad_input_1")
Expand Down
10 changes: 5 additions & 5 deletions tests/predictors/predictor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ def test_from_archive_does_not_consume_params(self):
archive = load_archive(
self.FIXTURES_ROOT / "simple_tagger" / "serialization" / "model.tar.gz"
)
Predictor.from_archive(archive, "sentence-tagger")
Predictor.from_archive(archive, "sentence_tagger")

# If it consumes the params, this will raise an exception
Predictor.from_archive(archive, "sentence-tagger")
Predictor.from_archive(archive, "sentence_tagger")

def test_loads_correct_dataset_reader(self):
# This model has a different dataset reader configuration for train and validation. The
Expand All @@ -21,16 +21,16 @@ def test_loads_correct_dataset_reader(self):
self.FIXTURES_ROOT / "simple_tagger_with_span_f1" / "serialization" / "model.tar.gz"
)

predictor = Predictor.from_archive(archive, "sentence-tagger")
predictor = Predictor.from_archive(archive, "sentence_tagger")
assert predictor._dataset_reader._token_indexers["tokens"].namespace == "test_tokens"

predictor = Predictor.from_archive(
archive, "sentence-tagger", dataset_reader_to_load="train"
archive, "sentence_tagger", dataset_reader_to_load="train"
)
assert predictor._dataset_reader._token_indexers["tokens"].namespace == "tokens"

predictor = Predictor.from_archive(
archive, "sentence-tagger", dataset_reader_to_load="validation"
archive, "sentence_tagger", dataset_reader_to_load="validation"
)
assert predictor._dataset_reader._token_indexers["tokens"].namespace == "test_tokens"

Expand Down
2 changes: 1 addition & 1 deletion tests/predictors/sentence_tagger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_predictions_to_labeled_instances(self):
archive = load_archive(
self.FIXTURES_ROOT / "simple_tagger" / "serialization" / "model.tar.gz"
)
predictor = Predictor.from_archive(archive, "sentence-tagger")
predictor = Predictor.from_archive(archive, "sentence_tagger")

instance = predictor._json_to_instance(inputs)
outputs = predictor._model.forward_on_instance(instance)
Expand Down

0 comments on commit 73289bc

Please sign in to comment.