Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
improve error message for Vocab.get_token_index (#4185)
Browse files Browse the repository at this point in the history
* improve error message Vocab.get_token_index

* fix comments

* fix linting

* a little better wording
  • Loading branch information
epwalsh authored May 1, 2020
1 parent 31616de commit 2602c8f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
9 changes: 6 additions & 3 deletions allennlp/data/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,15 +651,18 @@ def get_token_to_index_vocabulary(self, namespace: str = "tokens") -> Dict[str,
return self._token_to_index[namespace]

def get_token_index(self, token: str, namespace: str = "tokens") -> int:
if token in self._token_to_index[namespace]:
try:
return self._token_to_index[namespace][token]
else:
except KeyError:
try:
return self._token_to_index[namespace][self._oov_token]
except KeyError:
logger.error("Namespace: %s", namespace)
logger.error("Token: %s", token)
raise
raise KeyError(
f"'{token}' not found in vocab namespace '{namespace}', and namespace "
f"does not contain the default OOV token ('{self._oov_token}')"
)

def get_token_from_index(self, index: int, namespace: str = "tokens") -> str:
return self._index_to_token[namespace][index]
Expand Down
36 changes: 35 additions & 1 deletion allennlp/tests/data/vocabulary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def test_namespace_dependent_default_dict(self):
assert default_dict["foobaz"] == 3

def test_unknown_token(self):

# We're putting this behavior in a test so that the behavior is documented. There is
# solver code that depends in a small way on how we treat the unknown token, so any
# breaking change to this behavior should break a test, so you know you've done something
Expand All @@ -185,6 +184,41 @@ def test_unknown_token(self):
assert oov_index == 1
assert vocab.get_token_index("unseen word") == oov_index

def test_get_token_index(self):
# The behavior of get_token_index depends on whether or not the namespace has an OOV token.
vocab = Vocabulary(
counter={"labels": {"foo": 3, "bar": 2}, "tokens": {"foo": 3, "bar": 2}},
non_padded_namespaces=["labels"],
)

# Quick sanity check, this is what the token to index mappings should look like.
expected_token_to_index_dicts = {
"tokens": {vocab._padding_token: 0, vocab._oov_token: 1, "foo": 2, "bar": 3},
"labels": {"foo": 0, "bar": 1},
}
assert vocab._token_to_index["tokens"] == expected_token_to_index_dicts["tokens"]
assert vocab._token_to_index["labels"] == expected_token_to_index_dicts["labels"]

# get_token_index should return the OOV token index for OOV tokens when it can.
assert vocab.get_token_index("baz", "tokens") == 1

# get_token_index should raise helpful error message when token is OOV and there
# is no default OOV token in the namespace.
with pytest.raises(
KeyError,
match=r"'baz' not found .* and namespace does not contain the default OOV token .*",
):
vocab.get_token_index("baz", "labels")

# same should happen for the default OOV token itself, if not in namespace.
with pytest.raises(KeyError, match=rf"'{vocab._oov_token}' not found .*"):
vocab.get_token_index(vocab._oov_token, "labels")

# Now just make sure the token_to_index mappings haven't been modified
# (since we're defaultdicts we need to be a little careful here).
assert vocab._token_to_index["tokens"] == expected_token_to_index_dicts["tokens"]
assert vocab._token_to_index["labels"] == expected_token_to_index_dicts["labels"]

def test_set_from_file_reads_padded_files(self):

vocab_filename = self.TEST_DIR / "vocab_file"
Expand Down

0 comments on commit 2602c8f

Please sign in to comment.