-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_mcc_loss.py
23 lines (18 loc) · 1.08 KB
/
multi_mcc_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import tensorflow as tf
from tensorflow.keras import backend as K
def multi_mcc_loss(y_true, y_pred, false_pos_penal=1.0):
confusion_m = tf.matmul(K.transpose(y_true), y_pred)
if false_pos_penal != 1.0:
"""
This part is done for penalization of FalsePos symmetrically with FalseNeg,
i.e. FalseNeg is favorized for the same factor. In such way MCC values are comparable.
If you want to penalize FalseNeg, than just set false_pos_penal < 1.0 ;)
"""
confusion_m = tf.matrix_band_part(confusion_m, 0, 0) + tf.matrix_band_part(confusion_m, 0, -1)*false_pos_penal + tf.matrix_band_part(confusion_m, -1, 0)/false_pos_penal
N = K.sum(confusion_m)
up = N*tf.trace(confusion_m) - K.sum(tf.matmul(confusion_m, confusion_m))
down_left = K.sqrt(N**2 - K.sum(tf.matmul(confusion_m, K.transpose(confusion_m))))
down_right = K.sqrt(N**2 - K.sum(tf.matmul(K.transpose(confusion_m), confusion_m)))
mcc = up / (down_left * down_right + K.epsilon())
mcc = tf.where(tf.is_nan(mcc), tf.zeros_like(mcc), mcc)
return 1 - K.mean(mcc)