Skip to content

Commit

Permalink
[MXNET -1004] Poisson NegativeLog Likelihood loss (apache#12697)
Browse files Browse the repository at this point in the history
* PoissonNLLLoss function to compute negative log likelihood loss

* Removing debugging print statements

* Pylint code formatting problems addressed

* Added Stirling approximation for factorial term in the denominator and test case for the same

* Separated the test cases for Flag value for logits and compute_full

* Added comments for package- numpy inclusion and some pylint formatting

* Trigger CI

* Markdown file updted. Added entry for Poissons NLLLoss

* Fixing pending documentation issue

* Documentation docstring changed

* PR Comment to remove extra newline removed.

* Symbol PI corrected

* epsilon spellicng correction

* More unit tests added - testing with mod.score() and mod.fit()

* changed the number of epochs

* PR Comments addressed added mod score tests and a newline

* Empty line added

* Adding hybridized test

* Trigger CI

* Variable names changed
  • Loading branch information
gaurav-gireesh authored and lanking520 committed Oct 24, 2018
1 parent dabdf2a commit d6af8c9
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/api/python/gluon/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ This package includes several commonly used loss functions in neural networks.
LogisticLoss
TripletLoss
CTCLoss
PoissonNLLLoss
```


Expand Down
63 changes: 62 additions & 1 deletion python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
'SigmoidBinaryCrossEntropyLoss', 'SigmoidBCELoss',
'SoftmaxCrossEntropyLoss', 'SoftmaxCELoss',
'KLDivLoss', 'CTCLoss', 'HuberLoss', 'HingeLoss',
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss']
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss', 'PoissonNLLLoss']

import numpy as np
from .. import ndarray
from ..base import numeric_types
from .block import HybridBlock
Expand Down Expand Up @@ -706,3 +707,63 @@ def hybrid_forward(self, F, pred, positive, negative):
axis=self._batch_axis, exclude=True)
loss = F.relu(loss + self._margin)
return _apply_weighting(F, loss, self._weight, None)


class PoissonNLLLoss(Loss):
r"""For a target (Random Variable) in a Poisson distribution, the function calculates the Negative
Log likelihood loss.
PoissonNLLLoss measures the loss accrued from a poisson regression prediction made by the model.
.. math::
L = \text{pred} - \text{target} * \log(\text{pred}) +\log(\text{target!})
`pred`, `target` can have arbitrary shape as long as they have the same number of elements.
Parameters
----------
from_logits : boolean, default True
indicating whether log(predicted) value has already been computed. If True, the loss is computed as
:math:`\exp(\text{pred}) - \text{target} * \text{pred}`, and if False, then loss is computed as
:math:`\text{pred} - \text{target} * \log(\text{pred}+\text{epsilon})`.The default value
weight : float or None
Global scalar weight for loss.
batch_axis : int, default 0
The axis that represents mini-batch.
compute_full: boolean, default False
Indicates whether to add an approximation(Stirling factor) for the Factorial term in the formula for the loss.
The Stirling factor is:
:math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`
epsilon: float, default 1e-08
This is to avoid calculating log(0) which is not defined.
Inputs:
- **pred**: Predicted value
- **target**: Random variable(count or number) which belongs to a Poisson distribution.
- **sample_weight**: element-wise weighting tensor. Must be broadcastable
to the same shape as pred. For example, if pred has shape (64, 10)
and you want to weigh each sample in the batch separately,
sample_weight should have shape (64, 1).
Outputs:
- **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,).
"""
def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=False, **kwargs):
super(PoissonNLLLoss, self).__init__(weight, batch_axis, **kwargs)
self._from_logits = from_logits
self._compute_full = compute_full

def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08):
target = _reshape_like(F, target, pred)
if self._from_logits:
loss = F.exp(pred) - target * pred
else:
loss = pred - target * F.log(pred + epsilon)
if self._compute_full:
# Using numpy's pi value
stirling_factor = target * F.log(target)- target + 0.5 * F.log(2 * target * np.pi)
target_gt_1 = target > 1
stirling_factor *= target_gt_1
loss += stirling_factor
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss)
55 changes: 55 additions & 0 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,61 @@ def test_triplet_loss():
optimizer='adam')
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05

@with_seed()
def test_poisson_nllloss():
pred = mx.nd.random.normal(shape=(3, 4))
min_pred = mx.nd.min(pred)
#This is necessary to ensure only positive random values are generated for prediction,
# to avoid ivalid log calculation
pred[:] = pred + mx.nd.abs(min_pred)
target = mx.nd.random.normal(shape=(3, 4))
min_target = mx.nd.min(target)
#This is necessary to ensure only positive random values are generated for prediction,
# to avoid ivalid log calculation
target[:] += mx.nd.abs(min_target)

Loss = gluon.loss.PoissonNLLLoss(from_logits=True)
Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False)
#Calculating by brute formula for default value of from_logits = True

# 1) Testing for flag logits = True
brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy())
loss_withlogits = Loss(pred, target)
assert_almost_equal(brute_loss, loss_withlogits.asscalar())

#2) Testing for flag logits = False
loss_no_logits = Loss_no_logits(pred, target)
np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08))
if np.isnan(loss_no_logits.asscalar()):
assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar()))
else:
assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar())

#3) Testing for Sterling approximation
np_pred = np.random.uniform(1, 5, (2, 3))
np_target = np.random.uniform(1, 5, (2, 3))
np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)))
Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target))
assert_almost_equal(np_compute_full, loss_compute_full.asscalar())

@with_seed()
def test_poisson_nllloss_mod():
N = 1000
data = mx.random.poisson(shape=(N, 2))
label = mx.random.poisson(lam=4, shape=(N, 1))
data_iter = mx.io.NDArrayIter(data, label, batch_size=20, label_name='label', shuffle=True)
output = mx.sym.exp(get_net(1))
l = mx.symbol.Variable('label')
Loss = gluon.loss.PoissonNLLLoss(from_logits=False)
loss = Loss(output, l)
loss = mx.sym.make_loss(loss)
mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
mod.fit(data_iter, num_epoch=20, optimizer_params={'learning_rate': 0.01},
initializer=mx.init.Normal(sigma=0.1), eval_metric=mx.metric.Loss(),
optimizer='adam')
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit d6af8c9

Please sign in to comment.