Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add options for focal loss (#3036)
Browse files Browse the repository at this point in the history
* Add options for focal loss

Add focal loss to deal with class imbalance

* Fix some typing problems

* Update util.py

* Update util.py

* Update util.py

* Fix too long lines

* Fix keyword argument alert

* Update util.py

* Update util.py

* Fix a problem focal loss not activated

* Add test for focal loss gamma

* Fix some problem with decimal precision problems

* Add focal loss alpha test

* Update util_test.py

* Update util.py

* Update util_test.py

* Address some pylint and mypy problems

* Update util.py

* Update util.py

* Update util.py

* Update util.py

* Update util_test.py

* Update util_test.py

It was in the wrong place. Sorry.

* restore not-callable after torch.tensor()

* Update util_test.py

For more clear cross_entropy formulation

* Update util.py

Combine everything into `weights` and avoid reference to local variables later.

* Update util_test.py

Add `@flaky` also to token-average tests.

* Update util.py

Avoid involving `gamma` or `alpha` in average.

* Update util_test.py

* Update util_test.py

More tolerance to token average so the change it complaint < 1/1000.
  • Loading branch information
guoquan authored and matt-gardner committed Jul 10, 2019
1 parent c22ed57 commit ebe9113
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 7 deletions.
73 changes: 66 additions & 7 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
"""
# pylint: disable=too-many-lines
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import logging
import copy
import math
import json
import numpy

import torch

Expand Down Expand Up @@ -628,7 +629,10 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
targets: torch.LongTensor,
weights: torch.FloatTensor,
average: str = "batch",
label_smoothing: float = None) -> torch.FloatTensor:
label_smoothing: float = None,
gamma: float = None,
alpha: Union[float, List[float], torch.FloatTensor] = None
) -> torch.FloatTensor:
"""
Computes the cross entropy loss of a sequence, weighted with respect to
some user provided weights. Note that the weighting here is not the same as
Expand All @@ -655,6 +659,19 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
For example, with a label smoothing value of 0.2, a 4 class classification
target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was
the correct label.
gamma : ``float``, optional (default = None)
Focal loss[*] focusing parameter ``gamma`` to reduces the relative loss for
well-classified examples and put more focus on hard. The greater value
``gamma`` is, the more focus on hard examples.
alpha : ``float`` or ``List[float]``, optional (default = None)
Focal loss[*] weighting factor ``alpha`` to balance between classes. Can be
used independently with ``gamma``. If a single ``float`` is provided, it
is assumed binary case using ``alpha`` and ``1 - alpha`` for positive and
negative respectively. If a list of ``float`` is provided, with the same
length as the number of classes, the weights will match the classes.
[*] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár, "Focal Loss for
Dense Object Detection," 2017 IEEE International Conference on Computer
Vision (ICCV), Venice, 2017, pp. 2999-3007.
Returns
-------
Expand All @@ -667,12 +684,54 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
raise ValueError("Got average f{average}, expected one of "
"None, 'token', or 'batch'")

# make sure weights are float
weights = weights.float()
# sum all dim except batch
non_batch_dims = tuple(range(1, len(weights.shape)))
# shape : (batch_size,)
weights_batch_sum = weights.sum(dim=non_batch_dims)
# shape : (batch * sequence_length, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
# shape : (batch * sequence_length, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
# shape : (batch * max_len, 1)
targets_flat = targets.view(-1, 1).long()
# focal loss coefficient
if gamma:
# shape : (batch * sequence_length, num_classes)
probs_flat = log_probs_flat.exp()
# shape : (batch * sequence_length,)
probs_flat = torch.gather(probs_flat, dim=1, index=targets_flat)
# shape : (batch * sequence_length,)
focal_factor = (1. - probs_flat) ** gamma
# shape : (batch, sequence_length)
focal_factor = focal_factor.view(*targets.size())
weights = weights * focal_factor

if alpha is not None:
# shape : () / (num_classes,)
if isinstance(alpha, (float, int)):
# pylint: disable=not-callable
# shape : (2,)
alpha_factor = torch.tensor([1. - float(alpha), float(alpha)],
dtype=weights.dtype, device=weights.device)
# pylint: enable=not-callable
elif isinstance(alpha, (list, numpy.ndarray, torch.Tensor)):
# pylint: disable=not-callable
# shape : (c,)
alpha_factor = torch.tensor(alpha, dtype=weights.dtype, device=weights.device)
# pylint: enable=not-callable
if not alpha_factor.size():
# shape : (1,)
alpha_factor = alpha_factor.view(1)
# shape : (2,)
alpha_factor = torch.cat([1 - alpha_factor, alpha_factor])
else:
raise TypeError(('alpha must be float, list of float, or torch.FloatTensor, '
'{} provided.').format(type(alpha)))
# shape : (batch, max_len)
alpha_factor = torch.gather(alpha_factor, dim=0, index=targets_flat.view(-1)).view(*targets.size())
weights = weights * alpha_factor

if label_smoothing is not None and label_smoothing > 0.0:
num_classes = logits.size(-1)
Expand All @@ -691,18 +750,18 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
# shape : (batch, sequence_length)
negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
# shape : (batch, sequence_length)
negative_log_likelihood = negative_log_likelihood * weights.float()
negative_log_likelihood = negative_log_likelihood * weights

if average == "batch":
# shape : (batch_size,)
per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)
num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (weights_batch_sum + 1e-13)
num_non_empty_sequences = ((weights_batch_sum > 0).float().sum() + 1e-13)
return per_batch_loss.sum() / num_non_empty_sequences
elif average == "token":
return negative_log_likelihood.sum() / (weights.sum().float() + 1e-13)
return negative_log_likelihood.sum() / (weights_batch_sum.sum() + 1e-13)
else:
# shape : (batch_size,)
per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)
per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (weights_batch_sum + 1e-13)
return per_batch_loss


