Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Simplify metrics (#5154)
Browse files Browse the repository at this point in the history
* use dist_reduce_sum

* simplify metrics using dist_reduce_sum
  • Loading branch information
AkshitaB authored Apr 27, 2021
1 parent 12f5b0f commit 530dae4
Show file tree
Hide file tree
Showing 14 changed files with 57 additions and 184 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- Use `dist_reduce_sum` in distributed metrics.

### Added

- Added `TaskSuite` base class and command line functionality for running [`checklist`](/~https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module.
Expand Down
28 changes: 7 additions & 21 deletions allennlp/training/metrics/attachment_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from overrides import overrides
import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed
from allennlp.nn.util import dist_reduce_sum
from allennlp.training.metrics.metric import Metric


Expand Down Expand Up @@ -59,7 +58,6 @@ def __call__( # type: ignore
predicted_indices, predicted_labels, gold_indices, gold_labels, mask
)
predicted_indices, predicted_labels, gold_indices, gold_labels, mask = detached
device = predicted_indices.device

if mask is None:
mask = torch.ones_like(predicted_indices).bool()
Expand All @@ -83,24 +81,12 @@ def __call__( # type: ignore
total_sentences = correct_indices.size(0)
total_words = correct_indices.numel() - (~mask).sum()

if is_distributed():
dist.all_reduce(correct_indices, op=dist.ReduceOp.SUM)
dist.all_reduce(unlabeled_exact_match, op=dist.ReduceOp.SUM)
dist.all_reduce(correct_labels_and_indices, op=dist.ReduceOp.SUM)
dist.all_reduce(labeled_exact_match, op=dist.ReduceOp.SUM)
total_sentences = torch.tensor(total_sentences, device=device)
total_words = torch.tensor(total_words, device=device)
dist.all_reduce(total_sentences, op=dist.ReduceOp.SUM)
dist.all_reduce(total_words, op=dist.ReduceOp.SUM)
total_sentences = total_sentences.item()
total_words = total_words.item()

self._unlabeled_correct += correct_indices.sum()
self._exact_unlabeled_correct += unlabeled_exact_match.sum()
self._labeled_correct += correct_labels_and_indices.sum()
self._exact_labeled_correct += labeled_exact_match.sum()
self._total_sentences += total_sentences
self._total_words += total_words
self._unlabeled_correct += dist_reduce_sum(correct_indices).sum()
self._exact_unlabeled_correct += dist_reduce_sum(unlabeled_exact_match).sum()
self._labeled_correct += dist_reduce_sum(correct_labels_and_indices).sum()
self._exact_labeled_correct += dist_reduce_sum(labeled_exact_match).sum()
self._total_sentences += dist_reduce_sum(total_sentences)
self._total_words += dist_reduce_sum(total_words)

def get_metric(
self,
Expand Down
27 changes: 7 additions & 20 deletions allennlp/training/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from allennlp.common.util import is_distributed
from allennlp.training.metrics.metric import Metric
from allennlp.nn.util import dist_reduce_sum


@Metric.register("bleu")
Expand Down Expand Up @@ -118,24 +119,18 @@ def __call__(
None
"""
predictions, gold_targets = self.detach_tensors(predictions, gold_targets)
device = gold_targets.device
if is_distributed():
world_size = dist.get_world_size()
else:
world_size = 1

for ngram_size, _ in enumerate(self._ngram_weights, start=1):
precision_matches, precision_totals = self._get_modified_precision_counts(
predictions, gold_targets, ngram_size
)
if is_distributed():
_precision_matches = torch.tensor(precision_matches, device=device)
_precision_totals = torch.tensor(precision_totals, device=device)
dist.all_reduce(_precision_matches, op=dist.ReduceOp.SUM)
dist.all_reduce(_precision_totals, op=dist.ReduceOp.SUM)
precision_matches = _precision_matches.item() / world_size
precision_totals = _precision_totals.item() / world_size

self._precision_matches[ngram_size] += precision_matches
self._precision_totals[ngram_size] += precision_totals
self._precision_matches[ngram_size] += dist_reduce_sum(precision_matches) / world_size
self._precision_totals[ngram_size] += dist_reduce_sum(precision_totals) / world_size

if not self._exclude_indices:
_prediction_lengths = predictions.size(0) * predictions.size(1)
Expand All @@ -149,16 +144,8 @@ def __call__(
_prediction_lengths = valid_predictions_mask.sum().item()
_reference_lengths = valid_gold_targets_mask.sum().item()

if is_distributed():
prediction_lengths = torch.tensor(_prediction_lengths, device=device)
reference_lengths = torch.tensor(_reference_lengths, device=device)
dist.all_reduce(prediction_lengths, op=dist.ReduceOp.SUM)
dist.all_reduce(reference_lengths, op=dist.ReduceOp.SUM)
_prediction_lengths = prediction_lengths.item()
_reference_lengths = reference_lengths.item()

self._prediction_lengths += _prediction_lengths
self._reference_lengths += _reference_lengths
self._prediction_lengths += dist_reduce_sum(_prediction_lengths)
self._reference_lengths += dist_reduce_sum(_reference_lengths)

@overrides
def get_metric(self, reset: bool = False) -> Dict[str, float]:
Expand Down
10 changes: 3 additions & 7 deletions allennlp/training/metrics/boolean_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from overrides import overrides
import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed
from allennlp.nn.util import dist_reduce_sum
from allennlp.training.metrics.metric import Metric


Expand Down Expand Up @@ -87,11 +86,8 @@ def __call__(
_correct_count = (correct * keep).sum()
_total_count = keep.sum()

if is_distributed():
dist.all_reduce(_correct_count, op=dist.ReduceOp.SUM)
dist.all_reduce(_total_count, op=dist.ReduceOp.SUM)
self._correct_count += _correct_count.item()
self._total_count += _total_count.item()
self._correct_count += dist_reduce_sum(_correct_count).item()
self._total_count += dist_reduce_sum(_total_count).item()

def get_metric(self, reset: bool = False):
"""
Expand Down
14 changes: 3 additions & 11 deletions allennlp/training/metrics/categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from overrides import overrides
import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed
from allennlp.nn.util import dist_reduce_sum
from allennlp.common.checks import ConfigurationError
from allennlp.training.metrics.metric import Metric

Expand Down Expand Up @@ -98,15 +97,8 @@ def __call__(
_total_count = torch.tensor(gold_labels.numel())
_correct_count = correct.sum()

if is_distributed():
device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu")
_correct_count = _correct_count.to(device)
_total_count = _total_count.to(device)
dist.all_reduce(_correct_count, op=dist.ReduceOp.SUM)
dist.all_reduce(_total_count, op=dist.ReduceOp.SUM)

self.correct_count += _correct_count.item()
self.total_count += _total_count.item()
self.correct_count += dist_reduce_sum(_correct_count).item()
self.total_count += dist_reduce_sum(_total_count).item()

def get_metric(self, reset: bool = False):
"""
Expand Down
14 changes: 3 additions & 11 deletions allennlp/training/metrics/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from overrides import overrides
import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed
from allennlp.nn.util import dist_reduce_sum
from allennlp.training.metrics.metric import Metric


Expand All @@ -29,7 +28,6 @@ def __call__(
A masking tensor of shape (batch_size, ...).
"""
logits, mask = self.detach_tensors(logits, mask)
device = logits.device

if mask is None:
mask = torch.ones(logits.size()[:-1], device=logits.device).bool()
Expand All @@ -40,15 +38,9 @@ def __call__(
entropy = weighted_negative_likelihood.sum(-1)

_entropy = entropy.sum() / mask.sum()
_count = 1

if is_distributed():
count = torch.tensor(_count, device=device)
dist.all_reduce(_entropy, op=dist.ReduceOp.SUM)
dist.all_reduce(count, op=dist.ReduceOp.SUM)
_count = count.item()
self._entropy += _entropy.item()
self._count += _count
self._entropy += dist_reduce_sum(_entropy).item()
self._count += dist_reduce_sum(1)

@overrides
def get_metric(self, reset: bool = False):
Expand Down
23 changes: 4 additions & 19 deletions allennlp/training/metrics/evalb_bracketing_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@
from overrides import overrides
from nltk import Tree

import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed
from allennlp.common.checks import ConfigurationError
from allennlp.training.metrics.metric import Metric
from allennlp.nn.util import dist_reduce_sum

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -153,21 +150,9 @@ def __call__(self, predicted_trees: List[Tree], gold_trees: List[Tree]) -> None:

shutil.rmtree(tempdir)

if is_distributed():
device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu")
correct_predicted_brackets = torch.tensor(_correct_predicted_brackets, device=device)
predicted_brackets = torch.tensor(_predicted_brackets, device=device)
gold_brackets = torch.tensor(_gold_brackets, device=device)
dist.all_reduce(correct_predicted_brackets, op=dist.ReduceOp.SUM)
dist.all_reduce(predicted_brackets, op=dist.ReduceOp.SUM)
dist.all_reduce(gold_brackets, op=dist.ReduceOp.SUM)
_correct_predicted_brackets = correct_predicted_brackets.item()
_predicted_brackets = predicted_brackets.item()
_gold_brackets = gold_brackets.item()

self._correct_predicted_brackets += _correct_predicted_brackets
self._gold_brackets += _gold_brackets
self._predicted_brackets += _predicted_brackets
self._correct_predicted_brackets += dist_reduce_sum(_correct_predicted_brackets)
self._gold_brackets += dist_reduce_sum(_gold_brackets)
self._predicted_brackets += dist_reduce_sum(_predicted_brackets)

@overrides
def get_metric(self, reset: bool = False):
Expand Down
21 changes: 7 additions & 14 deletions allennlp/training/metrics/fbeta_measure.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List, Optional, Union

import torch
import torch.distributed as dist
from overrides import overrides

from allennlp.common.util import is_distributed, nan_safe_tensor_divide
from allennlp.common.util import nan_safe_tensor_divide
from allennlp.common.checks import ConfigurationError
from allennlp.training.metrics.metric import Metric
from allennlp.nn.util import dist_reduce_sum


@Metric.register("fbeta")
Expand Down Expand Up @@ -110,7 +110,6 @@ def __call__(
A masking tensor the same size as `gold_labels`.
"""
predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask)
device = gold_labels.device

# Calculate true_positive_sum, true_negative_sum, pred_sum, true_sum
num_classes = predictions.size(-1)
Expand Down Expand Up @@ -142,7 +141,7 @@ def __call__(
# Watch it:
# The total numbers of true positives under all _predicted_ classes are zeros.
if true_positives_bins.shape[0] == 0:
true_positive_sum = torch.zeros(num_classes, device=device)
true_positive_sum = torch.zeros(num_classes, device=predictions.device)
else:
true_positive_sum = torch.bincount(
true_positives_bins.long(), minlength=num_classes
Expand All @@ -154,7 +153,7 @@ def __call__(
if pred_bins.shape[0] != 0:
pred_sum = torch.bincount(pred_bins, minlength=num_classes).float()
else:
pred_sum = torch.zeros(num_classes, device=device)
pred_sum = torch.zeros(num_classes, device=predictions.device)

gold_labels_bins = gold_labels[mask].long()
if gold_labels.shape[0] != 0:
Expand All @@ -164,15 +163,9 @@ def __call__(

self._total_sum += mask.sum().to(torch.float)

if is_distributed():
true_positive_sum = torch.tensor(true_positive_sum, device=device)
dist.all_reduce(true_positive_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(pred_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(true_sum, op=dist.ReduceOp.SUM)

self._true_positive_sum += true_positive_sum
self._pred_sum += pred_sum
self._true_sum += true_sum
self._true_positive_sum += dist_reduce_sum(true_positive_sum)
self._pred_sum += dist_reduce_sum(pred_sum)
self._true_sum += dist_reduce_sum(true_sum)

@overrides
def get_metric(self, reset: bool = False):
Expand Down
18 changes: 4 additions & 14 deletions allennlp/training/metrics/fbeta_multi_label_measure.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import List, Optional

import torch
import torch.distributed as dist
from overrides import overrides

from allennlp.common.util import is_distributed
from allennlp.training.metrics import FBetaMeasure
from allennlp.training.metrics.metric import Metric
from allennlp.nn.util import dist_reduce_sum


@Metric.register("fbeta_multi_label")
Expand Down Expand Up @@ -95,7 +94,6 @@ def __call__(
A masking tensor the same size as `gold_labels`.
"""
predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask)
device = gold_labels.device

# Calculate true_positive_sum, true_negative_sum, pred_sum, true_sum
num_classes = predictions.size(-1)
Expand Down Expand Up @@ -149,17 +147,9 @@ def __call__(

self._total_sum += mask.expand_as(gold_labels).sum().to(torch.float)

if is_distributed():
true_positive_sum = torch.tensor(true_positive_sum, device=device)
pred_sum = torch.tensor(pred_sum, device=device)
true_sum = torch.tensor(true_sum, device=device)
dist.all_reduce(true_positive_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(pred_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(true_sum, op=dist.ReduceOp.SUM)

self._true_positive_sum += true_positive_sum
self._pred_sum += pred_sum
self._true_sum += true_sum
self._true_positive_sum += dist_reduce_sum(true_positive_sum)
self._pred_sum += dist_reduce_sum(pred_sum)
self._true_sum += dist_reduce_sum(true_sum)

@property
def _true_negative_sum(self):
Expand Down
16 changes: 3 additions & 13 deletions allennlp/training/metrics/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from overrides import overrides
import torch
import torch.distributed as dist

from allennlp.common.util import is_distributed
from allennlp.training.metrics.metric import Metric
from allennlp.nn.util import dist_reduce_sum


@Metric.register("mean_absolute_error")
Expand Down Expand Up @@ -35,7 +34,6 @@ def __call__(
A tensor of the same shape as `predictions`.
"""
predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask)
device = gold_labels.device

absolute_errors = torch.abs(predictions - gold_labels)

Expand All @@ -46,16 +44,8 @@ def __call__(
_total_count = gold_labels.numel()
_absolute_error = torch.sum(absolute_errors)

if is_distributed():
absolute_error = torch.tensor(_absolute_error, device=device)
total_count = torch.tensor(_total_count, device=device)
dist.all_reduce(absolute_error, op=dist.ReduceOp.SUM)
dist.all_reduce(total_count, op=dist.ReduceOp.SUM)
_absolute_error = absolute_error.item()
_total_count = total_count.item()

self._absolute_error += float(_absolute_error)
self._total_count += int(_total_count)
self._absolute_error += float(dist_reduce_sum(_absolute_error))
self._total_count += int(dist_reduce_sum(_total_count))

def get_metric(self, reset: bool = False) -> Dict[str, float]:
"""
Expand Down
2 changes: 0 additions & 2 deletions allennlp/training/metrics/pearson_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from overrides import overrides
import torch

# import torch.distributed as dist

from allennlp.common.util import is_distributed
from allennlp.training.metrics.covariance import Covariance
from allennlp.training.metrics.metric import Metric
Expand Down
Loading

0 comments on commit 530dae4

Please sign in to comment.