Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Enable multi-process training on CPU (#4272)
Browse files Browse the repository at this point in the history
* Use torch.device everywhere

* Update changelog

* Run distributed tests even on CPU

* Fix bug when running distributed tests on CPU

* Remove unused imports

* Update CHANGELOG.md

Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com>

Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com>
  • Loading branch information
dirkgr and epwalsh authored May 21, 2020
1 parent 7e683dd commit f27475a
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 71 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Additional CI checks to ensure docstrings are consistently formatted.
- Ability to train on CPU with multiple processes by setting `cuda_devices` to a list of negative integers in your training config. For example: `"distributed": {"cuda_devices": [-1, -1]}`. This is mainly to make it easier to test and debug distributed training code..

### Changed

Expand Down
22 changes: 15 additions & 7 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,21 @@ def _train_worker(
params["trainer"]["world_size"] = world_size
params["trainer"]["distributed"] = True

torch.cuda.set_device(int(gpu_id))
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{master_addr}:{master_port}",
world_size=world_size,
rank=global_rank,
)
if gpu_id >= 0:
torch.cuda.set_device(int(gpu_id))
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{master_addr}:{master_port}",
world_size=world_size,
rank=global_rank,
)
else:
dist.init_process_group(
backend="gloo",
init_method=f"tcp://{master_addr}:{master_port}",
world_size=world_size,
rank=global_rank,
)
logging.info(
f"Process group of world size {world_size} initialized "
f"for distributed training in worker {global_rank}"
Expand Down
65 changes: 36 additions & 29 deletions allennlp/common/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import subprocess

import torch
from torch import cuda

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -100,36 +101,42 @@ def from_list(strings):
return int(cuda_device) # type: ignore


def check_for_gpu(device_id: Union[int, List[int]]):
if isinstance(device_id, list):
for did in device_id:
def check_for_gpu(device: Union[int, torch.device, List[Union[int, torch.device]]]):
if isinstance(device, list):
for did in device:
check_for_gpu(did)
elif device_id is not None and device_id >= 0:
num_devices_available = cuda.device_count()
if num_devices_available == 0:
# Torch will give a more informative exception than ours, so we want to include
# that context as well if it's available. For example, if you try to run torch 1.5
# on a machine with CUDA10.1 you'll get the following:
#
# The NVIDIA driver on your system is too old (found version 10010).
#
torch_gpu_error = ""
try:
cuda._check_driver()
except Exception as e:
torch_gpu_error = "\n{0}".format(e)

raise ConfigurationError(
"Experiment specified a GPU but none is available;"
" if you want to run on CPU use the override"
" 'trainer.cuda_device=-1' in the json config file." + torch_gpu_error
)
elif device_id >= num_devices_available:
raise ConfigurationError(
f"Experiment specified GPU device {device_id}"
f" but there are only {num_devices_available} devices "
f" available."
)
elif device is None:
return
else:
from allennlp.common.util import int_to_device

device = int_to_device(device)
if device != torch.device("cpu"):
num_devices_available = cuda.device_count()
if num_devices_available == 0:
# Torch will give a more informative exception than ours, so we want to include
# that context as well if it's available. For example, if you try to run torch 1.5
# on a machine with CUDA10.1 you'll get the following:
#
# The NVIDIA driver on your system is too old (found version 10010).
#
torch_gpu_error = ""
try:
cuda._check_driver()
except Exception as e:
torch_gpu_error = "\n{0}".format(e)

raise ConfigurationError(
"Experiment specified a GPU but none is available;"
" if you want to run on CPU use the override"
" 'trainer.cuda_device=-1' in the json config file." + torch_gpu_error
)
elif device.index >= num_devices_available:
raise ConfigurationError(
f"Experiment specified GPU device {device.index}"
f" but there are only {num_devices_available} devices "
f" available."
)


def check_for_java() -> bool:
Expand Down
7 changes: 7 additions & 0 deletions allennlp/common/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@ def requires_multi_gpu(test_method):
test_method
)
)


def cpu_or_gpu(test_method):
"""
Decorator to indicate that a test should run on both CPU and GPU
"""
return pytest.mark.gpu(test_method)
8 changes: 8 additions & 0 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,14 @@ def is_lazy(iterable: Iterable[A]) -> bool:
return not isinstance(iterable, list)


def int_to_device(device: Union[int, torch.device]) -> torch.device:
if isinstance(device, torch.device):
return device
if device < 0:
return torch.device("cpu")
return torch.device(device)


