diff --git a/example/named_entity_recognition/src/metrics.py b/example/named_entity_recognition/src/metrics.py index d3d73782c62e..ef5f64fb1af3 100644 --- a/example/named_entity_recognition/src/metrics.py +++ b/example/named_entity_recognition/src/metrics.py @@ -50,15 +50,20 @@ def classifer_metrics(label, pred): correct_entitites = np.sum(corr_pred[pred_is_entity]) #precision: when we predict entity, how often are we right? - precision = correct_entitites/entity_preds if entity_preds == 0: precision = np.nan + else: + precision = correct_entitites/entity_preds #recall: of the things that were an entity, how many did we catch? recall = correct_entitites / num_entities if num_entities == 0: recall = np.nan - f1 = 2 * precision * recall / (precision + recall) + # To prevent dozens of warning: RuntimeWarning: divide by zero encountered in long_scalars + if precision + recall == 0: + f1 = 0 + else: + f1 = 2 * precision * recall / (precision + recall) return precision, recall, f1 def entity_precision(label, pred):