Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
consolidate testing decorators (#4213)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored May 8, 2020
1 parent 72061b1 commit 0bcab36
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 34 deletions.
37 changes: 36 additions & 1 deletion allennlp/common/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
"""
Utilities and helpers for writing tests.
"""
from allennlp.common.testing.test_case import AllenNlpTestCase, multi_device
import torch
import pytest

from allennlp.common.testing.test_case import AllenNlpTestCase
from allennlp.common.testing.model_test_case import ModelTestCase


_available_devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


def multi_device(test_method):
"""
Decorator that provides an argument `device` of type `str` for each available PyTorch device.
"""
return pytest.mark.parametrize("device", _available_devices)(pytest.mark.gpu(test_method))


def requires_gpu(test_method):
"""
Decorator to indicate that a test requires a GPU device.
"""
return pytest.mark.gpu(
pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")(
test_method
)
)


def requires_multi_gpu(test_method):
"""
Decorator to indicate that a test requires multiple GPU devices.
"""
return pytest.mark.gpu(
pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 or more GPUs required.")(
test_method
)
)
13 changes: 0 additions & 13 deletions allennlp/common/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import tempfile
from unittest import mock

import torch
import pytest

from allennlp.common.checks import log_pytorch_version_info

TEST_DIR = tempfile.mkdtemp(prefix="allennlp_tests")
Expand Down Expand Up @@ -57,13 +54,3 @@ def _cleanup_archive_dir_without_logging(path: str):
def teardown_method(self):
shutil.rmtree(self.TEST_DIR)
self.patcher.stop()


_available_devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


def multi_device(test_method):
"""
Decorator that provides an argument `device` of type `str` for each available PyTorch device.
"""
return pytest.mark.parametrize("device", _available_devices)(pytest.mark.gpu(test_method))
8 changes: 3 additions & 5 deletions allennlp/tests/commands/find_learning_rate_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import argparse
import os
import pytest

import torch
import pytest

from allennlp.common import Params
from allennlp.data import Vocabulary
from allennlp.data import DataLoader
from allennlp.models import Model
from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.testing import AllenNlpTestCase, requires_multi_gpu
from allennlp.commands.find_learning_rate import (
search_learning_rate,
find_learning_rate_from_args,
Expand Down Expand Up @@ -131,8 +130,7 @@ def test_find_learning_rate_args(self):
parser.parse_args(["find-lr", "path/to/params"])
assert cm.exception.code == 2 # argparse code for incorrect usage

@pytest.mark.gpu
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need multiple GPUs.")
@requires_multi_gpu
def test_find_learning_rate_multi_gpu(self):
params = self.params()
del params["trainer"]["cuda_device"]
Expand Down
8 changes: 3 additions & 5 deletions allennlp/tests/commands/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from allennlp.commands.train import Train, train_model, train_model_from_args, TrainModel
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.testing import AllenNlpTestCase, requires_gpu, requires_multi_gpu
from allennlp.data import DatasetReader, Instance, Vocabulary
from allennlp.data.dataloader import TensorDict
from allennlp.models import load_archive, Model
Expand Down Expand Up @@ -90,8 +90,7 @@ def test_train_model(self):
recover=True,
)

@pytest.mark.gpu
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need multiple GPUs.")
@requires_gpu
def test_train_model_distributed(self):
params = lambda: Params(
{
Expand Down Expand Up @@ -126,8 +125,7 @@ def test_train_model_distributed(self):
# Check we can load the serialized model
assert load_archive(out_dir).model

@pytest.mark.gpu
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need multiple GPUs.")
@requires_multi_gpu
def test_train_model_distributed_with_sharded_reader(self):
params = lambda: Params(
{
Expand Down
5 changes: 2 additions & 3 deletions allennlp/tests/modules/encoder_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn import LSTM, RNN

from allennlp.modules.encoder_base import _EncoderBase
from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.testing import AllenNlpTestCase, requires_gpu
from allennlp.nn.util import sort_batch_by_length, get_lengths_from_binary_sequence_mask


Expand Down Expand Up @@ -310,8 +310,7 @@ def test_non_contiguous_initial_states_handled(self):
encoder_base._update_states([final_states[0]], self.restoration_indices)
encoder_base.sort_and_run_forward(self.rnn, self.tensor, self.mask)

@pytest.mark.gpu
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
@requires_gpu
def test_non_contiguous_initial_states_handled_on_gpu(self):
# Some PyTorch operations which produce contiguous tensors on the CPU produce
# non-contiguous tensors on the GPU (e.g. forward pass of an RNN when batch_first=True).
Expand Down
11 changes: 4 additions & 7 deletions allennlp/tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from allennlp.common.checks import ConfigurationError
from allennlp.common.params import Params
from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.testing import AllenNlpTestCase, requires_gpu, requires_multi_gpu
from allennlp.data import Vocabulary
from allennlp.data.dataloader import TensorDict
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader
Expand Down Expand Up @@ -125,8 +125,7 @@ def test_trainer_can_run_exponential_moving_average(self):
)
trainer.train()

@pytest.mark.gpu
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
@requires_gpu
def test_trainer_can_run_cuda(self):
self.model.cuda()
trainer = GradientDescentTrainer(
Expand All @@ -139,8 +138,7 @@ def test_trainer_can_run_cuda(self):
assert "peak_gpu_0_memory_MB" in metrics
assert isinstance(metrics["peak_gpu_0_memory_MB"], int)

@pytest.mark.gpu
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 or more GPUs required.")
@requires_multi_gpu
def test_passing_trainer_multiple_gpus_raises_error(self):
self.model.cuda()

Expand Down Expand Up @@ -989,8 +987,7 @@ def __call__(


class TestApexTrainer(TrainerTestBase):
@pytest.mark.gpu
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
@requires_gpu
@pytest.mark.skipif(amp is None, reason="Apex is not installed.")
def test_trainer_can_run_amp(self):
self.model.cuda()
Expand Down

0 comments on commit 0bcab36

Please sign in to comment.