Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET -1004] Poisson NegativeLog Likelihood loss #12697

Merged
merged 20 commits into from
Oct 13, 2018
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
60dc8bc
PoissonNLLLoss function to compute negative log likelihood loss
gaurav-gireesh Sep 28, 2018
4821d5e
Removing debugging print statements
gaurav-gireesh Sep 28, 2018
178ce39
Pylint code formatting problems addressed
gaurav-gireesh Sep 28, 2018
b394e15
Added Stirling approximation for factorial term in the denominator an…
gaurav-gireesh Oct 1, 2018
7b00bf1
Separated the test cases for Flag value for logits and compute_full
gaurav-gireesh Oct 1, 2018
a19f494
Added comments for package- numpy inclusion and some pylint formatting
gaurav-gireesh Oct 1, 2018
3d18177
Trigger CI
gaurav-gireesh Oct 8, 2018
56a8faf
Markdown file updted. Added entry for Poissons NLLLoss
gaurav-gireesh Oct 8, 2018
814a04b
Fixing pending documentation issue
gaurav-gireesh Oct 9, 2018
4b37836
Documentation docstring changed
gaurav-gireesh Oct 9, 2018
506ae21
PR Comment to remove extra newline removed.
gaurav-gireesh Oct 9, 2018
ed2ee68
Symbol PI corrected
gaurav-gireesh Oct 9, 2018
00b0e71
epsilon spellicng correction
gaurav-gireesh Oct 9, 2018
790ca09
More unit tests added - testing with mod.score() and mod.fit()
gaurav-gireesh Oct 10, 2018
26c657f
changed the number of epochs
gaurav-gireesh Oct 10, 2018
db229a0
PR Comments addressed added mod score tests and a newline
gaurav-gireesh Oct 10, 2018
d7b1c4a
Empty line added
gaurav-gireesh Oct 10, 2018
f88c558
Adding hybridized test
gaurav-gireesh Oct 11, 2018
66e6775
Trigger CI
gaurav-gireesh Oct 11, 2018
3585a1e
Variable names changed
gaurav-gireesh Oct 12, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/gluon/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ This package includes several commonly used loss functions in neural networks.
LogisticLoss
TripletLoss
CTCLoss
PoissonNLLLoss
```


Expand Down
67 changes: 65 additions & 2 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
'SigmoidBinaryCrossEntropyLoss', 'SigmoidBCELoss',
'SoftmaxCrossEntropyLoss', 'SoftmaxCELoss',
'KLDivLoss', 'CTCLoss', 'HuberLoss', 'HingeLoss',
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss']

'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss', 'PoissonNLLLoss']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please add line space

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it. Thanks for pointing out.

import numpy as np
from .. import ndarray
from ..base import numeric_types
from .block import HybridBlock


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove newline

def _apply_weighting(F, loss, weight=None, sample_weight=None):
"""Apply weighting to loss.

Expand Down Expand Up @@ -706,3 +707,65 @@ def hybrid_forward(self, F, pred, positive, negative):
axis=self._batch_axis, exclude=True)
loss = F.relu(loss + self._margin)
return _apply_weighting(F, loss, self._weight, None)


class PoissonNLLLoss(Loss):
r"""For a target (Random Variable) in a Poisson distribution, the function calculates the Negative
Log likelihood loss.
PoissonNLLLoss measures the loss accrued from a poisson regression prediction made by the model.

.. math::
L = \text{pred} - \text{target} * \log(\text{pred}) +\log(\text{target!})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why ! after \text{target!}

Also do you have any reference for the definition of this loss function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anirudhacharya The "!" is for factorial of the target value. The formula to calculate the probability of the target value with the mean of the poisson distribution contains a factorial term in the denominator which is approximated by computing the Stirling approximation. Taking a log of the formula reduces to the form mentioned in the documentation.
Some wikipedia pages that I have found helpful are:
Poisson regression and
Poisson distribution.
Also, I have attached a link for reference to PyTorch's implementation of the loss function in the description of the PR.


`pred`, `target` can have arbitrary shape as long as they have the same number of elements.

Parameters
----------
from_logits : boolean, default True
indicating whether log(predicted) value has already been computed. If True, the loss is computed as
:math:`\exp(\text{pred}) - \text{target} * \text{pred}`, and if False, then loss is computed as
:math:`\text{pred} - \text{target} * \log(\text{pred}+\text{epsislon})`.The default value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these definitions seem to be at odds with the definition given above in line 718

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the formulae are different. Following are some points:

  1. The factorial term gets dropped in most calculations. However, there is a way to approximate it using Stirling factor/approximation. We can use compute_full flag to be set to True to incorporate this approximation in our loss value computation,
  2. from_logits flag set to true or false can render the formula to change depending upon the prediction is already logged ( log pred) or not.

weight : float or None
Global scalar weight for loss.
batch_axis : int, default 0
The axis that represents mini-batch.
compute_full: boolean, default False
Indicates whether to add an approximation(Stirling factor) for the Factorial term in the formula for the loss.
The Stirling factor is:
\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \PI * \text{target})
epsilon: float, default 1e-08
This is to avoid calculating log(0) which is not defined.


Inputs:
------
- **pred**: prediction tensor with arbitrary shape
- **target**: Random variable(count or number) which belongs to a Poisson distribution.
- **sample_weight**: element-wise weighting tensor. Must be broadcastable
to the same shape as pred. For example, if pred has shape (64, 10)
and you want to weigh each sample in the batch separately,
sample_weight should have shape (64, 1).

Outputs:
--------
- **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,).
"""
def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=False, **kwargs):
super(PoissonNLLLoss, self).__init__(weight, batch_axis, **kwargs)
self._from_logits = from_logits
self._compute_full = compute_full

def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08):
target = _reshape_like(F, target, pred)
if self._from_logits:
loss = F.exp(pred) - target * pred
else:
loss = pred - target * F.log(pred + epsilon)
if self._compute_full:
# Using numpy's pi value
stirling_factor = target * F.log(target)- target + 0.5 * F.log(2 * target * np.pi)
target_gt_1 = target > 1
stirling_factor *= target_gt_1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining why we need this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Yes, I have mentioned the use of the parameter 'compute_full' in function docstring, which explains why the calculation is needed.

loss += stirling_factor
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss)
38 changes: 38 additions & 0 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,44 @@ def test_triplet_loss():
optimizer='adam')
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05

@with_seed()
def test_poisson_nllloss():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please separate into 3 unit tests clearly stating what condition are you testing.

Copy link
Contributor Author

@gaurav-gireesh gaurav-gireesh Oct 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@roywei Yes, I have separated the use cases commenting on what exactly is being tested. I have not split into different functions so as to keep consistency with other loss functions tests.
Thank you for your comment.

pred = mx.nd.random.normal(shape=(3, 4))
min_pred = mx.nd.min(pred)
#This is necessary to ensure only positive random values are generated for prediction,
# to avoid ivalid log calculation
pred[:] = pred + mx.nd.abs(min_pred)
target = mx.nd.random.normal(shape=(3, 4))
min_target = mx.nd.min(target)
#This is necessary to ensure only positive random values are generated for prediction,
# to avoid ivalid log calculation
target[:] += mx.nd.abs(min_target)

Loss = gluon.loss.PoissonNLLLoss(from_logits=True)
Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False)
#Calculating by brute formula for default value of from_logits = True

# 1) Testing for flag logits = True
brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy())
loss_withlogits = Loss(pred, target)
assert_almost_equal(brute_loss, loss_withlogits.asscalar())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not test this loss function using the Module API with module.fit and module.score

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion. It is a valid point however, getting a synthetic data set to behave in a way that Random input X's are correlated to a target ~ PoissonDistribution(target belonging to a Poisson Distribution) is difficult to gather. However, I have tried with random data sets and trained the model but the loss gradient was not observed. This could be incrementally contributed if I come across such a dataset which shows the loss gradient with epochs as other models like LogisticRegression, Linear Regression and so on.
The formula to compute a loss value however, is implemented in the function which can be tested by raw calculations. This is also something that we see in unit tests of loss functions such as : test_bce_loss(Binary Cross Entropy).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit test added.


#2) Testing for flag logits = False
loss_no_logits = Loss_no_logits(pred, target)
np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08))
if np.isnan(loss_no_logits.asscalar()):
assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar()))
else:
assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar())

#3) Testing for Sterling approximation
np_pred = np.random.uniform(1, 5, (2, 3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests for hybridized version as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests for hybridized version as well.

Added the unit test for it.

np_target = np.random.uniform(1, 5, (2, 3))
np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)))
Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target))
assert_almost_equal(np_compute_full, loss_compute_full.asscalar())

if __name__ == '__main__':
import nose
Expand Down