Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Revert batching for input reduction (#3276)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-gardner authored Sep 24, 2019
1 parent 052e8d3 commit 2a95022
Showing 1 changed file with 5 additions and 24 deletions.
29 changes: 5 additions & 24 deletions allennlp/interpret/attackers/input_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,31 +80,12 @@ def get_length(input_instance: Instance):
return len(input_text_field.tokens)
candidates = heapq.nsmallest(self.beam_size, candidates, key=lambda x: get_length(x[0]))

# predictor.get_gradients is where the most expensive computation happens, so we're
# going to do it in a batch, up front, before iterating over the results.
copied_candidates = deepcopy(candidates)
all_grads, all_outputs = self.predictor.get_gradients([x[0] for x in copied_candidates])

# The output in `all_grads` and `all_outputs` is batched in a dictionary (e.g.,
# {'grad_output_1': batched_tensor}). We need to split this into a list of non-batched
# dictionaries that we can iterate over.
split_grads = []
for i in range(len(copied_candidates)):
split_grads.append({key: value[i] for key, value in all_grads.items()})
split_outputs = []
for i in range(len(copied_candidates)):
instance_outputs = {}
for key, value in all_outputs.items():
if key == 'loss':
continue
instance_outputs[key] = value[i]
split_outputs.append(instance_outputs)
beam_candidates = [(x[0], x[1], x[2], split_grads[i], split_outputs[i])
for i, x in enumerate(copied_candidates)]

beam_candidates = deepcopy(candidates)
candidates = []
for beam_instance, smallest_idx, tag_mask, grads, outputs in beam_candidates:
for beam_instance, smallest_idx, tag_mask in beam_candidates:
# get gradients and predictions
beam_tag_mask = deepcopy(tag_mask)
grads, outputs = self.predictor.get_gradients([beam_instance])

for output in outputs:
if isinstance(outputs[output], torch.Tensor):
Expand Down Expand Up @@ -133,7 +114,7 @@ def get_length(input_instance: Instance):
current_tokens = deepcopy(text_field.tokens)
reduced_instances_and_smallest = _remove_one_token(beam_instance,
input_field_to_attack,
grads[grad_input_field],
grads[grad_input_field][0],
ignore_tokens,
self.beam_size,
beam_tag_mask)
Expand Down

0 comments on commit 2a95022

Please sign in to comment.