Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Targeted hotflip attacks and beam search for input reduction (#3206)
Browse files Browse the repository at this point in the history
* Targeted hotflip attack, beam search for input reduction

* Adding a test

* Fix tests

* pylint, mypy, tests

* last pylint (i hope...)
  • Loading branch information
matt-gardner authored Aug 29, 2019
1 parent f2824fd commit 78ee3d8
Show file tree
Hide file tree
Showing 11 changed files with 372 additions and 157 deletions.
5 changes: 2 additions & 3 deletions allennlp/data/fields/span_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, span_start: int, span_end: int, sequence_field: SequenceField
f"but found ({span_start}, {span_end}).")

if span_end > self.sequence_field.sequence_length() - 1:
raise ValueError(f"span_end must be < len(sequence_length) - 1, but found "
raise ValueError(f"span_end must be <= len(sequence_length) - 1, but found "
f"{span_end} and {self.sequence_field.sequence_length() - 1} respectively.")

@overrides
Expand All @@ -62,5 +62,4 @@ def __str__(self) -> str:
def __eq__(self, other) -> bool:
if isinstance(other, tuple) and len(other) == 2:
return other == (self.span_start, self.span_end)
else:
return id(self) == id(other)
return super().__eq__(other)
9 changes: 8 additions & 1 deletion allennlp/interpret/attackers/attacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def attack_from_json(self,
inputs: JsonDict,
input_field_to_attack: str,
grad_input_field: str,
ignore_tokens: List[str]) -> JsonDict:
ignore_tokens: List[str],
target: JsonDict) -> JsonDict:
"""
This function finds a modification to the input text that would change the model's
prediction in some desired manner (e.g., an adversarial attack).
Expand All @@ -40,6 +41,12 @@ def attack_from_json(self,
The field in the gradients dictionary that contains the input gradients. For example,
`grad_input_1` will be the field for single input tasks. See get_gradients() in
`Predictor` for more information on field names.
target : ``JsonDict``
If given, this is a `targeted` attack, trying to change the prediction to a particular
value, instead of just changing it from its original prediction. Subclasses are not
required to accept this argument, as not all attacks make sense as targeted attacks.
Perhaps that means we should make the API more crisp, but adding another class is not
worth it.
Returns
-------
Expand Down
235 changes: 166 additions & 69 deletions allennlp/interpret/attackers/hotflip.py

Large diffs are not rendered by default.

159 changes: 105 additions & 54 deletions allennlp/interpret/attackers/input_reduction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
from typing import List, Tuple
import heapq

import numpy as np
import torch
Expand All @@ -9,6 +10,7 @@
from allennlp.data.fields import TextField, SequenceLabelField
from allennlp.interpret.attackers import utils
from allennlp.interpret.attackers.attacker import Attacker
from allennlp.predictors import Predictor


@Attacker.register('input-reduction')
Expand All @@ -22,71 +24,110 @@ class InputReduction(Attacker):
This check is brittle, i.e., the code could break if the name of this field has changed, or if
a non-NER model has a field called "tags".
"""
def __init__(self, predictor: Predictor, beam_size: int = 3) -> None:
super().__init__(predictor)
self.beam_size = beam_size

def attack_from_json(self, inputs: JsonDict = None,
input_field_to_attack: str = 'tokens',
grad_input_field: str = 'grad_input_1',
ignore_tokens: List[str] = None):
ignore_tokens: List[str] = None,
target: JsonDict = None):
if target is not None:
raise ValueError('Input reduction does not implement targeted attacks')
ignore_tokens = ["@@NULL@@"] if ignore_tokens is None else ignore_tokens
original_instances = self.predictor.json_to_labeled_instances(inputs)
original_text_field: TextField = original_instances[0][input_field_to_attack] # type: ignore
original_tokens = deepcopy(original_text_field.tokens)
final_tokens = []
for current_instance in original_instances:
# Save fields that must be checked for equality
fields_to_compare = utils.get_fields_to_compare(inputs, current_instance, input_field_to_attack)

# Set num_ignore_tokens, which tells input reduction when to stop
# We keep at least one token for input reduction on classification/entailment/etc.
if "tags" not in current_instance:
num_ignore_tokens = 1

# Set num_ignore_tokens for NER and build token mask
else:
num_ignore_tokens, tag_mask, original_tags = _get_ner_tags_and_mask(current_instance,
input_field_to_attack,
ignore_tokens)

