This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for bart in allennlp-models (#4169)
* Support for bart in allennlp-models - Added option to PretrainedTransformerEmbedder to allow usage with encoder-decoder models and created unit test. - Added ROUGE-N metric. ROUGE-L will follow soon - Added indices to tokens abstract method to TokenIndexer. Implemented for PretrainedTransformerIndexer. This is useful for turning decoded sequences in seq2seq models into text. - added timestep parameter to step function in beamsearch - other minor changes * Implemented ROUGE-L, updated ROUGE-N, new tests - Implemented ROUGE-L metric (F1 score) - Implemented ROUGE-N recall, precision and F1 as metrics that can be accessed separately - Now computing overall ROUGE-N/L as average over scores of each sequence pair, rather than summing counts across all pairs and then computing the metric - added tests for new padding behavior in get_text_field_mask - added test for ROUGE-N/L - stylistic improvements * Polynomial lr scheduling, max tokens batch sampling, other small changes - Implemented Polynomial learning rate scheduling, which is used in BART. The implementeation is based on Fairseqs and tensorflows implementation. - Implementation an option to specify the number of maximum tokens per batch, rather than specifiying a fixed batch size. This is also used for fine-tuning BART. Added a unit test too. - For indices_to_tokens, removed code that removes the cls/sep tokens introduced by max length. Added a test to reflect this. * Small stylistic changes * Added documentation, separated max tokens sampler, fixed circular important, memory tracking per batch, polynomical lr decay bug fix - Added documentation for lazy_groups_of_max_size - Some stylistic changes - Made MaxTokensBatchSampler a subclass of BucketBatchSampler - Annotated beam search with no grad - fixed bug in poly decay related to lr of first batch - fixed circular import, finally - added gpu/cpu memory tracking for tensorboard for batches (previously this was only possible for epochs) * Fixed linting errors, fixed rouge test - TODO: fix `TestPretrainedTransformerEmbedder.test_encoder_decoder_model` and TestPretrainedTransformerIndexer.test_indices_to_tokens. Both issues are related to the new tokenizers * Fixed issues with new tokenizers - fixed issue with roberta based tokenizers in pretrained_transformer_indexer - temporary fix for incorrect types ids when using max length for tokens_to_indices in PretrainedTransformerIndexer - fixed indexer test to not compare idx and idx_end of Tokens * Added max tokens batch sampler to __init__.py * Fixed max tokens sampler to account for padding * Fixed large batches due to short source sequences but long target sequences in max tokens batch sampler * Formatting * Filled in the changelog * Tests have moved * Fix docs * Adds a test for the max tokens sampler * Adds warning when a single instance is too big * More docs changes * Formatting * Docs * Fix old models * Fixed linting and type checking errors * Fix docs build * Fix circular imports Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
- Loading branch information
Showing
29 changed files
with
980 additions
and
166 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import logging | ||
import random | ||
from typing import List, Iterable, Optional, Iterator, TypeVar | ||
|
||
from allennlp.data.samplers import BatchSampler, BucketBatchSampler | ||
from torch.utils import data | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
A = TypeVar("A") | ||
|
||
|
||
@BatchSampler.register("max_tokens_sampler") | ||
class MaxTokensBatchSampler(BucketBatchSampler): | ||
""" | ||
An sampler which by default, argsorts batches with respect to the maximum input lengths `per | ||
batch`. Batches are then created such that the number of tokens in a batch does not exceed the given | ||
maximum number of tokens. You can provide a list of field names and padding keys (or pass none, in which case | ||
they will be inferred) which the dataset will be sorted by before doing this batching, causing inputs | ||
with similar length to be batched together, making computation more efficient (as less time is | ||
wasted on padded elements of the batch). | ||
# Parameters | ||
data_source: `data.Dataset` | ||
The pytorch `Dataset` of allennlp Instances to bucket. | ||
max_tokens : `int` | ||
The maximum number of tokens to include in a batch. | ||
sorting_keys : `List[str]`, optional | ||
To bucket inputs into batches, we want to group the instances by padding length, so that we | ||
minimize the amount of padding necessary per batch. In order to do this, we need to know | ||
which fields need what type of padding, and in what order. | ||
Specifying the right keys for this is a bit cryptic, so if this is not given we try to | ||
auto-detect the right keys by iterating through a few instances upfront, reading all of the | ||
padding keys and seeing which one has the longest length. We use that one for padding. | ||
This should give reasonable results in most cases. Some cases where it might not be the | ||
right thing to do are when you have a `ListField[TextField]`, or when you have a really | ||
long, constant length `ArrayField`. | ||
When you need to specify this yourself, you can create an instance from your dataset and | ||
call `Instance.get_padding_lengths()` to see a list of all keys used in your data. You | ||
should give one or more of those as the sorting keys here. | ||
padding_noise : `float`, optional (default = `0.1`) | ||
When sorting by padding length, we add a bit of noise to the lengths, so that the sorting | ||
isn't deterministic. This parameter determines how much noise we add, as a percentage of | ||
the actual padding value for each instance. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_source: data.Dataset, | ||
max_tokens: Optional[int] = None, | ||
sorting_keys: List[str] = None, | ||
padding_noise: float = 0.1, | ||
): | ||
super().__init__(data_source, -1, sorting_keys, padding_noise, False) | ||
|
||
self.max_tokens = max_tokens | ||
|
||
def _lazy_groups_of_max_size( | ||
self, iterable: Iterable[A], sizes: Iterable[int], | ||
) -> Iterator[List[A]]: | ||
""" | ||
Takes an `iterable` of data and an iterable `sizes` of the same length which represents the sizes of each | ||
corresponding item in `iterable`. The instances from `iterable` are batched such that the total size | ||
of the batch as computed from `sizes` does not exceed `max_size`. | ||
""" | ||
cur_max_size = 0 | ||
group: List[A] = [] | ||
|
||
iterator = iter(iterable) | ||
size_iter = iter(sizes) | ||
|
||
for item, size in zip(iterator, size_iter): | ||
if size > self.max_tokens: | ||
logger.warning( | ||
"Found instance of size %d, which is bigger than the expected size for a batch (%d)", | ||
size, | ||
self.max_tokens, | ||
) | ||
group_size = max(size, cur_max_size) * (len(group) + 1) | ||
|
||
if group_size > self.max_tokens: | ||
yield group | ||
cur_max_size = 0 | ||
group = [] | ||
|
||
group.append(item) | ||
cur_max_size = max(cur_max_size, size) | ||
|
||
if len(group) != 0: | ||
yield group | ||
|
||
def __iter__(self) -> Iterable[List[int]]: | ||
indices, lengths = self._argsort_by_padding(self.data_source) | ||
|
||
max_lengths = [max(length) for length in lengths] | ||
group_iterator = self._lazy_groups_of_max_size(indices, max_lengths) | ||
|
||
batches = [list(group) for group in group_iterator] | ||
random.shuffle(batches) | ||
for batch in batches: | ||
yield batch | ||
|
||
def __len__(self): | ||
# There is no easy way to count the number of batches, so we need to iterate and count. | ||
return sum(1 for _ in self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.