Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
allow implicit package imports (#3253)
Browse files Browse the repository at this point in the history
* allow implicit package imports

* tweak logic

* address PR feedback

* add pop_choice test

* add message to failure

* missing parenthesis
  • Loading branch information
joelgrus authored Sep 17, 2019
1 parent 48de866 commit 05be16a
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 8 deletions.
27 changes: 22 additions & 5 deletions allennlp/common/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,11 @@ def get(self, key: str, default: Any = DEFAULT):
value = self.params.get(key, default)
return self._check_is_dict(key, value)

def pop_choice(self, key: str, choices: List[Any], default_to_first_choice: bool = False) -> Any:
def pop_choice(self,
key: str,
choices: List[Any],
default_to_first_choice: bool = False,
allow_class_names: bool = True) -> Any:
"""
Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of
the given choices. Note that this `pops` the key from params, modifying the dictionary,
Expand All @@ -342,12 +346,21 @@ def pop_choice(self, key: str, choices: List[Any], default_to_first_choice: bool
``ConfigurationError``, because specifying the ``key`` is required (e.g., you `have` to
specify your model class when running an experiment, but you can feel free to use
default settings for encoders if you want).
allow_class_names : bool, optional (default = True)
If this is `True`, then we allow unknown choices that look like fully-qualified class names.
This is to allow e.g. specifying a model type as my_library.my_model.MyModel
and importing it on the fly. Our check for "looks like" is extremely lenient
and consists of checking that the value contains a '.'.
"""
default = choices[0] if default_to_first_choice else self.DEFAULT
value = self.pop(key, default)
if value not in choices:
ok_because_class_name = allow_class_names and '.' in value
if value not in choices and not ok_because_class_name:
key_str = self.history + key
message = '%s not in acceptable choices for %s: %s' % (value, key_str, str(choices))
message = (f"{value} not in acceptable choices for {key_str}: {choices}. "
"You should either use the --include-package flag to make sure the correct module "
"is loaded, or use a fully qualified class name in your config file like "
"""{"model": "my_module.models.MyModel"} to have it imported automatically.""")
raise ConfigurationError(message)
return value

Expand Down Expand Up @@ -541,7 +554,8 @@ def pop_choice(params: Dict[str, Any],
key: str,
choices: List[Any],
default_to_first_choice: bool = False,
history: str = "?.") -> Any:
history: str = "?.",
allow_class_names: bool = True) -> Any:
"""
Performs the same function as :func:`Params.pop_choice`, but is required in order to deal with
places that the Params object is not welcome, such as inside Keras layers. See the docstring
Expand All @@ -552,7 +566,10 @@ def pop_choice(params: Dict[str, Any],
history, so you'll have to fix that in the log if you want to actually recover the logged
parameters.
"""
value = Params(params, history).pop_choice(key, choices, default_to_first_choice)
value = Params(params, history).pop_choice(key,
choices,
default_to_first_choice,
allow_class_names=allow_class_names)
return value


Expand Down
33 changes: 30 additions & 3 deletions allennlp/common/registrable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from collections import defaultdict
from typing import TypeVar, Type, Dict, List
import importlib
import logging

from allennlp.common.checks import ConfigurationError
Expand Down Expand Up @@ -70,9 +71,35 @@ def add_subclass_to_registry(subclass: Type[T]):
@classmethod
def by_name(cls: Type[T], name: str) -> Type[T]:
logger.info(f"instantiating registered subclass {name} of {cls}")
if name not in Registrable._registry[cls]:
raise ConfigurationError("%s is not a registered name for %s" % (name, cls.__name__))
return Registrable._registry[cls].get(name)
if name in Registrable._registry[cls]:
return Registrable._registry[cls].get(name)
elif "." in name:
# This might be a fully qualified class name, so we'll try importing its "module"
# and finding it there.
parts = name.split(".")
submodule = ".".join(parts[:-1])
class_name = parts[-1]

try:
module = importlib.import_module(submodule)
except ModuleNotFoundError:
raise ConfigurationError(f"tried to interpret {name} as a path to a class "
f"but unable to import module {submodule}")

try:
return getattr(module, class_name)
except AttributeError:
raise ConfigurationError(f"tried to interpret {name} as a path to a class "
f"but unable to find class {class_name} in {submodule}")

else:
# is not a qualified class name
raise ConfigurationError(f"{name} is not a registered name for {cls.__name__}. "
"You probably need to use the --include-package flag "
"to load your custom code. Alternatively, you can specify your choices "
"""using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """
"in which case they will be automatically imported correctly.")


@classmethod
def list_available(cls) -> List[str]:
Expand Down
17 changes: 17 additions & 0 deletions allennlp/tests/common/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest

from allennlp.common.checks import ConfigurationError
from allennlp.common.params import Params, unflatten, with_fallback, parse_overrides, infer_and_cast
from allennlp.common.testing import AllenNlpTestCase

Expand Down Expand Up @@ -435,3 +436,19 @@ def test_duplicate_copies_all_params_state(self):

assert new_params.loading_from_archive
assert new_params.files_to_archive == {"hey": "this is a path"}

def test_pop_choice(self):
choices = ['my_model', 'other_model']
params = Params({'model': 'my_model'})
assert params.pop_choice('model', choices) == 'my_model'

params = Params({'model': 'non_existent_model'})
with pytest.raises(ConfigurationError):
params.pop_choice('model', choices)

params = Params({'model': 'module.submodule.ModelName'})
assert params.pop_choice('model', 'choices') == 'module.submodule.ModelName'

params = Params({'model': 'module.submodule.ModelName'})
with pytest.raises(ConfigurationError):
params.pop_choice('model', choices, allow_class_names=False)
42 changes: 42 additions & 0 deletions allennlp/tests/common/registrable_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# pylint: disable=no-self-use,invalid-name,too-many-public-methods
import inspect
import os
import sys

import pytest
import torch
import torch.nn.init
Expand Down Expand Up @@ -151,3 +155,41 @@ def test_registry_has_builtin_similarity_functions(self):
assert SimilarityFunction.by_name("bilinear").__name__ == 'BilinearSimilarity'
assert SimilarityFunction.by_name("linear").__name__ == 'LinearSimilarity'
assert SimilarityFunction.by_name("cosine").__name__ == 'CosineSimilarity'

def test_implicit_include_package(self):
# Create a new package in a temporary dir
packagedir = self.TEST_DIR / 'testpackage'
packagedir.mkdir() # pylint: disable=no-member
(packagedir / '__init__.py').touch() # pylint: disable=no-member

# And add that directory to the path
sys.path.insert(0, str(self.TEST_DIR))

# Write out a duplicate dataset reader there, but registered under a different name.
snli_reader = DatasetReader.by_name('snli')

with open(inspect.getabsfile(snli_reader)) as f:
code = f.read().replace("""@DatasetReader.register("snli")""",
"""@DatasetReader.register("snli-fake")""")

with open(os.path.join(packagedir, 'reader.py'), 'w') as f:
f.write(code)

# Fails to import by registered name
with pytest.raises(ConfigurationError) as exc:
DatasetReader.by_name('snli-fake')
assert "is not a registered name" in str(exc.value)

# Fails to import with wrong module name
with pytest.raises(ConfigurationError) as exc:
DatasetReader.by_name('testpackage.snli_reader.SnliFakeReader')
assert "unable to import module" in str(exc.value)

# Fails to import with wrong class name
with pytest.raises(ConfigurationError):
DatasetReader.by_name('testpackage.reader.SnliFakeReader')
assert "unable to find class" in str(exc.value)

# Imports successfully with right fully qualified name
duplicate_reader = DatasetReader.by_name('testpackage.reader.SnliReader')
assert duplicate_reader.__name__ == 'SnliReader'

0 comments on commit 05be16a

Please sign in to comment.