Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

load plugins from Predictor.from_path #4333

Merged
merged 3 commits into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixes `PretrainedTransformerMismatchedIndexer` in the case where a token consists of zero word pieces.
- Fixes a bug when using a lazy dataset reader that results in a `UserWarning` from PyTorch being printed at
every iteration during training.
- `Predictor.from_path` now automatically loads plugins (unless you specify `load_plugins=False`) so
that you don't have to manually import a bunch of modules when instantiating predictors from
an archive path.

### Added

Expand Down
10 changes: 9 additions & 1 deletion allennlp/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor
from torch import backends

from allennlp.common import Registrable
from allennlp.common import Registrable, plugins
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import DatasetReader, Instance
from allennlp.data.batch import Batch
Expand Down Expand Up @@ -235,6 +235,7 @@ def from_path(
cuda_device: int = -1,
dataset_reader_to_load: str = "validation",
frozen: bool = True,
import_plugins: bool = True,
) -> "Predictor":
"""
Instantiate a `Predictor` from an archive path.
Expand All @@ -257,12 +258,19 @@ def from_path(
"validation".
frozen : `bool`, optional (default=`True`)
If we should call `model.eval()` when building the predictor.
import_plugins : `bool`, optional (default=`True`)
If `True`, we attempt to import plugins before loading the predictor.
This comes with additional overhead, but means you don't need to explicitly
import the modules that your predictor depends on as long as those modules
can be found by `allennlp.common.plugins.import_plugins()`.

# Returns

`Predictor`
A Predictor instance.
"""
if import_plugins:
plugins.import_plugins()
return Predictor.from_archive(
load_archive(archive_path, cuda_device=cuda_device),
predictor_name,
Expand Down