-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TinyBERT data augmentation #1923
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very much looking forward to the results of the first experiments! I thought I leave some feedback although I understand this is still a draft.
Let's compare different glove embedding models and maybe even fasttext in the experiments too at some point later. I would also be interested to learn how often the replaced words are single-piece words (BERT is used) or multiple-piece words (glove is used).
haystack/utils/augment_squad.py
Outdated
@@ -0,0 +1,178 @@ | |||
""" | |||
Script to perform data augmentation on a SQuAD like dataset to increase training data. It follows the approach oultined in the TinyBERT paper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a link to the paper here as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the link now.
haystack/utils/augment_squad.py
Outdated
""" | ||
Script to perform data augmentation on a SQuAD like dataset to increase training data. It follows the approach oultined in the TinyBERT paper. | ||
Usage: | ||
python augment_squad.py --squad_path <squad_path> --output_path <output_patn> \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: patn -> path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now fixed.
def load_glove(glove_path: Path = Path("glove.txt"), vocab_size: int = 100_000): | ||
if not glove_path.exists(): | ||
zip_path = glove_path.parent / (glove_path.name + ".zip") | ||
request = requests.get("https://nlp.stanford.edu/data/glove.42B.300d.zip", allow_redirects=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's compare performance of https://nlp.stanford.edu/data/glove.840B.300d.zip and https://nlp.stanford.edu/data/glove.6B.zip
It would be also interesting to see whether fasttext performs better than glove when used for data augmentation. We need a way to use data augmentation for non-English datasets as well and fasttext could help with that.
import fasttext
ft = fasttext.load_model('cc.en.300.bin')
#german model: https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.de.300.bin.gz
ft.get_nearest_neighbors('hello')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this would be interesting to see. I think with a model trained using whole word masking we could perhaps even try not using this additional step at all, however, it would probably be quite costly to test this.
Distilling SQuAD the way they did for the original results would be equivalent timewise to about 100 epochs on an unaugmented dataset. 100 of those epochs would take about 100*45min=75h. In addition to that, data augmentation takes about 45h because you need to do a forward pass for each word.
Although I am trying to improve data augmentation speed and we could use a smaller model as student, trying out different tokenizer is probably not worth it as about 94% of all words can be replaced using BERT (tested with this dataset).
haystack/utils/augment_squad.py
Outdated
possible_words.append([word] + tokenizer.convert_ids_to_tokens(ranking)) | ||
|
||
batch_index += 1 | ||
elif word in glove_word_id_mapping: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is where we could try fasttext. The elif will become an else then because there are no out-of-vocabulary issues with fasttext
import fasttext
ft = fasttext.load_model('cc.en.300.bin')
#german model: https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.de.300.bin.gz
ft.get_nearest_neighbors('hello')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That could be useful as 43% of the times when BERT can't be used glove doesn't work either. (Same test on this dataset as above)
However, I'm not so sure how great the difference would really be as this would still only be the case in 2.6% of all cases at the cost of being comparable to the original paper or at the cost of additional distillation runs (as explained above they take up a lot of time).
haystack/utils/augment_squad.py
Outdated
parser.add_argument("--replace_probability", type=float, default=0.4, help="Probability of replacing a word") | ||
parser.add_argument("--glove_path", type=Path, default="glove.txt", help="Path to the glove file") | ||
|
||
model = BertForMaskedLM.from_pretrained("bert-base-uncased") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's pass the model/tokenizer name as an argument to the script as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added these arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two points that I would like to talk about before merging this PR. First, let's talk about whether is_impossible
should be set to True
and why. Second, I came up with an idea for a test case to test the script end-to-end (at least the number of generated questions and the format of the generated squad file).
haystack/utils/augment_squad.py
Outdated
for topic in tqdm(squad["data"]): | ||
paragraphs = [] | ||
for paragraph in topic["paragraphs"]: | ||
# make every question unanswerable as answer strings will probably match and aren't relevant for distillation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand this comment correctly. I understand that answer strings won't be relevant for distillation because we will make predictions with the teacher model anyway. However, what do you mean with "answer strings will probably match"? Why do we want to set is_impossible
to True
? That would result in this question being handled as not answerable. Couldn't we leave question["answers"] = []
as is but have question["is_impossible"] = False
?
|
||
args = parser.parse_args() | ||
|
||
augment_squad(**vars(args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding testing: I would suggest that for a test we load a small/tiny squad file with SquadData and count the number of questions with
haystack/haystack/utils/squad_data.py
Line 146 in 13510aa
def count(self, unit="questions"): |
Next step is to run the
augment_squad()
method and in the end load the result again with SquadData and count again to see whether the size of the dataset was multiplied as expected by multiplication_factor
. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let's wait for the tests and merge if all of them are green.
Proposed changes:
This adds TinyBERT data augmentation as described in #1874.
In its current form, it is very separate from the rest of the codebase as there wasn't an existing haystack abstraction which seemed appropriate.
DataSilo
was considered, but this didn't seem to be in its spirit. Integrating it indistil_from
, for example, would also probably need a lot of additional parameters that you wouldn't expect in this method.Currently the following works:
Status (please check what you already did):
closes #1874