Skip to content

Commit

Permalink
Migrating to use native Pytorch AMP (#2827)
Browse files Browse the repository at this point in the history
* Started making changes to use native Pytorch AMP

* Updated compute_loss functions to use torch.cuda.amp.autocast

* Updating docstrings

* Add use_amp to trainer_checkpoint

* Removed mentions of apex and started to add the necessary warnings

* Removing unused instances of use_amp variable

* Added fast training test for FARMReader. Needed to add max_query_length as a parameter in FARMReader.__init__ and FARMReader.train

* Make max_query_length optional in FARMReader.train

* Update lg

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>
Co-authored-by: agnieszka-m <amarzec13@gmail.com>
  • Loading branch information
3 people authored Jan 5, 2023
1 parent 35e9ff2 commit e84fae2
Show file tree
Hide file tree
Showing 15 changed files with 253 additions and 275 deletions.
2 changes: 1 addition & 1 deletion docs/_src/api/api/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ Computes Transformer-based similarity of predicted answer to gold labels to deri

Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels
b) the highest similarity of all predictions to gold labels
c) a matrix consisting of the similarities of all the predicitions compared to all gold labels
c) a matrix consisting of the similarities of all the predictions compared to all gold labels

**Arguments**:

Expand Down
50 changes: 18 additions & 32 deletions docs/_src/api/api/reader.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def train(data_dir: str,
evaluate_every: int = 300,
save_dir: Optional[str] = None,
num_processes: Optional[int] = None,
use_amp: str = None,
use_amp: bool = False,
checkpoint_root_dir: Path = Path("model_checkpoints"),
checkpoint_every: Optional[int] = None,
checkpoints_to_keep: int = 3,
Expand Down Expand Up @@ -193,14 +193,10 @@ Note that the evaluation report is logged at evaluation level INFO while Haystac
- `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing.
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
Set to None to use all CPU cores minus one.
- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
Available options:
None (Don't use AMP)
"O0" (Normal FP32 training)
"O1" (Mixed Precision => Recommended)
"O2" (Almost FP16)
"O3" (Pure FP16).
See details on: https://nvidia.github.io/apex/amp.html
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
training speed and reduce GPU memory usage.
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
- `checkpoint_root_dir`: The Path of a directory where all train checkpoints are saved. For each individual
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
- `checkpoint_every`: Save a train checkpoint after this many steps of training.
Expand Down Expand Up @@ -237,7 +233,7 @@ def distil_prediction_layer_from(
evaluate_every: int = 300,
save_dir: Optional[str] = None,
num_processes: Optional[int] = None,
use_amp: str = None,
use_amp: bool = False,
checkpoint_root_dir: Path = Path("model_checkpoints"),
checkpoint_every: Optional[int] = None,
checkpoints_to_keep: int = 3,
Expand Down Expand Up @@ -284,7 +280,7 @@ A list containing torch device objects and/or strings is supported (For example
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
- `student_batch_size`: Number of samples the student model receives in one batch for training
- `student_batch_size`: Number of samples the teacher model receives in one batch for distillation
- `teacher_batch_size`: Number of samples the teacher model receives in one batch for distillation
- `n_epochs`: Number of iterations on the whole training data set
- `learning_rate`: Learning rate of the optimizer
- `max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down.
Expand All @@ -296,14 +292,10 @@ Options for different schedules are available in FARM.
- `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing.
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
Set to None to use all CPU cores minus one.
- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
Available options:
None (Don't use AMP)
"O0" (Normal FP32 training)
"O1" (Mixed Precision => Recommended)
"O2" (Almost FP16)
"O3" (Pure FP16).
See details on: https://nvidia.github.io/apex/amp.html
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
training speed and reduce GPU memory usage.
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
- `checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
- `checkpoint_every`: save a train checkpoint after this many steps of training.
Expand Down Expand Up @@ -347,7 +339,7 @@ def distil_intermediate_layers_from(
evaluate_every: int = 300,
save_dir: Optional[str] = None,
num_processes: Optional[int] = None,
use_amp: str = None,
use_amp: bool = False,
checkpoint_root_dir: Path = Path("model_checkpoints"),
checkpoint_every: Optional[int] = None,
checkpoints_to_keep: int = 3,
Expand Down Expand Up @@ -389,8 +381,7 @@ that gets split off from training data for eval.
A list containing torch device objects and/or strings is supported (For example
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
- `student_batch_size`: Number of samples the student model receives in one batch for training
- `student_batch_size`: Number of samples the teacher model receives in one batch for distillation
- `batch_size`: Number of samples the student model and teacher model receives in one batch for training
- `n_epochs`: Number of iterations on the whole training data set
- `learning_rate`: Learning rate of the optimizer
- `max_seq_len`: Maximum text length (in tokens). Everything longer gets cut down.
Expand All @@ -402,21 +393,16 @@ Options for different schedules are available in FARM.
- `num_processes`: The number of processes for `multiprocessing.Pool` during preprocessing.
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
Set to None to use all CPU cores minus one.
- `use_amp`: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
Available options:
None (Don't use AMP)
"O0" (Normal FP32 training)
"O1" (Mixed Precision => Recommended)
"O2" (Almost FP16)
"O3" (Pure FP16).
See details on: https://nvidia.github.io/apex/amp.html
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
training speed and reduce GPU memory usage.
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
- `checkpoint_root_dir`: the Path of directory where all train checkpoints are saved. For each individual
checkpoint, a subdirectory with the name epoch_{epoch_num}_step_{step_num} is created.
- `checkpoint_every`: save a train checkpoint after this many steps of training.
- `checkpoints_to_keep`: maximum number of train checkpoints to save.
- `caching`: whether or not to use caching for preprocessed dataset and teacher logits
- `cache_path`: Path to cache the preprocessed dataset and teacher logits
- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
- `processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used.
Expand Down Expand Up @@ -663,7 +649,7 @@ Example:
**Arguments**:

- `question`: Question string
- `documents`: List of documents as string type
- `texts`: A list of Document texts as a string type
- `top_k`: The maximum number of answers to return

**Returns**:
Expand Down
24 changes: 10 additions & 14 deletions docs/_src/api/api/retriever.md
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def train(data_dir: str,
weight_decay: float = 0.0,
num_warmup_steps: int = 100,
grad_acc_steps: int = 1,
use_amp: str = None,
use_amp: bool = False,
optimizer_name: str = "AdamW",
optimizer_correct_bias: bool = True,
save_dir: str = "../saved_models/dpr",
Expand Down Expand Up @@ -984,12 +984,10 @@ you should use the file_system strategy.
- `epsilon`: epsilon parameter of optimizer
- `weight_decay`: weight decay parameter of optimizer
- `grad_acc_steps`: number of steps to accumulate gradient over before back-propagation is done
- `use_amp`: Whether to use automatic mixed precision (AMP) or not. The options are:
"O0" (FP32)
"O1" (Mixed Precision)
"O2" (Almost FP16)
"O3" (Pure FP16).
For more information, refer to: https://nvidia.github.io/apex/amp.html
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
training speed and reduce GPU memory usage.
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
- `optimizer_name`: what optimizer to use (default: AdamW)
- `num_warmup_steps`: number of warmup steps
- `optimizer_correct_bias`: Whether to correct bias in optimizer
Expand Down Expand Up @@ -1305,7 +1303,7 @@ def train(data_dir: str,
weight_decay: float = 0.0,
num_warmup_steps: int = 100,
grad_acc_steps: int = 1,
use_amp: str = None,
use_amp: bool = False,
optimizer_name: str = "AdamW",
optimizer_correct_bias: bool = True,
save_dir: str = "../saved_models/mm_retrieval",
Expand Down Expand Up @@ -1345,12 +1343,10 @@ very similar (high score by BM25) to query but do not contain the answer)-
- `epsilon`: Epsilon parameter of optimizer.
- `weight_decay`: Weight decay parameter of optimizer.
- `grad_acc_steps`: Number of steps to accumulate gradient over before back-propagation is done.
- `use_amp`: Whether to use automatic mixed precision (AMP) or not. The options are:
"O0" (FP32)
"O1" (Mixed Precision)
"O2" (Almost FP16)
"O3" (Pure FP16).
For more information, refer to: https://nvidia.github.io/apex/amp.html
- `use_amp`: Whether to use automatic mixed precision (AMP) natively implemented in PyTorch to improve
training speed and reduce GPU memory usage.
For more information, see (Haystack Optimization)[https://haystack.deepset.ai/guides/optimization]
and (Automatic Mixed Precision Package - Torch.amp)[https://pytorch.org/docs/stable/amp.html].
- `optimizer_name`: What optimizer to use (default: TransformersAdamW).
- `num_warmup_steps`: Number of warmup steps.
- `optimizer_correct_bias`: Whether to correct bias in optimizer.
Expand Down
8 changes: 5 additions & 3 deletions haystack/modeling/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
:param dev_filename: The name of the file containing the dev data. If None and 0.0 < dev_split < 1.0 the dev set
will be a slice of the train set.
:param test_filename: The name of the file containing test data.
:param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None
:param dev_split: The proportion of the train set that will be sliced. Only works if `dev_filename` is set to `None`.
:param data_dir: The directory in which the train, test and perhaps dev files can be found.
:param tasks: Tasks for which the processor shall extract labels from the input data.
Usually this includes a single, default task, e.g. text classification.
Expand Down Expand Up @@ -137,7 +137,7 @@ def load(
If None and 0.0 < dev_split < 1.0 the dev set
will be a slice of the train set.
:param test_filename: The name of the file containing test data.
:param dev_split: The proportion of the train set that will sliced.
:param dev_split: The proportion of the train set that will be sliced.
Only works if dev_filename is set to None
:param kwargs: placeholder for passing generic parameters
:return: An instance of the specified processor.
Expand Down Expand Up @@ -217,6 +217,7 @@ def convert_from_transformers(
tokenizer_class=None,
tokenizer_args=None,
use_fast=True,
max_query_length=64,
**kwargs,
):
tokenizer_args = tokenizer_args or {}
Expand All @@ -238,6 +239,7 @@ def convert_from_transformers(
metric="squad",
data_dir="data",
doc_stride=doc_stride,
max_query_length=max_query_length,
)
elif task_type == "embeddings":
processor = InferenceProcessor(tokenizer=tokenizer, max_seq_len=max_seq_len)
Expand Down Expand Up @@ -396,7 +398,7 @@ def __init__(
:param dev_filename: The name of the file containing the dev data. If None and 0.0 < dev_split < 1.0 the dev set
will be a slice of the train set.
:param test_filename: None
:param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None
:param dev_split: The proportion of the train set that will be sliced. Only works if `dev_filename` is set to `None`.
:param doc_stride: When the document containing the answer is too long it gets split into part, strided by doc_stride
:param max_query_length: Maximum length of the question (in number of subword tokens)
:param proxies: proxy configuration to allow downloads of remote datasets.
Expand Down
5 changes: 5 additions & 0 deletions haystack/modeling/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def load(
multithreading_rust: bool = True,
use_auth_token: Optional[Union[bool, str]] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
max_query_length: int = 64,
**kwargs,
):
"""
Expand Down Expand Up @@ -178,6 +179,7 @@ def load(
`transformers-cli login` (stored in ~/.huggingface) will be used.
Additional information can be found here
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:param max_query_length: Only QA: Maximum length of the question in number of tokens.
:return: An instance of the Inferencer.
"""
if tokenizer_args is None:
Expand Down Expand Up @@ -228,6 +230,7 @@ def load(
tokenizer_args=tokenizer_args,
use_fast=use_fast,
use_auth_token=use_auth_token,
max_query_length=max_query_length,
**kwargs,
)

Expand All @@ -241,6 +244,8 @@ def load(
"Please set a lower value for doc_stride (Suggestions: doc_stride=128, max_seq_len=384) "
)
processor.doc_stride = doc_stride
if hasattr(processor, "max_query_length"):
processor.max_query_length = max_query_length

return cls(
model,
Expand Down
2 changes: 1 addition & 1 deletion haystack/modeling/model/adaptive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def load( # type: ignore
* vocab.txt vocab file for language model, turning text to Wordpiece Tokens
:param load_dir: Location where the AdaptiveModel is stored.
:param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda").
:param device: Specifies the device to which you want to send the model, either torch.device("cpu") or torch.device("cuda").
:param strict: Whether to strictly enforce that the keys loaded from saved model match the ones in
the PredictionHead (see torch.nn.module.load_state_dict()).
:param processor: Processor to populate prediction head with information coming from tasks.
Expand Down
Loading

0 comments on commit e84fae2

Please sign in to comment.