current_text_field: TextField = current_instance[input_field_to_attack] # type: ignore
current_tokens = deepcopy(current_text_field.tokens)
smallest_idx = -1
# keep removing tokens until prediction is about to change
while len(current_text_field.tokens) >= num_ignore_tokens:
for instance in original_instances:
final_tokens.append(self._attack_instance(inputs,
instance,
input_field_to_attack,
grad_input_field,
ignore_tokens))
return sanitize({"final": final_tokens, "original": original_tokens})

def _attack_instance(self,
inputs: JsonDict,
instance: Instance,
input_field_to_attack: str,
grad_input_field: str,
ignore_tokens: List[str]):
# Save fields that must be checked for equality
fields_to_compare = utils.get_fields_to_compare(inputs, instance, input_field_to_attack)

# Set num_ignore_tokens, which tells input reduction when to stop
# We keep at least one token for input reduction on classification/entailment/etc.
if "tags" not in instance:
num_ignore_tokens = 1
tag_mask = None

# Set num_ignore_tokens for NER and build token mask
else:
num_ignore_tokens, tag_mask, original_tags = _get_ner_tags_and_mask(instance,
input_field_to_attack,
ignore_tokens)

text_field: TextField = instance[input_field_to_attack] # type: ignore
current_tokens = deepcopy(text_field.tokens)
candidates = [(instance, -1, tag_mask)]
# keep removing tokens until prediction is about to change
while len(current_tokens) > num_ignore_tokens and candidates:
# sort current candidates by smallest length (we want to remove as many tokens as possible)
def get_length(input_instance: Instance):
input_text_field: TextField = input_instance[input_field_to_attack] # type: ignore
return len(input_text_field.tokens)
candidates = heapq.nsmallest(self.beam_size, candidates, key=lambda x: get_length(x[0]))

beam_candidates = deepcopy(candidates)
candidates = []
for beam_instance, smallest_idx, tag_mask in beam_candidates:
# get gradients and predictions
grads, outputs = self.predictor.get_gradients([current_instance])
beam_tag_mask = deepcopy(tag_mask)
grads, outputs = self.predictor.get_gradients([beam_instance])

for output in outputs:
if isinstance(outputs[output], torch.Tensor):
outputs[output] = outputs[output].detach().cpu().numpy().squeeze().squeeze()
elif isinstance(outputs[output], list):
outputs[output] = outputs[output][0]

# Check if any fields have changed, if so, break loop
if "tags" not in current_instance:
if any(current_instance[field] != fields_to_compare[field] for field in fields_to_compare):
break
# Check if any fields have changed, if so, next beam
if "tags" not in instance:
# relabel beam_instance since last iteration removed an input token
beam_instance = self.predictor.predictions_to_labeled_instances(beam_instance, outputs)[0]
if utils.instance_has_changed(beam_instance, fields_to_compare):
continue

# special case for sentence tagging (we have tested NER)
else:
if smallest_idx != -1:
del tag_mask[smallest_idx]
cur_tags = [outputs["tags"][x] for x in range(len(outputs["tags"])) if tag_mask[x]]
# remove the mask where you remove the input token from.
if smallest_idx != -1: # Don't delete on the very first iteration
del beam_tag_mask[smallest_idx]
cur_tags = [outputs["tags"][x] for x in range(len(outputs["tags"])) if beam_tag_mask[x]]
if cur_tags != original_tags:
break
continue

# remove a token from the input
current_tokens = deepcopy(current_text_field.tokens)
current_instance, smallest_idx = _remove_one_token(current_instance,
text_field: TextField = beam_instance[input_field_to_attack] # type: ignore
current_tokens = deepcopy(text_field.tokens)
reduced_instances_and_smallest = _remove_one_token(beam_instance,
input_field_to_attack,
grads[grad_input_field],
ignore_tokens)

final_tokens.append(current_tokens)
return sanitize({"final": final_tokens, "original": original_tokens})
ignore_tokens,
self.beam_size,
beam_tag_mask)
candidates.extend(reduced_instances_and_smallest)
return current_tokens