Expand Down
97 changes: 97 additions & 0 deletions allennlp/tests/nn/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numpy.testing import assert_array_almost_equal, assert_almost_equal
import torch
import pytest
from flaky import flaky

from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase
Expand Down Expand Up @@ -661,6 +662,7 @@ def test_sequence_cross_entropy_with_logits_averages_batch_correctly(self):
# Batch has one completely padded row, so divide by 4.
assert loss.data.numpy() == vector_loss.sum().item() / 4

@flaky(max_runs=3, min_passes=1)
def test_sequence_cross_entropy_with_logits_averages_token_correctly(self):
# test token average is the same as multiplying the per-batch loss
# with the per-batch weights and dividing by the total weight
Expand All @@ -681,6 +683,101 @@ def test_sequence_cross_entropy_with_logits_averages_token_correctly(self):
average_token_loss = (total_token_loss / weights.float().sum()).detach()
assert_almost_equal(loss.detach().item(), average_token_loss.item(), decimal=5)

def test_sequence_cross_entropy_with_logits_gamma_correctly(self):
batch = 1
length = 3
classes = 4
gamma = abs(numpy.random.randn()) # [0, +inf)

tensor = torch.rand([batch, length, classes])
targets = torch.LongTensor(numpy.random.randint(0, classes, [batch, length]))
weights = torch.ones([batch, length])

loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, gamma=gamma)

correct_loss = 0.
for logit, label in zip(tensor.squeeze(0), targets.squeeze(0)):
p = torch.nn.functional.softmax(logit, dim=-1)
pt = p[label]
ft = (1 - pt) ** gamma
correct_loss += - pt.log() * ft
# Average over sequence.
correct_loss = correct_loss / length
numpy.testing.assert_array_almost_equal(loss.data.numpy(), correct_loss.data.numpy())

def test_sequence_cross_entropy_with_logits_alpha_float_correctly(self):
batch = 1
length = 3
classes = 2 # alpha float for binary class only
alpha = numpy.random.rand() if numpy.random.rand() > 0.5 else (1. - numpy.random.rand()) # [0, 1]

tensor = torch.rand([batch, length, classes])
targets = torch.LongTensor(numpy.random.randint(0, classes, [batch, length]))
weights = torch.ones([batch, length])

loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, alpha=alpha)

correct_loss = 0.
for logit, label in zip(tensor.squeeze(0), targets.squeeze(0)):
logp = torch.nn.functional.log_softmax(logit, dim=-1)
logpt = logp[label]
if label:
at = alpha
else:
at = 1 - alpha
correct_loss += - logpt * at
# Average over sequence.
correct_loss = correct_loss / length
numpy.testing.assert_array_almost_equal(loss.data.numpy(), correct_loss.data.numpy())

def test_sequence_cross_entropy_with_logits_alpha_single_float_correctly(self):
batch = 1
length = 3
classes = 2 # alpha float for binary class only
alpha = numpy.random.rand() if numpy.random.rand() > 0.5 else (1. - numpy.random.rand()) # [0, 1]
alpha = torch.tensor(alpha)

tensor = torch.rand([batch, length, classes])
targets = torch.LongTensor(numpy.random.randint(0, classes, [batch, length]))
weights = torch.ones([batch, length])

loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, alpha=alpha)

correct_loss = 0.
for logit, label in zip(tensor.squeeze(0), targets.squeeze(0)):
logp = torch.nn.functional.log_softmax(logit, dim=-1)
logpt = logp[label]
if label:
at = alpha
else:
at = 1 - alpha
correct_loss += - logpt * at
# Average over sequence.
correct_loss = correct_loss / length
numpy.testing.assert_array_almost_equal(loss.data.numpy(), correct_loss.data.numpy())

def test_sequence_cross_entropy_with_logits_alpha_list_correctly(self):
batch = 1
length = 3
classes = 4 # alpha float for binary class only
alpha = abs(numpy.random.randn(classes)) # [0, +inf)

tensor = torch.rand([batch, length, classes])
targets = torch.LongTensor(numpy.random.randint(0, classes, [batch, length]))
weights = torch.ones([batch, length])

loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, alpha=alpha)

correct_loss = 0.
for logit, label in zip(tensor.squeeze(0), targets.squeeze(0)):
logp = torch.nn.functional.log_softmax(logit, dim=-1)
logpt = logp[label]
at = alpha[label]
correct_loss += - logpt * at
# Average over sequence.
correct_loss = correct_loss / length
numpy.testing.assert_array_almost_equal(loss.data.numpy(), correct_loss.data.numpy())

def test_replace_masked_values_replaces_masked_values_with_finite_value(self):
tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
mask = torch.FloatTensor([[1, 1, 0]])
Expand Down

0 comments on commit ebe9113

Please sign in to comment.