diff --git a/lit_nlp/app.py b/lit_nlp/app.py index d994893c..5453c009 100644 --- a/lit_nlp/app.py +++ b/lit_nlp/app.py @@ -477,6 +477,7 @@ def __init__( 'multiclass': metrics.MulticlassMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), + 'rouge': metrics.RougeL(), }) gradient_map_interpreters = { 'Grad L2 Norm': gradient_maps.GradientNorm(), diff --git a/lit_nlp/components/metrics.py b/lit_nlp/components/metrics.py index fc6c1401..7dd11e76 100644 --- a/lit_nlp/components/metrics.py +++ b/lit_nlp/components/metrics.py @@ -31,6 +31,7 @@ from scipy.spatial import distance as scipy_distance from sklearn import metrics as sklearn_metrics +from rouge_score import rouge_scorer JsonDict = types.JsonDict IndexedInput = types.IndexedInput Spec = types.Spec @@ -468,6 +469,45 @@ def compute(self, return {'corpus_bleu' + name_suffix: bleu.score} +class RougeL(SimpleMetrics): + """RougeL score for generation tasks.""" + + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) + + def _score(self, reference, prediction): + return self._scorer.score( + target=reference, prediction=prediction)['rougeL'].fmeasure + + def is_compatible(self, field_spec: types.LitType) -> bool: + """Return true if compatible with this field.""" + return isinstance(field_spec, + (types.GeneratedText, types.GeneratedTextCandidates)) + + def compute(self, + labels: Sequence[Text], + preds: Sequence[Union[Text, types.ScoredTextCandidates]], + label_spec: types.TextSegment, + pred_spec: Union[types.GeneratedText, + types.GeneratedTextCandidates], + config: Optional[JsonDict] = None) -> Dict[Text, float]: + """Compute metric(s) between labels and predictions.""" + del label_spec + del config + + if not labels or not preds: + return {} + + name_suffix = '' + if isinstance(pred_spec, types.GeneratedTextCandidates): + preds = [types.GeneratedTextCandidates.top_text(v) for v in preds] + name_suffix = '@1' + scores = list(map(self._score, labels, preds)) + + return {'rougeL' + name_suffix: np.mean(scores)} + + class BinaryConfusionMetricsImpl(SimpleMetrics): """Confusion matrix values for binary classification.""" diff --git a/lit_nlp/components/metrics_test.py b/lit_nlp/components/metrics_test.py index e38b7646..bb9698c5 100644 --- a/lit_nlp/components/metrics_test.py +++ b/lit_nlp/components/metrics_test.py @@ -211,45 +211,119 @@ def test_compute(self): class CorpusBLEUTest(absltest.TestCase): def test_is_compatible(self): - corpusblue_metrics = metrics.CorpusBLEU() + bleu_metrics = metrics.CorpusBLEU() + + # Only compatible with generation types. + self.assertTrue(bleu_metrics.is_compatible(types.GeneratedText())) + self.assertTrue(bleu_metrics.is_compatible(types.GeneratedTextCandidates())) - # Only compatible with GeneratedText spec. - self.assertTrue(corpusblue_metrics.is_compatible(types.GeneratedText())) self.assertFalse( - corpusblue_metrics.is_compatible(types.MulticlassPreds(vocab=['']))) - self.assertFalse(corpusblue_metrics.is_compatible(types.RegressionScore())) + bleu_metrics.is_compatible(types.MulticlassPreds(vocab=['']))) + self.assertFalse(bleu_metrics.is_compatible(types.RegressionScore())) def test_compute(self): - corpusblue_metrics = metrics.CorpusBLEU() + bleu_metrics = metrics.CorpusBLEU() # All correct predictions. - result = corpusblue_metrics.compute( + result = bleu_metrics.compute( ['This is a test.', 'Test two', 'A third test example'], ['This is a test.', 'Test two', 'A third test example'], types.GeneratedText(), types.GeneratedText()) testing_utils.assert_deep_almost_equal(self, result, - {'corpus_bleu': 100.00000}) + {'corpus_bleu': 100.0000}) # Some incorrect predictions. - result = corpusblue_metrics.compute( + result = bleu_metrics.compute( ['This is a test.', 'Test one', 'A third test'], ['This is a test.', 'Test two', 'A third test example'], types.GeneratedText(), types.GeneratedText()) testing_utils.assert_deep_almost_equal(self, result, {'corpus_bleu': 68.037493}) - result = corpusblue_metrics.compute( + result = bleu_metrics.compute( ['This is a test.', 'Test one', 'A third test'], ['these test.', 'Test two', 'A third test example'], types.GeneratedText(), types.GeneratedText()) testing_utils.assert_deep_almost_equal(self, result, - {'corpus_bleu': 29.508062388758525}) + {'corpus_bleu': 29.508062}) + + # Empty labels and predictions + result = bleu_metrics.compute([], [], types.GeneratedText(), + types.GeneratedText()) + testing_utils.assert_deep_almost_equal(self, result, {}) + + def test_compute_with_candidates(self): + bleu_metrics = metrics.CorpusBLEU() + + # Should only score the first one (@1). + labels = ['This is a test.', 'Test two'] + preds = [ + [('This is a test.', -1.0), ('foobar', -20.0)], + [('Test two', -1.0), ('spam', -20.0)], + ] + + result = bleu_metrics.compute(labels, preds, types.TextSegment(), + types.GeneratedTextCandidates()) + testing_utils.assert_deep_almost_equal(self, result, + {'corpus_bleu@1': 100.0000}) + + +class RougeLTest(absltest.TestCase): + + def test_is_compatible(self): + rouge_metrics = metrics.RougeL() + + # Only compatible with generation types. + self.assertTrue(rouge_metrics.is_compatible(types.GeneratedText())) + self.assertTrue( + rouge_metrics.is_compatible(types.GeneratedTextCandidates())) + + self.assertFalse( + rouge_metrics.is_compatible(types.MulticlassPreds(vocab=['']))) + self.assertFalse(rouge_metrics.is_compatible(types.RegressionScore())) + + def test_compute(self): + rouge_metrics = metrics.RougeL() + + # All correct predictions. + result = rouge_metrics.compute( + ['This is a test.', 'Test two', 'A third test example'], + ['This is a test.', 'Test two', 'A third test example'], + types.TextSegment(), types.GeneratedText()) + testing_utils.assert_deep_almost_equal(self, result, {'rougeL': 1.0}) + + # Some incorrect predictions. + result = rouge_metrics.compute( + ['This is a test.', 'Test one', 'A third test'], + ['This is a test.', 'Test two', 'A third test example'], + types.TextSegment(), types.GeneratedText()) + testing_utils.assert_deep_almost_equal(self, result, {'rougeL': 0.785714}) + + result = rouge_metrics.compute( + ['This is a test.', 'Test one', 'A third test'], + ['these test.', 'Test two', 'A third test example'], + types.TextSegment(), types.GeneratedText()) + testing_utils.assert_deep_almost_equal(self, result, {'rougeL': 0.563492}) # Empty labels and predictions - result = corpusblue_metrics.compute([], [], types.GeneratedText(), - types.GeneratedText()) + result = rouge_metrics.compute([], [], types.GeneratedText(), + types.GeneratedText()) testing_utils.assert_deep_almost_equal(self, result, {}) + def test_compute_with_candidates(self): + rouge_metrics = metrics.RougeL() + + # Should only score the first one (@1). + labels = ['This is a test.', 'Test two'] + preds = [ + [('This is a test.', -1.0), ('foobar', -20.0)], + [('Test two', -1.0), ('spam', -20.0)], + ] + + result = rouge_metrics.compute(labels, preds, types.TextSegment(), + types.GeneratedTextCandidates()) + testing_utils.assert_deep_almost_equal(self, result, {'rougeL@1': 1.0}) + class ClassifcationMarginTest(absltest.TestCase):