def _remove_one_token(instance: Instance,
input_field_to_attack: str,
grads: np.ndarray,
ignore_tokens: List[str]) -> Tuple[Instance, int]:
ignore_tokens: List[str],
beam_size: int,
tag_mask: List[int]) -> List[Tuple[Instance, int, List[int]]]:
"""
Finds the token with the smallest gradient and removes it.
"""
Expand All @@ -106,27 +147,37 @@ def _remove_one_token(instance: Instance,
for idx, label in enumerate(labels):
if label != "O":
grads_mag[idx] = float("inf")
reduced_instances_and_smallest: List[Tuple[Instance, int, List[int]]] = []
for _ in range(beam_size):
# copy instance and edit later
copied_instance = deepcopy(instance)
copied_text_field: TextField = copied_instance[input_field_to_attack] # type: ignore

smallest = np.argmin(grads_mag)
if smallest == float("inf"): # if all are ignored tokens, return.
return instance, smallest
# find smallest
smallest = np.argmin(grads_mag)
if grads_mag[smallest] == float("inf"): # if all are ignored tokens, return.
break
grads_mag[smallest] = float("inf") # so the other beams don't use this token

# remove smallest
inputs_before_smallest = text_field.tokens[0:smallest]
inputs_after_smallest = text_field.tokens[smallest + 1:]
text_field.tokens = inputs_before_smallest + inputs_after_smallest
# remove smallest
inputs_before_smallest = copied_text_field.tokens[0:smallest]
inputs_after_smallest = copied_text_field.tokens[smallest + 1:]
copied_text_field.tokens = inputs_before_smallest + inputs_after_smallest

if "tags" in instance:
tag_field_before_smallest = tag_field.labels[0:smallest]
tag_field_after_smallest = tag_field.labels[smallest + 1:]
tag_field.labels = tag_field_before_smallest + tag_field_after_smallest # type: ignore
tag_field.sequence_field = text_field
if "tags" in instance:
tag_field: SequenceLabelField = copied_instance["tags"] # type: ignore
tag_field_before_smallest = tag_field.labels[0:smallest]
tag_field_after_smallest = tag_field.labels[smallest + 1:]
tag_field.labels = tag_field_before_smallest + tag_field_after_smallest # type: ignore
tag_field.sequence_field = copied_text_field

copied_instance.indexed = False
reduced_instances_and_smallest.append((copied_instance, smallest, tag_mask))

instance.indexed = False
return instance, smallest
return reduced_instances_and_smallest


def _get_ner_tags_and_mask(current_instance: Instance,
def _get_ner_tags_and_mask(instance: Instance,
input_field_to_attack: str,
ignore_tokens: List[str]):
"""
Expand All @@ -135,15 +186,15 @@ def _get_ner_tags_and_mask(current_instance: Instance,
"""
# Set num_ignore_tokens
num_ignore_tokens = 0
input_field: TextField = current_instance[input_field_to_attack] # type: ignore
input_field: TextField = instance[input_field_to_attack] # type: ignore
for token in input_field.tokens:
if str(token) in ignore_tokens:
num_ignore_tokens += 1

# save the original tags and a 0/1 mask where the tags are
tag_mask = []
original_tags = []
tag_field: SequenceLabelField = current_instance["tags"] # type: ignore
tag_field: SequenceLabelField = instance["tags"] # type: ignore
for label in tag_field.labels:
if label != "O":
tag_mask.append(1)
Expand Down
15 changes: 13 additions & 2 deletions allennlp/interpret/attackers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_fields_to_compare(inputs: JsonDict, instance: Instance, input_field_to_a
instance : ``Instance``
A labeled instance that is output from json_to_labeled_instances().
input_field_to_attack : ``str``
The key in the inputs JsonDict you want to attack, e.g., `tokens`.
The key in the inputs JsonDict you want to attack, e.g., tokens.
Returns
-------
Expand All @@ -22,6 +22,17 @@ def get_fields_to_compare(inputs: JsonDict, instance: Instance, input_field_to_a
fields_to_compare = {
key: instance[key]
for key in instance.fields
if key not in inputs and key != input_field_to_attack
if key not in inputs and key != input_field_to_attack and key != 'metadata' and key != 'output'
}
return fields_to_compare

def instance_has_changed(instance: Instance, fields_to_compare: JsonDict):
if 'clusters' in fields_to_compare:
# Coref needs a special case here, apparently. I (mattg) am not sure why the check below
# doesn't catch this case; TODO: look into this.
original_clusters = set(tuple(l) for l in fields_to_compare['clusters'])
new_clusters = set(tuple(l) for l in instance['clusters']) # type: ignore
return original_clusters != new_clusters
if any(instance[field] != fields_to_compare[field] for field in fields_to_compare):
return True
return False
10 changes: 3 additions & 7 deletions allennlp/interpret/saliency_interpreters/integrated_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import Instance
from allennlp.interpret.saliency_interpreters.saliency_interpreter import SaliencyInterpreter
from allennlp.modules.text_field_embedders import TextFieldEmbedder

from allennlp.nn import util

@SaliencyInterpreter.register('integrated-gradient')
class IntegratedGradient(SaliencyInterpreter):
Expand Down Expand Up @@ -52,11 +51,8 @@ def forward_hook(module, inputs, output): # pylint: disable=unused-argument
output.mul_(alpha)

# Register the hook
handle = None
for module in self.predictor._model.modules():
if isinstance(module, TextFieldEmbedder):
handle = module.register_forward_hook(forward_hook)

embedding_layer = util.find_embedding_layer(self.predictor._model)
handle = embedding_layer.register_forward_hook(forward_hook)
return handle

def _integrate_gradients(self, instance: Instance) -> Dict[str, numpy.ndarray]:
Expand Down
8 changes: 3 additions & 5 deletions allennlp/interpret/saliency_interpreters/simple_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from allennlp.common.util import JsonDict, sanitize
from allennlp.interpret.saliency_interpreters.saliency_interpreter import SaliencyInterpreter
from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.nn import util


@SaliencyInterpreter.register('simple-gradient')
Expand Down Expand Up @@ -54,9 +54,7 @@ def _register_forward_hook(self, embeddings_list: List):
def forward_hook(module, inputs, output): # pylint: disable=unused-argument
embeddings_list.append(output.squeeze(0).clone().detach().numpy())

handle = None
for module in self.predictor._model.modules():
if isinstance(module, TextFieldEmbedder):
handle = module.register_forward_hook(forward_hook)
embedding_layer = util.find_embedding_layer(self.predictor._model)
handle = embedding_layer.register_forward_hook(forward_hook)

return handle
9 changes: 3 additions & 6 deletions allennlp/interpret/saliency_interpreters/smooth_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import Instance
from allennlp.interpret.saliency_interpreters.saliency_interpreter import SaliencyInterpreter
from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.predictors import Predictor
from allennlp.nn import util

@SaliencyInterpreter.register('smooth-gradient')
class SmoothGradient(SaliencyInterpreter):
Expand Down Expand Up @@ -58,11 +58,8 @@ def forward_hook(module, inputs, output): # pylint: disable=unused-argument
output.add_(noise)

# Register the hook
handle = None
for module in self.predictor._model.modules():
if isinstance(module, TextFieldEmbedder):
handle = module.register_forward_hook(forward_hook)

embedding_layer = util.find_embedding_layer(self.predictor._model)
handle = embedding_layer.register_forward_hook(forward_hook)
return handle

def _smooth_grads(self, instance: Instance) -> Dict[str, numpy.ndarray]:
Expand Down
9 changes: 7 additions & 2 deletions allennlp/tests/data/fields/span_field_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
class TestSpanField(AllenNlpTestCase):
def setUp(self):
super(TestSpanField, self).setUp()
self.indexers = {"words": SingleIdTokenIndexer("words")}
self.text = TextField([Token(t) for t in ["here", "is", "a", "sentence", "for", "spans", "."]],
{"words": SingleIdTokenIndexer("words")})
self.indexers)

def test_as_tensor_converts_span_field_correctly(self):
span_field = SpanField(2, 3, self.text)
Expand Down Expand Up @@ -44,7 +45,11 @@ def test_printing_doesnt_crash(self):
def test_equality(self):
span_field1 = SpanField(2, 3, self.text)
span_field2 = SpanField(2, 3, self.text)
span_field3 = SpanField(2, 3, TextField([Token(t) for t in ['not', 'the', 'same', 'tokens']],
self.indexers))

assert span_field1 == (2, 3)
assert span_field1 == span_field1
assert span_field1 != span_field2
assert span_field1 == span_field2
assert span_field1 != span_field3
assert span_field2 != span_field3
Loading

0 comments on commit 78ee3d8

Please sign in to comment.