Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
ensure transformer params are frozen at initialization when train_par…
Browse files Browse the repository at this point in the history
…ameters is false (#4377)

* ensure transformer params are frozen at initialization

* update CHANGELOG

* removed un-needed member var
  • Loading branch information
epwalsh authored Jun 18, 2020
1 parent 3e8a9ef commit e52b751
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 44 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Reduced the amount of log messages produced by `allennlp.common.file_utils`.
- Fixed a bug where `PretrainedTransformerEmbedder` parameters appeared to be trainable
in the log output even when `train_parameters` was set to `False`.

## [v1.0.0](/~https://github.com/allenai/allennlp/releases/tag/v1.0.0) - 2020-06-16

Expand Down
85 changes: 41 additions & 44 deletions allennlp/modules/token_embedders/pretrained_transformer_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,16 @@ def __init__(
# I'm not sure if this works for all models; open an issue on github if you find a case
# where it doesn't work.
self.output_dim = self.config.hidden_size
self._train_parameters = train_parameters

tokenizer = PretrainedTransformerTokenizer(model_name)
self._num_added_start_tokens = len(tokenizer.single_sequence_start_tokens)
self._num_added_end_tokens = len(tokenizer.single_sequence_end_tokens)
self._num_added_tokens = self._num_added_start_tokens + self._num_added_end_tokens

if not train_parameters:
for param in self.transformer_model.parameters():
param.requires_grad = False

@overrides
def get_output_dim(self):
return self.output_dim
Expand Down Expand Up @@ -102,51 +105,45 @@ def forward(
Shape: `[batch_size, num_wordpieces, embedding_size]`.
"""
# Some of the huggingface transformers don't support type ids at all and crash when you supply
# them. For others, you can supply a tensor of zeros, and if you don't, they act as if you did.
# There is no practical difference to the caller, so here we pretend that one case is the same
# as another case.
if type_ids is not None:
max_type_id = type_ids.max()
if max_type_id == 0:
type_ids = None
else:
if max_type_id >= self._number_of_token_type_embeddings():
raise ValueError("Found type ids too large for the chosen transformer model.")
assert token_ids.shape == type_ids.shape

fold_long_sequences = self._max_length is not None and token_ids.size(1) > self._max_length
if fold_long_sequences:
batch_size, num_segment_concat_wordpieces = token_ids.size()
token_ids, segment_concat_mask, type_ids = self._fold_long_sequences(
token_ids, segment_concat_mask, type_ids
)

with torch.set_grad_enabled(self._train_parameters):
# Some of the huggingface transformers don't support type ids at all and crash when you supply
# them. For others, you can supply a tensor of zeros, and if you don't, they act as if you did.
# There is no practical difference to the caller, so here we pretend that one case is the same
# as another case.
if type_ids is not None:
max_type_id = type_ids.max()
if max_type_id == 0:
type_ids = None
else:
if max_type_id >= self._number_of_token_type_embeddings():
raise ValueError(
"Found type ids too large for the chosen transformer model."
)
assert token_ids.shape == type_ids.shape

fold_long_sequences = (
self._max_length is not None and token_ids.size(1) > self._max_length
transformer_mask = segment_concat_mask if self._max_length is not None else mask
# Shape: [batch_size, num_wordpieces, embedding_size],
# or if self._max_length is not None:
# [batch_size * num_segments, self._max_length, embedding_size]

# We call this with kwargs because some of the huggingface models don't have the
# token_type_ids parameter and fail even when it's given as None.
# Also, as of transformers v2.5.1, they are taking FloatTensor masks.
parameters = {"input_ids": token_ids, "attention_mask": transformer_mask.float()}
if type_ids is not None:
parameters["token_type_ids"] = type_ids
embeddings = self.transformer_model(**parameters)[0]

if fold_long_sequences:
embeddings = self._unfold_long_sequences(
embeddings, segment_concat_mask, batch_size, num_segment_concat_wordpieces
)
if fold_long_sequences:
batch_size, num_segment_concat_wordpieces = token_ids.size()
token_ids, segment_concat_mask, type_ids = self._fold_long_sequences(
token_ids, segment_concat_mask, type_ids
)

transformer_mask = segment_concat_mask if self._max_length is not None else mask
# Shape: [batch_size, num_wordpieces, embedding_size],
# or if self._max_length is not None:
# [batch_size * num_segments, self._max_length, embedding_size]

# We call this with kwargs because some of the huggingface models don't have the
# token_type_ids parameter and fail even when it's given as None.
# Also, as of transformers v2.5.1, they are taking FloatTensor masks.
parameters = {"input_ids": token_ids, "attention_mask": transformer_mask.float()}
if type_ids is not None:
parameters["token_type_ids"] = type_ids
embeddings = self.transformer_model(**parameters)[0]

if fold_long_sequences:
embeddings = self._unfold_long_sequences(
embeddings, segment_concat_mask, batch_size, num_segment_concat_wordpieces
)

return embeddings

return embeddings

def _fold_long_sequences(
self,
Expand Down

0 comments on commit e52b751

Please sign in to comment.