Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

upgrade to pytorch 1.2 #3182

Merged
merged 16 commits into from
Aug 24, 2019
Merged
27 changes: 24 additions & 3 deletions allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def ensure_model_can_train_save_and_load(self,
tolerance: float = 1e-4,
cuda_device: int = -1,
gradients_to_ignore: Set[str] = None,
overrides: str = ""):
overrides: str = "",
disable_dropout: bool = True):
"""
Parameters
----------
Expand All @@ -68,6 +69,9 @@ def ensure_model_can_train_save_and_load(self,
infrequently-used parameters that are hard to force the model to use in a small test).
overrides : ``str``, optional (default = "")
A JSON string that we will use to override values in the input parameter file.
disable_dropout : ``bool``, optional (default = True)
If True we will set all dropout to 0 before checking gradients. (Otherwise, with small
datasets, you may get zero gradients because of unlucky dropout.)
"""
save_dir = self.TEST_DIR / "save_and_load_test"
archive_file = save_dir / "model.tar.gz"
Expand Down Expand Up @@ -103,7 +107,7 @@ def ensure_model_can_train_save_and_load(self,

# Check gradients are None for non-trainable parameters and check that
# trainable parameters receive some gradient if they are trainable.
self.check_model_computes_gradients_correctly(model, model_batch, gradients_to_ignore)
self.check_model_computes_gradients_correctly(model, model_batch, gradients_to_ignore, disable_dropout)

# The datasets themselves should be identical.
assert model_batch.keys() == loaded_batch.keys()
Expand Down Expand Up @@ -167,9 +171,20 @@ def assert_fields_equal(self, field1, field2, name: str, tolerance: float = 1e-6
@staticmethod
def check_model_computes_gradients_correctly(model: Model,
model_batch: Dict[str, Union[Any, Dict[str, Any]]],
params_to_ignore: Set[str] = None):
params_to_ignore: Set[str] = None,
disable_dropout: bool = True):
print("Checking gradients")
model.zero_grad()

original_dropouts: Dict[str, float] = {}

if disable_dropout:
# Remember original dropouts so we can restore them.
for name, module in model.named_modules():
if isinstance(module, torch.nn.Dropout):
original_dropouts[name] = getattr(module, 'p')
setattr(module, 'p', 0)

result = model(**model_batch)
result["loss"].backward()
has_zero_or_none_grads = {}
Expand Down Expand Up @@ -197,6 +212,12 @@ def check_model_computes_gradients_correctly(model: Model,
print(f"Parameter: {name} had incorrect gradient: {grad}")
raise Exception("Incorrect gradients found. See stdout for more info.")

# Now restore dropouts if we disabled them.
if disable_dropout:
for name, module in model.named_modules():
if name in original_dropouts:
setattr(module, 'p', original_dropouts[name])

def ensure_batch_predictions_are_consistent(
self,
keys_to_ignore: Iterable[str] = ()):
Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/biaffine_dependency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def _greedy_decode(self,
attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf))
# Mask padded tokens, because we only want to consider actual words as heads.
if mask is not None:
minus_mask = (1 - mask).byte().unsqueeze(2)
minus_mask = (1 - mask).to(dtype=torch.bool).unsqueeze(2)
attended_arcs.masked_fill_(minus_mask, -numpy.inf)

# Compute the heads greedily.
Expand Down
5 changes: 4 additions & 1 deletion allennlp/models/crf_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def forward(self, # type: ignore
if tags is not None:
# Add negative log-likelihood as loss
log_likelihood = self.crf(logits, tags, mask)
output["loss"] = -log_likelihood

# It's not clear why, but pylint seems to think `log_likelihood` is tuple
# (in fact, it's a torch.Tensor), so we need a disable.
output["loss"] = -log_likelihood # pylint: disable=invalid-unary-operand-type

# Represent viterbi tags as "class probabilities" that we can
# feed into the metrics
Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/graph_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _greedy_decode(arc_scores: torch.Tensor,
# shape (batch_size, sequence_length, sequence_length, num_tags)
arc_tag_logits = arc_tag_logits + inf_diagonal_mask.unsqueeze(0).unsqueeze(-1)
# Mask padded tokens, because we only want to consider actual word -> word edges.
minus_mask = (1 - mask).byte().unsqueeze(2)
minus_mask = (1 - mask).to(dtype=torch.bool).unsqueeze(2)
arc_scores.masked_fill_(minus_mask, -numpy.inf)
arc_tag_logits.masked_fill_(minus_mask.unsqueeze(-1), -numpy.inf)
# shape (batch_size, sequence_length, sequence_length)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _get_target_token_embeddings(self,
mask: torch.Tensor,
direction: int) -> torch.Tensor:
# Need to shift the mask in the correct direction
zero_col = token_embeddings.new_zeros(mask.size(0), 1).byte()
zero_col = token_embeddings.new_zeros(mask.size(0), 1).to(dtype=torch.bool)
if direction == 0:
# forward direction, get token to right
shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/modules/seq2seq_encoders/gated_cnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def forward(self,

# We need to broadcast the mask to feature dimension,
# and to use masked_fill_ we need the inverse of the mask.
mask_for_fill = (1 - mask).unsqueeze(1).byte()
mask_for_fill = (1 - mask).unsqueeze(1).to(dtype=torch.bool)

if self._return_all_layers:
# outputs will be [[all forward layers], [all backward layers]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(self, # pylint: disable=arguments-differ
# also mask out positions corresponding to oov
mask *= (inputs != self._oov_idx).long()
for document, doc_mask in zip(inputs, mask):
document = torch.masked_select(document, doc_mask.byte())
document = torch.masked_select(document, doc_mask.to(dtype=torch.bool))
vec = torch.bincount(document, minlength=self.vocab_size).float()
vec = vec.view(1, -1)
bag_of_words_vectors.append(vec)
Expand Down
10 changes: 5 additions & 5 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def masked_softmax(vector: torch.Tensor,
result = result * mask
result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
else:
masked_vector = vector.masked_fill((1 - mask).byte(), mask_fill_value)
masked_vector = vector.masked_fill((1 - mask).to(dtype=torch.bool), mask_fill_value)
result = torch.nn.functional.softmax(masked_vector, dim=dim)
return result

Expand Down Expand Up @@ -334,7 +334,7 @@ def masked_max(vector: torch.Tensor,
-------
A ``torch.Tensor`` of including the maximum values.
"""
one_minus_mask = (1.0 - mask).byte()
one_minus_mask = (1.0 - mask).to(dtype=torch.bool)
replaced_vector = vector.masked_fill(one_minus_mask, min_val)
max_value, _ = replaced_vector.max(dim=dim, keepdim=keepdim)
return max_value
Expand Down Expand Up @@ -365,7 +365,7 @@ def masked_mean(vector: torch.Tensor,
-------
A ``torch.Tensor`` of including the mean values.
"""
one_minus_mask = (1.0 - mask).byte()
one_minus_mask = (1.0 - mask).to(dtype=torch.bool)
replaced_vector = vector.masked_fill(one_minus_mask, 0.0)

value_sum = torch.sum(replaced_vector, dim=dim, keepdim=keepdim)
Expand Down Expand Up @@ -773,11 +773,11 @@ def replace_masked_values(tensor: torch.Tensor, mask: torch.Tensor, replace_with

This just does ``tensor.masked_fill()``, except the pytorch method fills in things with a mask
value of 1, where we want the opposite. You can do this in your own code with
``tensor.masked_fill((1 - mask).byte(), replace_with)``.
``tensor.masked_fill((1 - mask).to(dtype=torch.bool), replace_with)``.
"""
if tensor.dim() != mask.dim():
raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
return tensor.masked_fill((1 - mask).byte(), replace_with)
return tensor.masked_fill((1 - mask).to(dtype=torch.bool), replace_with)


def tensors_equal(tensor1: torch.Tensor, tensor2: torch.Tensor, tolerance: float = 1e-12) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions allennlp/tests/commands/docstring_help_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import re
import subprocess

import pytest

import allennlp.commands
from allennlp.common.testing import AllenNlpTestCase


@pytest.mark.skip(reason="This test is slow and somewhat fragile and doesn't need to run every commit.")
class TestDocstringHelp(AllenNlpTestCase):
RE_DOCSTRING_CALL_SUBCOMMAND_HELP = re.compile(r'^\s*\$ (allennlp \S+ --help)$', re.MULTILINE)
RE_STARTS_WITH_INDENTATION = re.compile(r'^ {4}', re.MULTILINE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
},
"iterator": {
"type": "same_language",
"batch_size": 8,
"batch_size": 10,
"sorting_keys": [["words", "num_tokens"]],
"instances_per_epoch": 8
},
Expand Down
2 changes: 1 addition & 1 deletion allennlp/tests/fixtures/naqanet/experiment.json
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
},
"iterator": {
"type": "basic",
"batch_size": 16
"batch_size": 8
},
"trainer": {
"num_epochs": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def setUp(self):
self.FIXTURES_ROOT / "data" / "dependencies_multilang" / "*")

def test_dependency_parser_can_save_and_load(self):
self.ensure_model_can_train_save_and_load(self.param_file)
self.ensure_model_can_train_save_and_load(self.param_file, gradients_to_ignore={"arc_attention._bias"})

def test_mst_decoding_can_run_forward(self):
self.model.use_mst_decoding_for_validation = True
self.ensure_model_can_train_save_and_load(self.param_file)
self.ensure_model_can_train_save_and_load(self.param_file, gradients_to_ignore={"arc_attention._bias"})
4 changes: 2 additions & 2 deletions allennlp/training/metrics/attachment_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def __call__(self, # type: ignore
gold_indices = gold_indices.long()
gold_labels = gold_labels.long()

# Multiply by a mask donoting locations of
# Multiply by a mask denoting locations of
# gold labels which we should ignore.
for label in self._ignore_classes:
label_mask = gold_labels.eq(label)
mask = mask * (1 - label_mask).long()
mask = mask * (~label_mask).long()

correct_indices = predicted_indices.eq(gold_indices).long() * mask
unlabeled_exact_match = (correct_indices + (1 - mask)).prod(dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __call__(self,
if mask is None:
batch_size = gold_labels.shape[0]
mask = torch.ones(batch_size)
mask = mask.byte()
mask = mask.to(dtype=torch.bool)

self._all_predictions = torch.cat([self._all_predictions,
torch.masked_select(predictions, mask).float()], dim=0)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _get_brevity_penalty(self) -> float:
return math.exp(1.0 - self._reference_lengths / self._prediction_lengths)

def _get_valid_tokens_mask(self, tensor: torch.LongTensor) -> torch.ByteTensor:
valid_tokens_mask = torch.ones(tensor.size(), dtype=torch.uint8)
valid_tokens_mask = torch.ones(tensor.size(), dtype=torch.bool)
for index in self._exclude_indices:
valid_tokens_mask = valid_tokens_mask & (tensor != index)
return valid_tokens_mask
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/metrics/fbeta_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __call__(self,

if mask is None:
mask = torch.ones_like(gold_labels)
mask = mask.to(torch.uint8)
mask = mask.to(dtype=torch.bool)
gold_labels = gold_labels.float()

argmax_predictions = predictions.max(dim=-1)[1].float()
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

# This installs Pytorch for CUDA 8 only. If you are using a newer version,
# please visit https://pytorch.org/ and install the relevant version.
# For now AllenNLP works with both PyTorch 1.0 and 0.4.1. Expect that in
# the future only >=1.0 will be supported.
torch>=0.4.1,<1.2
# allennlp>0.8.5 requires PyTorch 1.2 or greater. If you need to use
# an older version of PyTorch you'll also need to use an older version of allennlp.
torch>=1.2.0

# Parameter parsing (but not on Windows).
jsonnet>=0.10.0 ; sys.platform != 'win32'
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
packages=find_packages(exclude=["*.tests", "*.tests.*",
"tests.*", "tests"]),
install_requires=[
'torch>=0.4.1,<1.2',
'torch>=1.2.0',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we strictly need features from 1.2.0 or is 1.0 sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR is in response to breaking changes that happened between 1.1 and 1.2, but also going forward we'd like to use PyTorch's dataloader infrastructure, and for that to support allennlp datasets we need torch>=1.2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for explaining.

"jsonnet>=0.10.0 ; sys.platform != 'win32'",
'overrides',
'nltk',
Expand Down