-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is unintentionally a bigger change. The old code would display all metrics that are configured. This one will only show loss
and reg_loss
, or just loss
.
allennlp/training/trainer.py
Outdated
@@ -409,6 +409,11 @@ def __init__( | |||
self._tensorboard.get_batch_num_total = lambda: self._batch_num_total | |||
self._tensorboard.enable_activation_logging(self.model) | |||
|
|||
# Only display reg_loss if the model's configuration has regularization. | |||
self._show_metrics = ["loss", "reg_loss"] | |||
if self.model.get_regularization_penalty() == 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's likely that we have to detect whether a regularization loss is configured by looking whether the number is 0
. But it would be better if we make that decision based on whether it's configured, not based on whether the result is zero. That way we can show it, for example, when it's misconfigured.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the get_regularization_penalty()
function returns a 0 if the model is configured to not have regularization. This is what the docstring for that function says: Returns 0 if the model was not configured to use regularization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a perfectly fine PR and we could merge it like this.
But I have a suggestion. Feel free to roll your eyes at me. We could expand the type of reg_loss
to be Optional[float]
(or Optional[FloatTensor]
). In other words, Model.get_regularization_penalty()
would return None
when there is no regularizer, and all the places in trainer.py
where this is read have to handle this case properly. It's not a lot of places where this happens, so it shouldn't be too much work. Finally, there is a call like Trainer.get_metrics_from_dict()
or something like that that takes a whole bunch of parameters, including total_reg_loss
. That parameter becomes Optional[float]
instead of the current float
, and if total_reg_loss is None
, then it wouldn't add the "reg_loss"
key to the metrics at all.
This is a slightly larger scope for this change than originally intended, but I think it makes the whole thing cleaner. We no longer have to do the hasattr()
stuff, and our metrics
dict doesn't carry useless zeros. @matt-gardner, what do you think?
I agree that what @dirkgr suggests sounds like a cleaner change. |
…ennlp into fix-regloss-logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, though I have an open question about the regularizer returning float
. Does that ever happen? I think that's a bug in the regularizer if that happens.
@@ -574,19 +578,20 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: | |||
for batch in batch_group: | |||
batch_outputs = self.batch_outputs(batch, for_training=True) | |||
batch_group_outputs.append(batch_outputs) | |||
loss = batch_outputs["loss"] | |||
reg_loss = batch_outputs["reg_loss"] | |||
loss = batch_outputs.get("loss") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "loss"
is always there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it was more for uniformity's sake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎿
Fixes #4436
reg_loss is only displayed if the model being trained is configured to have a regularization penalty.