-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET -1004] Poisson NegativeLog Likelihood loss #12697
Changes from all commits
60dc8bc
4821d5e
178ce39
b394e15
7b00bf1
a19f494
3d18177
56a8faf
814a04b
4b37836
506ae21
ed2ee68
00b0e71
790ca09
26c657f
db229a0
d7b1c4a
f88c558
66e6775
3585a1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,8 +23,9 @@ | |
'SigmoidBinaryCrossEntropyLoss', 'SigmoidBCELoss', | ||
'SoftmaxCrossEntropyLoss', 'SoftmaxCELoss', | ||
'KLDivLoss', 'CTCLoss', 'HuberLoss', 'HingeLoss', | ||
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss'] | ||
'SquaredHingeLoss', 'LogisticLoss', 'TripletLoss', 'PoissonNLLLoss'] | ||
|
||
import numpy as np | ||
from .. import ndarray | ||
from ..base import numeric_types | ||
from .block import HybridBlock | ||
|
@@ -706,3 +707,63 @@ def hybrid_forward(self, F, pred, positive, negative): | |
axis=self._batch_axis, exclude=True) | ||
loss = F.relu(loss + self._margin) | ||
return _apply_weighting(F, loss, self._weight, None) | ||
|
||
|
||
class PoissonNLLLoss(Loss): | ||
r"""For a target (Random Variable) in a Poisson distribution, the function calculates the Negative | ||
Log likelihood loss. | ||
PoissonNLLLoss measures the loss accrued from a poisson regression prediction made by the model. | ||
|
||
.. math:: | ||
L = \text{pred} - \text{target} * \log(\text{pred}) +\log(\text{target!}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: why ! after \text{target!} Also do you have any reference for the definition of this loss function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @anirudhacharya The "!" is for factorial of the target value. The formula to calculate the probability of the target value with the mean of the poisson distribution contains a factorial term in the denominator which is approximated by computing the Stirling approximation. Taking a log of the formula reduces to the form mentioned in the documentation. |
||
|
||
`pred`, `target` can have arbitrary shape as long as they have the same number of elements. | ||
|
||
Parameters | ||
---------- | ||
from_logits : boolean, default True | ||
indicating whether log(predicted) value has already been computed. If True, the loss is computed as | ||
:math:`\exp(\text{pred}) - \text{target} * \text{pred}`, and if False, then loss is computed as | ||
:math:`\text{pred} - \text{target} * \log(\text{pred}+\text{epsilon})`.The default value | ||
weight : float or None | ||
Global scalar weight for loss. | ||
batch_axis : int, default 0 | ||
The axis that represents mini-batch. | ||
compute_full: boolean, default False | ||
Indicates whether to add an approximation(Stirling factor) for the Factorial term in the formula for the loss. | ||
The Stirling factor is: | ||
:math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})` | ||
epsilon: float, default 1e-08 | ||
This is to avoid calculating log(0) which is not defined. | ||
|
||
|
||
Inputs: | ||
- **pred**: Predicted value | ||
- **target**: Random variable(count or number) which belongs to a Poisson distribution. | ||
- **sample_weight**: element-wise weighting tensor. Must be broadcastable | ||
to the same shape as pred. For example, if pred has shape (64, 10) | ||
and you want to weigh each sample in the batch separately, | ||
sample_weight should have shape (64, 1). | ||
|
||
Outputs: | ||
- **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,). | ||
""" | ||
def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=False, **kwargs): | ||
super(PoissonNLLLoss, self).__init__(weight, batch_axis, **kwargs) | ||
self._from_logits = from_logits | ||
self._compute_full = compute_full | ||
|
||
def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08): | ||
target = _reshape_like(F, target, pred) | ||
if self._from_logits: | ||
loss = F.exp(pred) - target * pred | ||
else: | ||
loss = pred - target * F.log(pred + epsilon) | ||
if self._compute_full: | ||
# Using numpy's pi value | ||
stirling_factor = target * F.log(target)- target + 0.5 * F.log(2 * target * np.pi) | ||
target_gt_1 = target > 1 | ||
stirling_factor *= target_gt_1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment explaining why we need this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||
loss += stirling_factor | ||
loss = _apply_weighting(F, loss, self._weight, sample_weight) | ||
return F.mean(loss) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -348,6 +348,61 @@ def test_triplet_loss(): | |
optimizer='adam') | ||
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05 | ||
|
||
@with_seed() | ||
def test_poisson_nllloss(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please separate into 3 unit tests clearly stating what condition are you testing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @roywei Yes, I have separated the use cases commenting on what exactly is being tested. I have not split into different functions so as to keep consistency with other loss functions tests. |
||
pred = mx.nd.random.normal(shape=(3, 4)) | ||
min_pred = mx.nd.min(pred) | ||
#This is necessary to ensure only positive random values are generated for prediction, | ||
# to avoid ivalid log calculation | ||
pred[:] = pred + mx.nd.abs(min_pred) | ||
target = mx.nd.random.normal(shape=(3, 4)) | ||
min_target = mx.nd.min(target) | ||
#This is necessary to ensure only positive random values are generated for prediction, | ||
# to avoid ivalid log calculation | ||
target[:] += mx.nd.abs(min_target) | ||
|
||
Loss = gluon.loss.PoissonNLLLoss(from_logits=True) | ||
Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False) | ||
#Calculating by brute formula for default value of from_logits = True | ||
|
||
# 1) Testing for flag logits = True | ||
brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy()) | ||
loss_withlogits = Loss(pred, target) | ||
assert_almost_equal(brute_loss, loss_withlogits.asscalar()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not test this loss function using the Module API with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the suggestion. It is a valid point however, getting a synthetic data set to behave in a way that Random input X's are correlated to a target ~ PoissonDistribution(target belonging to a Poisson Distribution) is difficult to gather. However, I have tried with random data sets and trained the model but the loss gradient was not observed. This could be incrementally contributed if I come across such a dataset which shows the loss gradient with epochs as other models like LogisticRegression, Linear Regression and so on. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unit test added. |
||
|
||
#2) Testing for flag logits = False | ||
loss_no_logits = Loss_no_logits(pred, target) | ||
np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08)) | ||
if np.isnan(loss_no_logits.asscalar()): | ||
assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar())) | ||
else: | ||
assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar()) | ||
|
||
#3) Testing for Sterling approximation | ||
np_pred = np.random.uniform(1, 5, (2, 3)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add tests for hybridized version as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Added the unit test for it. |
||
np_target = np.random.uniform(1, 5, (2, 3)) | ||
np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\ | ||
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1))) | ||
Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True) | ||
loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target)) | ||
assert_almost_equal(np_compute_full, loss_compute_full.asscalar()) | ||
|
||
@with_seed() | ||
def test_poisson_nllloss_mod(): | ||
N = 1000 | ||
data = mx.random.poisson(shape=(N, 2)) | ||
label = mx.random.poisson(lam=4, shape=(N, 1)) | ||
data_iter = mx.io.NDArrayIter(data, label, batch_size=20, label_name='label', shuffle=True) | ||
output = mx.sym.exp(get_net(1)) | ||
l = mx.symbol.Variable('label') | ||
Loss = gluon.loss.PoissonNLLLoss(from_logits=False) | ||
loss = Loss(output, l) | ||
loss = mx.sym.make_loss(loss) | ||
mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',)) | ||
mod.fit(data_iter, num_epoch=20, optimizer_params={'learning_rate': 0.01}, | ||
initializer=mx.init.Normal(sigma=0.1), eval_metric=mx.metric.Loss(), | ||
optimizer='adam') | ||
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05 | ||
|
||
if __name__ == '__main__': | ||
import nose | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please add line space
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added it. Thanks for pointing out.