Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add exist_ok parameter to registrable.register decorator. (#3190)
Browse files Browse the repository at this point in the history
* Add exist_ok parameter to registrable.register decorator.

* Fix pylint.

* Fix pylint.

* Add docstring for registrable.register.

* Add a space to appease pylint.

* Switch if not exist_ok to else.

* Switch to fstrings.

* Change logger.debug to logger.info
  • Loading branch information
nelson-liu authored Aug 22, 2019
1 parent ce6dc72 commit 7738cb5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
26 changes: 21 additions & 5 deletions allennlp/common/registrable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,37 @@ class Registrable(FromParams):
default_implementation: str = None

@classmethod
def register(cls: Type[T], name: str):
def register(cls: Type[T], name: str, exist_ok=False):
"""
Register a class under a particular name.
Parameters
----------
name: ``str``
The name to register the class under.
exist_ok: ``bool`, optional (default=False)
If True, overwrites any existing models registered under ``name``. Else,
throws an error if a model is already registered under ``name``.
"""
registry = Registrable._registry[cls]
def add_subclass_to_registry(subclass: Type[T]):
# Add to registry, raise an error if key has already been used.
if name in registry:
message = "Cannot register %s as %s; name already in use for %s" % (
name, cls.__name__, registry[name].__name__)
raise ConfigurationError(message)
if exist_ok:
message = (f"{name} has already been registered as {registry[name].__name__}, but "
f"exist_ok=True, so overwriting with {cls.__name__}")
logger.info(message)
else:
message = (f"Cannot register {name} as {cls.__name__}; "
f"name already in use for {registry[name].__name__}")
raise ConfigurationError(message)
registry[name] = subclass
return subclass
return add_subclass_to_registry

@classmethod
def by_name(cls: Type[T], name: str) -> Type[T]:
logger.debug(f"instantiating registered subclass {name} of {cls}")
logger.info(f"instantiating registered subclass {name} of {cls}")
if name not in Registrable._registry[cls]:
raise ConfigurationError("%s is not a registered name for %s" % (name, cls.__name__))
return Registrable._registry[cls].get(name)
Expand Down
18 changes: 17 additions & 1 deletion allennlp/tests/common/registrable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_registrable_functionality_works(self):
# This function tests the basic `Registrable` functionality:
#
# 1. The decorator should add things to the list.
# 2. The decorator should crash when adding a duplicate.
# 2. The decorator should crash when adding a duplicate (unless exist_ok=True).
# 3. If a default is given, it should show up first in the list.
#
# What we don't test here is that built-in items are registered correctly. Those are
Expand Down Expand Up @@ -56,6 +56,22 @@ class Fake(base_class):
base_class.list_available()
base_class.default_implementation = default

# Verify that registering under a name that already exists
# causes a ConfigurationError.
with pytest.raises(ConfigurationError):
@base_class.register('fake')
class FakeAlternate(base_class):
# pylint: disable=abstract-method
pass

# Registering under a name that already exists should overwrite
# if exist_ok=True.
@base_class.register('fake', exist_ok=True) # pylint: disable=function-redefined
class FakeAlternate(base_class):
# pylint: disable=abstract-method
pass
assert base_class.by_name('fake') == FakeAlternate

del Registrable._registry[base_class]['fake'] # pylint: disable=protected-access

# TODO(mattg): maybe move all of these into tests for the base class?
Expand Down

0 comments on commit 7738cb5

Please sign in to comment.