diff --git a/allennlp/common/registrable.py b/allennlp/common/registrable.py index d2ec1f4023a..be680c3ca18 100644 --- a/allennlp/common/registrable.py +++ b/allennlp/common/registrable.py @@ -39,21 +39,37 @@ class Registrable(FromParams): default_implementation: str = None @classmethod - def register(cls: Type[T], name: str): + def register(cls: Type[T], name: str, exist_ok=False): + """ + Register a class under a particular name. + + Parameters + ---------- + name: ``str`` + The name to register the class under. + exist_ok: ``bool`, optional (default=False) + If True, overwrites any existing models registered under ``name``. Else, + throws an error if a model is already registered under ``name``. + """ registry = Registrable._registry[cls] def add_subclass_to_registry(subclass: Type[T]): # Add to registry, raise an error if key has already been used. if name in registry: - message = "Cannot register %s as %s; name already in use for %s" % ( - name, cls.__name__, registry[name].__name__) - raise ConfigurationError(message) + if exist_ok: + message = (f"{name} has already been registered as {registry[name].__name__}, but " + f"exist_ok=True, so overwriting with {cls.__name__}") + logger.info(message) + else: + message = (f"Cannot register {name} as {cls.__name__}; " + f"name already in use for {registry[name].__name__}") + raise ConfigurationError(message) registry[name] = subclass return subclass return add_subclass_to_registry @classmethod def by_name(cls: Type[T], name: str) -> Type[T]: - logger.debug(f"instantiating registered subclass {name} of {cls}") + logger.info(f"instantiating registered subclass {name} of {cls}") if name not in Registrable._registry[cls]: raise ConfigurationError("%s is not a registered name for %s" % (name, cls.__name__)) return Registrable._registry[cls].get(name) diff --git a/allennlp/tests/common/registrable_test.py b/allennlp/tests/common/registrable_test.py index bb43995c94f..efe86182d64 100644 --- a/allennlp/tests/common/registrable_test.py +++ b/allennlp/tests/common/registrable_test.py @@ -27,7 +27,7 @@ def test_registrable_functionality_works(self): # This function tests the basic `Registrable` functionality: # # 1. The decorator should add things to the list. - # 2. The decorator should crash when adding a duplicate. + # 2. The decorator should crash when adding a duplicate (unless exist_ok=True). # 3. If a default is given, it should show up first in the list. # # What we don't test here is that built-in items are registered correctly. Those are @@ -56,6 +56,22 @@ class Fake(base_class): base_class.list_available() base_class.default_implementation = default + # Verify that registering under a name that already exists + # causes a ConfigurationError. + with pytest.raises(ConfigurationError): + @base_class.register('fake') + class FakeAlternate(base_class): + # pylint: disable=abstract-method + pass + + # Registering under a name that already exists should overwrite + # if exist_ok=True. + @base_class.register('fake', exist_ok=True) # pylint: disable=function-redefined + class FakeAlternate(base_class): + # pylint: disable=abstract-method + pass + assert base_class.by_name('fake') == FakeAlternate + del Registrable._registry[base_class]['fake'] # pylint: disable=protected-access # TODO(mattg): maybe move all of these into tests for the base class?