-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding distillation loss functions from TinyBERT #1879
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very good already. Just some smaller changes requested. Most interesting for you is the missing return
keyword in haystack/nodes/reader/farm.py
I guess. Happy to jump on a quick call in the afternoon if you want to discuss something.
def test_tinybert_distillation(): | ||
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D") | ||
teacher = FARMReader(model_name_or_path="bert-base-uncased") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use as smaller teacher model here to speed up the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be theoretically possible, but the teacher model would need to have the exact same dimensions except for the number of layers which would need to be a multiple of the number of student layers. This means it is quite hard to find a matching model. If it is a big performance issue we could perhaps create our own "mock model" with the right parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, let's keep it as it is for now. 👍
:return: None | ||
""" | ||
if tinybert_loss: | ||
self._training_procedure(data_dir=data_dir, train_filename=train_filename, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is a return
missing here in front of self._training_procedure(...
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, task specific distillation for TinyBERT has two stages and the second stage is the same as what we have already implemented. So calling _training_procedure with tinybert=True only executes the first stage. I have added a short comment explaining that.
@@ -1,9 +1,6 @@ | |||
from typing import Optional, Union, Tuple, List, Callable | |||
|
|||
from typing import TYPE_CHECKING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain why we got rid of these lines so that I understand a bit better? _LRScheduler
is A response to this comment would be fine. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for DistillationTrainer turned out to be wrong. Because of that I don't need to import TYPE_CHECKING anymore as it's just necessary for preventing the circular import of FARMReader. There was never really a reason to also use that for _LRScheduler
so `_LRScheduler can just be imported normally.
@@ -630,7 +627,7 @@ class DistillationTrainer(Trainer): | |||
""" | |||
def __init__( | |||
self, | |||
model: "FARMReader", | |||
model: "AdaptiveModel", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the code ready to use other models than FARMReader in its current form?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can basically train any AdaptiveModel with a QA prediction_head
. I changed this line because I realised that _training_procedure only passes the AdaptiveModel. This behavior is exactly the same for the normal Trainer
class.
@@ -484,6 +484,8 @@ def forward( | |||
input_ids: torch.Tensor, | |||
segment_ids: torch.Tensor, | |||
padding_mask: torch.Tensor, | |||
output_hidden_states: bool = False, | |||
output_attentions: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add docstrings for these new parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I have added the docstrings.
sequence_output, pooled_output = output_tuple[0], output_tuple[1] | ||
return sequence_output, pooled_output | ||
return output_tuple | ||
# if self.model.encoder.config.output_hidden_states == True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check the commented code. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have now deleted the commented code. It is unnecessary as output tuple is now handled by HuggingFace transformers.
@@ -356,7 +356,7 @@ def prepare_labels(self, **kwargs): | |||
all_labels.append(labels) | |||
return all_labels | |||
|
|||
def forward(self, **kwargs): | |||
def forward(self, output_hidden_states: bool = False, output_attentions: bool = False, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the doc strings for the new parameters here as well, e.g.:
:param output_hidden_states: Whether to output hidden states
:param output_attentions: Whether to output attentions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the doc strings.
Proposed changes:
This adds the distillation loss functions from TinyBERT as explained in #1873.
Status (please check what you already did):
This adds two parameters to the
distil_from
method. Enabling the parametertinybert_loss
adds an additional distillation stage before the original one.tinybert_epochs
specifies the number of epochs in this stage. The stage is realised using a new TinyBERTDistillationTrainer that computes the teacher hidden states and attention on the fly.Caching of the teacher is not used as this would take up too much memory (100s to 1000s of gigabytes). This means that the standard DistillationTrainer can be used.