Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Reduce hotflip vocab size, batch input reduction beam search (#3270)
Browse files Browse the repository at this point in the history
* saving state

* fixes

* Fixed tests

* remove initialize from constructor

* A couple of other minor fixes

* Fix test

* pylint, mypy

* better comments and documentation
  • Loading branch information
matt-gardner authored Sep 21, 2019
1 parent 9a67546 commit 76d248f
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 47 deletions.
5 changes: 5 additions & 0 deletions allennlp/data/token_indexers/elmo_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def convert_word_to_char_ids(self, word: str) -> List[int]:
# +1 one for masking
return [c + 1 for c in char_ids]

def __eq__(self, other) -> bool:
if isinstance(self, other.__class__):
return self.__dict__ == other.__dict__
return NotImplemented


@TokenIndexer.register("elmo_characters")
class ELMoTokenCharactersIndexer(TokenIndexer[List[int]]):
Expand Down
107 changes: 71 additions & 36 deletions allennlp/interpret/attackers/hotflip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from allennlp.common.util import JsonDict, sanitize
from allennlp.data.fields import TextField
from allennlp.data.token_indexers import ELMoTokenCharactersIndexer, TokenCharactersIndexer
from allennlp.data.token_indexers import (ELMoTokenCharactersIndexer,
TokenCharactersIndexer,
SingleIdTokenIndexer)
from allennlp.data.tokenizers import Token
from allennlp.interpret.attackers import utils
from allennlp.interpret.attackers.attacker import Attacker
Expand All @@ -23,28 +25,56 @@ class Hotflip(Attacker):
"""
Runs the HotFlip style attack at the word-level https://arxiv.org/abs/1712.06751. We use the
first-order taylor approximation described in https://arxiv.org/abs/1903.06620, in the function
_first_order_taylor(). Constructing this object is expensive due to the construction of the
embedding matrix.
``_first_order_taylor()``.
We try to re-use the embedding matrix from the model when deciding what other words to flip a
token to. For a large class of models, this is straightforward. When there is a
character-level encoder, however (e.g., with ELMo, any char-CNN, etc.), or a combination of
encoders (e.g., ELMo + glove), we need to construct a fake embedding matrix that we can use in
``_first_order_taylor()``. We do this by getting a list of words from the model's vocabulary
and embedding them using the encoder. This can be expensive, both in terms of time and memory
usage, so we take a ``max_tokens`` parameter to limit the size of this fake embedding matrix.
This also requires a model to `have` a token vocabulary in the first place, which can be
problematic for models that only have character vocabularies.
Parameters
----------
predictor : ``Predictor``
The model (inside a Predictor) that we're attacking. We use this to get gradients and
predictions.
vocab_namespace : ``str``, optional (default='tokens')
We use this to know three things: (1) which tokens we should ignore when producing flips
(we don't consider non-alphanumeric tokens); (2) what the string value is of the token that
we produced, so we can show something human-readable to the user; and (3) if we need to
construct a fake embedding matrix, we use the tokens in the vocabulary as flip candidates.
max_tokens : ``int``, optional (default=5000)
This is only used when we need to construct a fake embedding matrix. That matrix can take
a lot of memory when the vocab size is large. This parameter puts a cap on the number of
tokens to use, so the fake embedding matrix doesn't take as much memory.
"""
def __init__(self, predictor: Predictor, vocab_namespace: str = 'tokens') -> None:
def __init__(self,
predictor: Predictor,
vocab_namespace: str = 'tokens',
max_tokens: int = 5000) -> None:
super().__init__(predictor)
self.vocab = self.predictor._model.vocab
self.namespace = vocab_namespace
# Force new tokens to be alphanumeric
self.max_tokens = max_tokens
self.invalid_replacement_indices: List[int] = []
for i in self.vocab._index_to_token[self.namespace]:
if not self.vocab._index_to_token[self.namespace][i].isalnum():
self.invalid_replacement_indices.append(i)
self.token_embedding: Embedding = None
self.embedding_matrix: torch.Tensor = None

def initialize(self):
"""
Call this function before running attack_from_json(). We put the call to
``_construct_embedding_matrix()`` in this function to prevent a large amount of compute
being done when __init__() is called.
"""
if self.token_embedding is None:
self.token_embedding = self._construct_embedding_matrix()
if self.embedding_matrix is None:
self.embedding_matrix = self._construct_embedding_matrix()

def _construct_embedding_matrix(self) -> Embedding:
"""
Expand All @@ -56,44 +86,49 @@ def _construct_embedding_matrix(self) -> Embedding:
final output embedding. We then group all of those output embeddings into an "embedding
matrix".
"""
# Gets all tokens in the vocab and their corresponding IDs
all_tokens = self.vocab._token_to_index[self.namespace]
all_indices = list(self.vocab._index_to_token[self.namespace].keys())
all_inputs = {"tokens": torch.LongTensor(all_indices).unsqueeze(0)}

embedding_layer = util.find_embedding_layer(self.predictor._model)
if isinstance(embedding_layer, (Embedding, torch.nn.modules.sparse.Embedding)):
# If we're using something that already has an only embedding matrix, we can just use
# that and bypass this method.
return embedding_layer.weight

# We take the top `self.max_tokens` as candidates for hotflip. Because we have to
# construct a new vector for each of these, we can't always afford to use the whole vocab,
# for both runtime and memory considerations.
all_tokens = list(self.vocab._token_to_index[self.namespace])[:self.max_tokens]
max_index = self.vocab.get_token_index(all_tokens[-1], self.namespace)
self.invalid_replacement_indices = [i for i in self.invalid_replacement_indices if i < max_index]

all_inputs = {}
# A bit of a hack; this will only work with some dataset readers, but it'll do for now.
indexers = self.predictor._dataset_reader._token_indexers # type: ignore
for token_indexer in indexers.values():
# handle when a model uses character-level inputs, e.g., a CharCNN
if isinstance(token_indexer, TokenCharactersIndexer):
for indexer_name, token_indexer in indexers.items():
if isinstance(token_indexer, SingleIdTokenIndexer):
all_indices = [self.vocab._token_to_index[self.namespace][token] for token in all_tokens]
all_inputs[indexer_name] = torch.LongTensor(all_indices).unsqueeze(0)
elif isinstance(token_indexer, TokenCharactersIndexer):
tokens = [Token(x) for x in all_tokens]
max_token_length = max(len(x) for x in all_tokens)
indexed_tokens = token_indexer.tokens_to_indices(tokens, self.vocab, "token_characters")
padded_tokens = token_indexer.as_padded_tensor(indexed_tokens,
{"token_characters": len(tokens)},
{"num_token_characters": max_token_length})
all_inputs['token_characters'] = torch.LongTensor(padded_tokens['token_characters']).unsqueeze(0)
# for ELMo models
if isinstance(token_indexer, ELMoTokenCharactersIndexer):
all_inputs[indexer_name] = torch.LongTensor(padded_tokens['token_characters']).unsqueeze(0)
elif isinstance(token_indexer, ELMoTokenCharactersIndexer):
elmo_tokens = []
for token in all_tokens:
elmo_indexed_token = token_indexer.tokens_to_indices([Token(text=token)],
self.vocab,
"sentence")["sentence"]
elmo_tokens.append(elmo_indexed_token[0])
all_inputs["elmo"] = torch.LongTensor(elmo_tokens).unsqueeze(0)
all_inputs[indexer_name] = torch.LongTensor(elmo_tokens).unsqueeze(0)
else:
raise RuntimeError('Unsupported token indexer:', token_indexer)

embedding_layer = util.find_embedding_layer(self.predictor._model)
if isinstance(embedding_layer, torch.nn.modules.sparse.Embedding):
embedding_matrix = embedding_layer.weight
else:
# pass all tokens through the fake matrix and create an embedding out of it.
embedding_matrix = embedding_layer(all_inputs).squeeze()
# pass all tokens through the fake matrix and create an embedding out of it.
embedding_matrix = embedding_layer(all_inputs).squeeze()

return Embedding(num_embeddings=self.vocab.get_vocab_size(self.namespace),
embedding_dim=embedding_matrix.shape[1],
weight=embedding_matrix,
trainable=False)
return embedding_matrix

def attack_from_json(self,
inputs: JsonDict,
Expand Down Expand Up @@ -135,7 +170,7 @@ def attack_from_json(self,
token (hence the list of length one), and we want to change the prediction from
whatever it was to ``"she"``.
"""
if self.token_embedding is None:
if self.embedding_matrix is None:
self.initialize()
ignore_tokens = DEFAULT_IGNORE_TOKENS if ignore_tokens is None else ignore_tokens

Expand Down Expand Up @@ -197,7 +232,7 @@ def attack_from_json(self,

while True:
# Compute L2 norm of all grads.
grad = grads[grad_input_field]
grad = grads[grad_input_field][0]
grads_magnitude = [g.dot(g) for g in grad]

# only flip a token once
Expand All @@ -211,15 +246,15 @@ def attack_from_json(self,
break
flipped.append(index_of_token_to_flip)

# TODO(mattg): This is quite a bit of a hack, both for gpt2 and for getting the
# vocab id in general... I don't have better ideas at the moment, though.
indexer_name = 'tokens' if self.namespace == 'gpt2' else self.namespace
# TODO(mattg): This is quite a bit of a hack for getting the vocab id... I don't
# have better ideas at the moment, though.
indexer_name = self.namespace
input_tokens = text_field._indexed_tokens[indexer_name]
original_id_of_token_to_flip = input_tokens[index_of_token_to_flip]

# Get new token using taylor approximation.
new_id = self._first_order_taylor(grad[index_of_token_to_flip],
self.token_embedding.weight, # type: ignore
self.embedding_matrix,
original_id_of_token_to_flip,
sign)

Expand Down Expand Up @@ -258,7 +293,7 @@ def attack_from_json(self,
"outputs": outputs})

def _first_order_taylor(self, grad: numpy.ndarray,
embedding_matrix: torch.nn.parameter.Parameter,
embedding_matrix: torch.Tensor,
token_idx: int,
sign: int) -> int:
"""
Expand Down
27 changes: 23 additions & 4 deletions allennlp/interpret/attackers/input_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,31 @@ def get_length(input_instance: Instance):
return len(input_text_field.tokens)
candidates = heapq.nsmallest(self.beam_size, candidates, key=lambda x: get_length(x[0]))

beam_candidates = deepcopy(candidates)
# predictor.get_gradients is where the most expensive computation happens, so we're
# going to do it in a batch, up front, before iterating over the results.
copied_candidates = deepcopy(candidates)
all_grads, all_outputs = self.predictor.get_gradients([x[0] for x in copied_candidates])

# The output in `all_grads` and `all_outputs` is batched in a dictionary (e.g.,
# {'grad_output_1': batched_tensor}). We need to split this into a list of non-batched
# dictionaries that we can iterate over.
split_grads = []
for i in range(len(copied_candidates)):
split_grads.append({key: value[i] for key, value in all_grads.items()})
split_outputs = []
for i in range(len(copied_candidates)):
instance_outputs = {}
for key, value in all_outputs.items():
if key == 'loss':
continue
instance_outputs[key] = value[i]
split_outputs.append(instance_outputs)
beam_candidates = [(x[0], x[1], x[2], split_grads[i], split_outputs[i])
for i, x in enumerate(copied_candidates)]

candidates = []
for beam_instance, smallest_idx, tag_mask in beam_candidates:
# get gradients and predictions
for beam_instance, smallest_idx, tag_mask, grads, outputs in beam_candidates:
beam_tag_mask = deepcopy(tag_mask)
grads, outputs = self.predictor.get_gradients([beam_instance])

for output in outputs:
if isinstance(outputs[output], torch.Tensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:

# Normalize results
for key, grad in grads.items():
embedding_grad = numpy.sum(grad, axis=1)
# The [0] here is undo-ing the batching that happens in get_gradients.
embedding_grad = numpy.sum(grad[0], axis=1)
norm = numpy.linalg.norm(embedding_grad, ord=1)
normalized_grad = [math.fabs(e) / norm for e in embedding_grad]
grads[key] = normalized_grad
Expand Down
3 changes: 2 additions & 1 deletion allennlp/interpret/saliency_interpreters/simple_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:
# This is then used as an index into the reversed input array to match up the
# gradient and its respective embedding.
input_idx = int(key[-1]) - 1
emb_grad = numpy.sum(grad * embeddings_list[input_idx], axis=1)
# The [0] here is undo-ing the batching that happens in get_gradients.
emb_grad = numpy.sum(grad[0] * embeddings_list[input_idx], axis=1)
norm = numpy.linalg.norm(emb_grad, ord=1)
normalized_grad = [math.fabs(e) / norm for e in emb_grad]
grads[key] = normalized_grad
Expand Down
4 changes: 3 additions & 1 deletion allennlp/interpret/saliency_interpreters/smooth_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:
for key, grad in grads.items():
# TODO (@Eric-Wallace), SmoothGrad is not using times input normalization.
# Fine for now, but should fix for consistency.
embedding_grad = numpy.sum(grad, axis=1)

# The [0] here is undo-ing the batching that happens in get_gradients.
embedding_grad = numpy.sum(grad[0], axis=1)
norm = numpy.linalg.norm(embedding_grad, ord=1)
normalized_grad = [math.fabs(e) / norm for e in embedding_grad]
grads[key] = normalized_grad
Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/next_token_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def forward(self, # type: ignore
tokens: Dict[str, torch.LongTensor],
target_ids: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
batch_size = tokens['tokens'].size()[0]

# Shape: (batch_size, num_tokens, embedding_dim)
embeddings = self._text_field_embedder(tokens)
batch_size = embeddings.size(0)

# Shape: (batch_size, num_tokens, encoding_dim)
if self._contextualizer:
Expand Down
15 changes: 15 additions & 0 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,8 @@ def find_embedding_layer(model: torch.nn.Module) -> torch.nn.Module:
from pytorch_transformers.modeling_gpt2 import GPT2Model
from pytorch_transformers.modeling_bert import BertEmbeddings as BertEmbeddingsNew
from allennlp.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import Embedding
for module in model.modules():
if isinstance(module, BertEmbeddingsOld):
return module.word_embeddings
Expand All @@ -1497,5 +1499,18 @@ def find_embedding_layer(model: torch.nn.Module) -> torch.nn.Module:
return module.wte
for module in model.modules():
if isinstance(module, TextFieldEmbedder):
# pylint: disable=protected-access
if isinstance(module, BasicTextFieldEmbedder):
# We'll have a check for single Embedding cases, because we can be more efficient
# in cases like this. If this check fails, then for something like hotflip we need
# to actually run the text field embedder and construct a vector for each token.
if len(module._token_embedders) == 1:
embedder = list(module._token_embedders.values())[0]
if isinstance(embedder, Embedding):
if embedder._projection is None: # pylint: disable=protected-access
# If there's a projection inside the Embedding, then we need to return
# the whole TextFieldEmbedder, because there's more computation that
# needs to be done than just multiply by an embedding matrix.
return embedder
return module
raise RuntimeError("No embedding module found!")
2 changes: 1 addition & 1 deletion allennlp/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_gradients(self,
grad_dict = dict()
for idx, grad in enumerate(embedding_gradients):
key = 'grad_input_' + str(idx + 1)
grad_dict[key] = grad.squeeze_(0).detach().cpu().numpy()
grad_dict[key] = grad.detach().cpu().numpy()

return grad_dict, outputs

Expand Down
4 changes: 2 additions & 2 deletions allennlp/tests/predictors/predictor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def test_get_gradients(self):
assert 'grad_input_2' in grads
assert grads['grad_input_1'] is not None
assert grads['grad_input_2'] is not None
assert len(grads['grad_input_1']) == 9 # 9 words in hypothesis
assert len(grads['grad_input_2']) == 5 # 5 words in premise
assert len(grads['grad_input_1'][0]) == 9 # 9 words in hypothesis
assert len(grads['grad_input_2'][0]) == 5 # 5 words in premise

0 comments on commit 76d248f

Please sign in to comment.