diff --git a/docs/_src/api/api/reader.md b/docs/_src/api/api/reader.md index 068584d352..a0d8b35fe1 100644 --- a/docs/_src/api/api/reader.md +++ b/docs/_src/api/api/reader.md @@ -157,23 +157,26 @@ If any checkpoints are stored, a subsequent run of train() will resume training None - -#### distil\_from + +#### distil\_prediction\_layer\_from ```python - | distil_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0, tinybert_loss: bool = False, tinybert_epochs: int = 1) + | distil_prediction_layer_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 3e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0) ``` -Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset -and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model +Fine-tune a model on a QA dataset using logit-based distillation. You need to provide a teacher model that is already finetuned on the dataset +and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model. using a more complex teacher. +Originally proposed in: https://arxiv.org/pdf/1503.02531.pdf +This can also be considered as the second stage of distillation finetuning as described in the TinyBERT paper: +https://arxiv.org/pdf/1909.10351.pdf **Example** ```python student = FARMReader(model_name_or_path="prajjwal1/bert-medium") teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2") -student.distil_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json", +student.distil_prediction_layer_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json", learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5) ``` @@ -222,6 +225,75 @@ If any checkpoints are stored, a subsequent run of train() will resume training - `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. - `tinybert_loss`: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased. - `tinybert_epochs`: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function. +- `tinybert_learning_rate`: Learning rate to use when training the student model with the TinyBERT loss function. +- `tinybert_train_filename`: Filename of training data to use when training the student model with the TinyBERT loss function. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script. If not specified, the training data from the original training is used. + +**Returns**: + +None + + +#### distil\_intermediate\_layers\_from + +```python + | distil_intermediate_layers_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 5, learning_rate: float = 5e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse", temperature: float = 1.0) +``` + +The first stage of distillation finetuning as described in the TinyBERT paper: +https://arxiv.org/pdf/1909.10351.pdf + +**Example** +```python +student = FARMReader(model_name_or_path="prajjwal1/bert-medium") +teacher = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D") + +student.distil_intermediate_layers_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json", + learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5) +``` + +Checkpoints can be stored via setting `checkpoint_every` to a custom number of steps. +If any checkpoints are stored, a subsequent run of train() will resume training from the latest available checkpoint. + +**Arguments**: + +- `teacher_model`: Model whose logits will be used to improve accuracy +- `data_dir`: Path to directory containing your training data in SQuAD style +- `train_filename`: Filename of training data. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script +- `dev_filename`: Filename of dev / eval data +- `test_filename`: Filename of test data +- `dev_split`: Instead of specifying a dev_filename, you can also specify a ratio (e.g. 0.1) here + that gets split off from training data for eval. +- `use_gpu`: Whether to use GPU (if available) +- `student_batch_size`: Number of samples the student model receives in one batch for training +- `student_batch_size`: Number of samples the teacher model receives in one batch for distillation +- `n_epochs`: Number of iterations on the whole training data set +- `learning_rate`: Learning rate of the optimizer +- `max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down. +- `warmup_proportion`: Proportion of training steps until maximum learning rate is reached. + Until that point LR is increasing linearly. After that it's decreasing again linearly. + Options for different schedules are available in FARM. +- `evaluate_every`: Evaluate the model every X steps on the hold-out eval dataset +- `save_dir`: Path to store the final model +- `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing. + Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. + Set to None to use all CPU cores minus one. +- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. + Available options: + None (Don't use AMP) + "O0" (Normal FP32 training) + "O1" (Mixed Precision => Recommended) + "O2" (Almost FP16) + "O3" (Pure FP16). + See details on: https://nvidia.github.io/apex/amp.html +- `checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual + checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. +- `checkpoint_every`: save a train checkpoint after this many steps of training. +- `checkpoints_to_keep`: maximum number of train checkpoints to save. +:param caching whether or not to use caching for preprocessed dataset and teacher logits +- `cache_path`: Path to cache the preprocessed dataset and teacher logits +- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important. +- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits) +- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. **Returns**: diff --git a/haystack/nodes/reader/farm.py b/haystack/nodes/reader/farm.py index 615320dc3c..d450271dcc 100644 --- a/haystack/nodes/reader/farm.py +++ b/haystack/nodes/reader/farm.py @@ -203,6 +203,8 @@ def _training_procedure( if not save_dir: save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}" + if tinybert: + save_dir += "_tinybert_stage_1" # 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset label_list = ["start_token", "end_token"] @@ -378,7 +380,7 @@ def train( checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep, caching=caching, cache_path=cache_path) - def distil_from( + def distil_prediction_layer_from( self, teacher_model: "FARMReader", data_dir: str, @@ -389,7 +391,7 @@ def distil_from( student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, - learning_rate: float = 1e-5, + learning_rate: float = 3e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, @@ -405,20 +407,21 @@ def distil_from( distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0, - tinybert_loss: bool = False, - tinybert_epochs: int = 1, ): """ - Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset - and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model + Fine-tune a model on a QA dataset using logit-based distillation. You need to provide a teacher model that is already finetuned on the dataset + and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model. using a more complex teacher. + Originally proposed in: https://arxiv.org/pdf/1503.02531.pdf + This can also be considered as the second stage of distillation finetuning as described in the TinyBERT paper: + https://arxiv.org/pdf/1909.10351.pdf **Example** ```python student = FARMReader(model_name_or_path="prajjwal1/bert-medium") teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2") - student.distil_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json", + student.distil_prediction_layer_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json", learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5) ``` @@ -465,20 +468,10 @@ def distil_from( :param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. :param tinybert_loss: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased. :param tinybert_epochs: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function. + :param tinybert_learning_rate: Learning rate to use when training the student model with the TinyBERT loss function. + :param tinybert_train_filename: Filename of training data to use when training the student model with the TinyBERT loss function. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script. If not specified, the training data from the original training is used. :return: None """ - if tinybert_loss: # do hidden state and attention distillation as additional stage - self._training_procedure(data_dir=data_dir, train_filename=train_filename, - dev_filename=dev_filename, test_filename=test_filename, - use_gpu=use_gpu, batch_size=student_batch_size, - n_epochs=tinybert_epochs, learning_rate=learning_rate, - max_seq_len=max_seq_len, warmup_proportion=warmup_proportion, - dev_split=dev_split, evaluate_every=evaluate_every, - save_dir=save_dir, num_processes=num_processes, - use_amp=use_amp, checkpoint_root_dir=checkpoint_root_dir, - checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep, - teacher_model=teacher_model, teacher_batch_size=teacher_batch_size, - caching=caching, cache_path=cache_path, tinybert=True) return self._training_procedure(data_dir=data_dir, train_filename=train_filename, dev_filename=dev_filename, test_filename=test_filename, use_gpu=use_gpu, batch_size=student_batch_size, @@ -492,6 +485,102 @@ def distil_from( caching=caching, cache_path=cache_path, distillation_loss_weight=distillation_loss_weight, distillation_loss=distillation_loss, temperature=temperature) + def distil_intermediate_layers_from( + self, + teacher_model: "FARMReader", + data_dir: str, + train_filename: str, + dev_filename: Optional[str] = None, + test_filename: Optional[str] = None, + use_gpu: Optional[bool] = None, + student_batch_size: int = 10, + teacher_batch_size: Optional[int] = None, + n_epochs: int = 5, + learning_rate: float = 5e-5, + max_seq_len: Optional[int] = None, + warmup_proportion: float = 0.2, + dev_split: float = 0, + evaluate_every: int = 300, + save_dir: Optional[str] = None, + num_processes: Optional[int] = None, + use_amp: str = None, + checkpoint_root_dir: Path = Path("model_checkpoints"), + checkpoint_every: Optional[int] = None, + checkpoints_to_keep: int = 3, + caching: bool = False, + cache_path: Path = Path("cache/data_silo"), + distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse", + temperature: float = 1.0, + ): + """ + The first stage of distillation finetuning as described in the TinyBERT paper: + https://arxiv.org/pdf/1909.10351.pdf + + **Example** + ```python + student = FARMReader(model_name_or_path="prajjwal1/bert-medium") + teacher = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D") + + student.distil_intermediate_layers_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json", + learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5) + ``` + + Checkpoints can be stored via setting `checkpoint_every` to a custom number of steps. + If any checkpoints are stored, a subsequent run of train() will resume training from the latest available checkpoint. + + :param teacher_model: Model whose logits will be used to improve accuracy + :param data_dir: Path to directory containing your training data in SQuAD style + :param train_filename: Filename of training data. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script + :param dev_filename: Filename of dev / eval data + :param test_filename: Filename of test data + :param dev_split: Instead of specifying a dev_filename, you can also specify a ratio (e.g. 0.1) here + that gets split off from training data for eval. + :param use_gpu: Whether to use GPU (if available) + :param student_batch_size: Number of samples the student model receives in one batch for training + :param student_batch_size: Number of samples the teacher model receives in one batch for distillation + :param n_epochs: Number of iterations on the whole training data set + :param learning_rate: Learning rate of the optimizer + :param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down. + :param warmup_proportion: Proportion of training steps until maximum learning rate is reached. + Until that point LR is increasing linearly. After that it's decreasing again linearly. + Options for different schedules are available in FARM. + :param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset + :param save_dir: Path to store the final model + :param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing. + Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. + Set to None to use all CPU cores minus one. + :param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. + Available options: + None (Don't use AMP) + "O0" (Normal FP32 training) + "O1" (Mixed Precision => Recommended) + "O2" (Almost FP16) + "O3" (Pure FP16). + See details on: https://nvidia.github.io/apex/amp.html + :param checkpoint_root_dir: the Path of directory where all train checkpoints are saved. For each individual + checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created. + :param checkpoint_every: save a train checkpoint after this many steps of training. + :param checkpoints_to_keep: maximum number of train checkpoints to save. + :param caching whether or not to use caching for preprocessed dataset and teacher logits + :param cache_path: Path to cache the preprocessed dataset and teacher logits + :param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important. + :param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits) + :param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model. + :return: None + """ + return self._training_procedure(data_dir=data_dir, train_filename=train_filename, + dev_filename=dev_filename, test_filename=test_filename, + use_gpu=use_gpu, batch_size=student_batch_size, + n_epochs=n_epochs, learning_rate=learning_rate, + max_seq_len=max_seq_len, warmup_proportion=warmup_proportion, + dev_split=dev_split, evaluate_every=evaluate_every, + save_dir=save_dir, num_processes=num_processes, + use_amp=use_amp, checkpoint_root_dir=checkpoint_root_dir, + checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep, + teacher_model=teacher_model, teacher_batch_size=teacher_batch_size, + caching=caching, cache_path=cache_path, + distillation_loss=distillation_loss, temperature=temperature, tinybert=True) + def update_parameters( self, context_window_size: Optional[int] = None, diff --git a/test/test_distillation.py b/test/test_distillation.py index 0f8d4fb478..f63e6fd234 100644 --- a/test/test_distillation.py +++ b/test/test_distillation.py @@ -23,7 +23,7 @@ def test_distillation(): student_weights.pop(-2) # pooler is not updated due to different attention head - student.distil_from(teacher, data_dir="samples/squad", train_filename="tiny.json") + student.distil_prediction_layer_from(teacher, data_dir="samples/squad", train_filename="tiny.json") # create new checkpoint new_student_weights = create_checkpoint(student) @@ -47,7 +47,7 @@ def test_tinybert_distillation(): student_weights.pop(-1) # last layer is not affected by tinybert loss student_weights.pop(-1) # pooler is not updated due to different attention head - student._training_procedure(teacher_model=teacher, tinybert=True, data_dir="samples/squad", train_filename="tiny.json") + student.distil_intermediate_layers_from(teacher_model=teacher, data_dir="samples/squad", train_filename="tiny.json") # create new checkpoint new_student_weights = create_checkpoint(student)