def log_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> None:
frozen_parameter_names, tunable_parameter_names = get_frozen_and_tunable_parameter_names(model)

Expand Down
7 changes: 5 additions & 2 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ def has_tensor(obj) -> bool:
return False


def move_to_device(obj, cuda_device: int):
def move_to_device(obj, cuda_device: Union[torch.device, int]):
"""
Given a structure (possibly) containing Tensors on the CPU,
move all the Tensors to the specified GPU (or do nothing, if they should be on the CPU).
"""
from allennlp.common.util import int_to_device

if cuda_device < 0 or not has_tensor(obj):
cuda_device = int_to_device(cuda_device)

if cuda_device == torch.device("cpu") or not has_tensor(obj):
return obj
elif isinstance(obj, torch.Tensor):
return obj.cuda(cuda_device)
Expand Down
33 changes: 14 additions & 19 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import time
import traceback
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from allennlp.common.util import int_to_device

try:
from apex import amp
Expand Down Expand Up @@ -49,7 +51,7 @@ class Trainer(Registrable):
def __init__(
self,
serialization_dir: str,
cuda_device: int = -1,
cuda_device: Union[int, torch.device] = -1,
distributed: bool = False,
local_rank: int = 0,
world_size: int = 1,
Expand All @@ -65,28 +67,19 @@ def __init__(
"our Trainer always uses a single GPU per process."
)

if not isinstance(cuda_device, int):
raise ConfigurationError("Expected an int for cuda_device, got {}".format(cuda_device))

if distributed and world_size <= 1:
raise ConfigurationError(
"Distributed training can be performed only with more than 1 GPU device. Check "
"Distributed training can be performed only with more than 1 device. Check "
"`cuda_device` key in the experiment configuration."
)

self.cuda_device = cuda_device
self.cuda_device = int_to_device(cuda_device)

self._distributed = distributed
self._rank = local_rank
self._master = self._rank == 0
self._world_size = world_size

def _move_to_gpu(self, model: Model) -> Model:
if self.cuda_device != -1:
return model.cuda(self.cuda_device)
else:
return model

def train(self) -> Dict[str, Any]:
"""
Train a model and return the results.
Expand Down Expand Up @@ -383,7 +376,9 @@ def __init__(
# these places: `model.__call__`, `model.train` and `model.eval`.
if self._distributed:
self._pytorch_model = DistributedDataParallel(
self.model, device_ids=[self.cuda_device], find_unused_parameters=True
self.model,
device_ids=None if self.cuda_device == torch.device("cpu") else [self.cuda_device],
find_unused_parameters=True,
)
else:
self._pytorch_model = self.model
Expand Down Expand Up @@ -556,7 +551,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
train_reg_loss,
batches_this_epoch,
world_size=self._world_size,
cuda_device=[self.cuda_device],
cuda_device=self.cuda_device,
)

if self._master:
Expand Down Expand Up @@ -600,7 +595,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
batches_this_epoch,
reset=True,
world_size=self._world_size,
cuda_device=[self.cuda_device],
cuda_device=self.cuda_device,
)
metrics["cpu_memory_MB"] = peak_cpu_usage
for (gpu_num, memory) in gpu_usage:
Expand Down Expand Up @@ -672,7 +667,7 @@ def _validation_loss(self, epoch: int) -> Tuple[float, float, int]:
val_reg_loss,
batches_this_epoch,
world_size=self._world_size,
cuda_device=[self.cuda_device],
cuda_device=self.cuda_device,
)
description = training_util.description_from_metrics(val_metrics)
val_generator_tqdm.set_description(description, refresh=False)
Expand All @@ -693,7 +688,7 @@ def _validation_loss(self, epoch: int) -> Tuple[float, float, int]:
f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)."
)
# Indicate that we're done so that any workers that have remaining data stop validation early.
done = torch.tensor(1, device=self.cuda_device if self.cuda_device >= 0 else None)
done = torch.tensor(1, device=self.cuda_device)
torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
assert done.item()

Expand Down Expand Up @@ -764,7 +759,7 @@ def train(self) -> Dict[str, Any]:
num_batches,
reset=True,
world_size=self._world_size,
cuda_device=[self.cuda_device],
cuda_device=self.cuda_device,
)

# Check validation metric for early stopping
Expand Down
9 changes: 3 additions & 6 deletions allennlp/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
import shutil
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -284,7 +284,7 @@ def get_metrics(
num_batches: int,
reset: bool = False,
world_size: int = 1,
cuda_device: Union[int, List] = 0,
cuda_device: Union[int, torch.device] = torch.device("cpu"),
) -> Dict[str, float]:
"""
Gets the metrics but sets `"loss"` to
Expand All @@ -299,10 +299,7 @@ def get_metrics(
# In distributed mode, average out all metrics across GPUs
aggregated_metrics = {}
for metric_name, metric_val in metrics.items():
if isinstance(cuda_device, list):
metric_tensor = torch.tensor(metric_val).to(torch.device(cuda_device[0]))
else:
metric_tensor = torch.tensor(metric_val).to(torch.device(cuda_device))
metric_tensor = torch.tensor(metric_val).to(cuda_device)
dist.all_reduce(metric_tensor, op=dist.ReduceOp.SUM)
reduced_metric = metric_tensor.item() / world_size
aggregated_metrics[metric_name] = reduced_metric
Expand Down
29 changes: 22 additions & 7 deletions tests/commands/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from allennlp.commands.train import Train, train_model, train_model_from_args, TrainModel
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase, requires_gpu, requires_multi_gpu
from allennlp.common.testing import AllenNlpTestCase, cpu_or_gpu
from allennlp.data import DatasetReader, Instance, Vocabulary
from allennlp.data.dataloader import TensorDict
from allennlp.models import load_archive, Model
Expand Down Expand Up @@ -111,8 +111,13 @@ def test_train_model(self):
recover=True,
)

@requires_gpu
@cpu_or_gpu
def test_train_model_distributed(self):
if torch.cuda.device_count() >= 2:
devices = [0, 1]
else:
devices = [-1, -1]

params = lambda: Params(
{
"model": {
Expand All @@ -127,7 +132,7 @@ def test_train_model_distributed(self):
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
"data_loader": {"batch_size": 2},
"trainer": {"num_epochs": 2, "optimizer": "adam"},
"distributed": {"cuda_devices": [0, 1]},
"distributed": {"cuda_devices": devices},
}
)

Expand All @@ -146,9 +151,14 @@ def test_train_model_distributed(self):
# Check we can load the serialized model
assert load_archive(out_dir).model

@requires_multi_gpu
@cpu_or_gpu
@pytest.mark.parametrize("lazy", [True, False])
def test_train_model_distributed_with_sharded_reader(self, lazy):
if torch.cuda.device_count() >= 2:
devices = [0, 1]
else:
devices = [-1, -1]

params = lambda: Params(
{
"model": {
Expand All @@ -167,7 +177,7 @@ def test_train_model_distributed_with_sharded_reader(self, lazy):
"validation_data_path": SEQUENCE_TAGGING_SHARDS_PATH,
"data_loader": {"batch_size": 2},
"trainer": {"num_epochs": 2, "optimizer": "adam"},
"distributed": {"cuda_devices": [0, 1]},
"distributed": {"cuda_devices": devices},
}
)

Expand Down Expand Up @@ -232,9 +242,14 @@ def test_train_model_distributed_with_sharded_reader(self, lazy):
assert train_complete in worker1_log
assert validation_complete in worker1_log

@requires_multi_gpu
@cpu_or_gpu
@pytest.mark.parametrize("lazy", [True, False])
def test_train_model_distributed_without_sharded_reader(self, lazy: bool):
if torch.cuda.device_count() >= 2:
devices = [0, 1]
else:
devices = [-1, -1]

num_epochs = 2
params = lambda: Params(
{
Expand All @@ -256,7 +271,7 @@ def test_train_model_distributed_without_sharded_reader(self, lazy: bool):
"tests.commands.train_test.TrainingDataLoggerBatchCallback"
],
},
"distributed": {"cuda_devices": [0, 1]},
"distributed": {"cuda_devices": devices},
}
)

Expand Down
2 changes: 1 addition & 1 deletion tests/nn/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ class A(NamedTuple):
"b": FakeTensor(),
"c": (1, FakeTensor()),
}
new_device = 4
new_device = torch.device(4)
moved_obj = util.move_to_device(structured_obj, new_device)
assert moved_obj["a"][0].a == 1
assert moved_obj["a"][0].b._device == new_device
Expand Down

0 comments on commit f27475a

Please sign in to comment.