-
Notifications
You must be signed in to change notification settings - Fork 6.8k
add pos_weight for SigmoidBinaryCrossEntropyLoss #13612
Conversation
@eureka7mt Thanks for the contribution, could you add a unit test for this case? |
@mxnet-label-bot add[Gluon, pr-awaiting-review] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor comment on the unit test and fix the CI failure.
Rest LGTM.
@eureka7mt Could you fix the Trailing whitespace issue? |
Don't know why it failed in test_multinomial_generator() which is in the file '/work/mxnet/tests/python/gpu/../unittest/test_random.py' with unix-gpu. |
@eureka7mt could you re-trigger the CI? |
@eureka7mt Could you please look into the CI failures? |
@mxnet-label-bot update [pr-awaiting-merge] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will merge it after the CI passes.
Thanks for your contribution!
Adding the if-else statement make an error.Though the default value of pos_weight is set to be 1,the pos_weight is usually an (1,N) NDArray.And it seems that an error happen in if-else statement when input is a symbol |
@eureka7mt Edit: I think the pos_weight is a scalar, since it is a binary classification loss. |
@wkcn It could be a scalar for classifying a single class.But for multi-class and multi-label classifying,it should be a tensor.Because in this situation,the number of positive examples and negative examples isn't same for each class. |
I change the order of SigmoidBinaryCrossEntropyLoss inputs from |
Sorry that I trigger the sanity problem. |
Maybe the broadcast_mul isn't necessary.I think that a NDArray * a NDArray will do broadcast_mul automatically. |
@eureka7mt I think we may pass Symbol into SigmoidBinaryCrossEntropyLoss. Symbol will not broadcast_mul automatically in my test. import mxnet as mx
from mxnet.gluon import nn
class TestBlock(nn.HybridBlock):
def __init__(self):
super(TestBlock, self).__init__()
def hybrid_forward(self, F, x, y):
return x * y
block = TestBlock()
block.hybridize()
a = mx.nd.zeros((10, 1))
b = mx.nd.ones((1, 5))
c = block(a, b)
print (c.asnumpy()) |
The PR has been merged. |
* add pos_weight for SigmoidBinaryCrossEntropyLoss in gluon.loss * Update loss.py * add test add test * set the default value of pos_weight to be 1 * fix unittest * set N be a random number * fix issues * test without random number * test with random N * fix * fix errors * fix errors * fix order * Update loss.py * Update loss.py * fix pylint * default pos_weight=None * add broadcast_mul and fix pylint * fix unittest * Update loss.py * Update loss.py * Update loss.py
* add pos_weight for SigmoidBinaryCrossEntropyLoss in gluon.loss * Update loss.py * add test add test * set the default value of pos_weight to be 1 * fix unittest * set N be a random number * fix issues * test without random number * test with random N * fix * fix errors * fix errors * fix order * Update loss.py * Update loss.py * fix pylint * default pos_weight=None * add broadcast_mul and fix pylint * fix unittest * Update loss.py * Update loss.py * Update loss.py
* add pos_weight for SigmoidBinaryCrossEntropyLoss in gluon.loss * Update loss.py * add test add test * set the default value of pos_weight to be 1 * fix unittest * set N be a random number * fix issues * test without random number * test with random N * fix * fix errors * fix errors * fix order * Update loss.py * Update loss.py * fix pylint * default pos_weight=None * add broadcast_mul and fix pylint * fix unittest * Update loss.py * Update loss.py * Update loss.py
Description
Add pos_weight for SigmoidBinaryCrossEntropyLoss.
A value
pos_weights > 1
decreases the false negative count, hence increasing the recall.Conversely setting
pos_weights < 1
decreases the false positive count and increases the precision.This can be seen from the fact that
pos_weight
is introduced as a multiplicative coefficient for the positive targets term in the loss expression:label * -log(sigmoid(pred)) * pos_weight + (1 - label) * -log(1 - sigmoid(pred))
It's adopted from tensorflow's implementation
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments