Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

fix a bug when using fp16 training & gradient clipping #5426

Merged
merged 3 commits into from
Oct 7, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.cuda import amp
from torch.nn.utils import clip_grad_norm_
import torch.distributed as dist
from torch.cuda.amp.grad_scaler import OptState

from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import util as common_util, Tqdm, Lazy
Expand Down Expand Up @@ -349,6 +350,23 @@ def _pytorch_model(self):
return self.model
return self._ddp_wrapped_model.model

def clip_gradient(self):
"""
Performs gradient clipping.
If the model is in mixed precision training, we would first unscale the gradient.
"""
if self._grad_clipping is not None:
# 1. We have to unscale the gradient before clipping
if self._scaler is not None:
optimizer_state = self._scaler._per_optimizer_states[id(self.optimizer)]
# 2. The `unscale_` shouldn't be performed more than once per optimizer per step call,
# so we only perform `unscale_` if it has not already been called.
if optimizer_state["stage"] is not OptState.UNSCALED:
self._scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_value_(
[p for p in self.model.parameters() if p.grad is not None], self._grad_clipping
)

def rescale_gradients(self) -> Optional[float]:
"""
Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.
Expand Down Expand Up @@ -518,6 +536,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
train_loss += batch_loss

batch_grad_norm = self.rescale_gradients()
self.clip_gradient()

if self._learning_rate_scheduler:
self._learning_rate_scheduler.step_batch(self._total_batches_completed + 1)
Expand Down Expand Up @@ -756,7 +775,6 @@ def train(self) -> Dict[str, Any]:
callback.on_end(self, metrics=metrics, epoch=epoch, is_primary=self._primary)

def _try_train(self) -> Tuple[Dict[str, Any], int]:
training_util.enable_gradient_clipping(self.model, self._grad_clipping)

logger.info("Beginning training.")

Expand Down