-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding mnist training example (pytorch) (#278)
* Add MNIST example with PyTorch * Add mpi4py MNIST example * Add allreduce timers, don't LB by default * first commit adding pytorch cuda support * adding cuda pytorch support to mpi4py impl * working cuda device assignment * basic readme * resolving srun issue * Do reshaping and division on device * time step -> epoch * typo * remove unnecessary copy and set dev idx manually --------- Co-authored-by: Jaemin Choi <jchoi157@illinois.edu> Co-authored-by: Zane Fink <finkzane@gmail.com>
- Loading branch information
1 parent
9e05b2b
commit 04d57a5
Showing
3 changed files
with
416 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
## Distributed MNIST with Pytorch | ||
|
||
This example implements a data-parallel distributed training algorithm on the MNIST dataset. The implementation uses pytorch to do computation on the GPU, if available. After compute at every epoch, data is collected on the CPU for global reduction. | ||
|
||
### Running the example | ||
|
||
First, install the necessary dependencies: | ||
|
||
`pip install torch torchvision` | ||
|
||
The Charm4py implementation can be run with srun or charmrun: | ||
|
||
`srun python mnist.py` | ||
`python -m charmrun.start mnist.py +p2` | ||
|
||
To run the mpi4py example with mpiexec on 2 processors: | ||
|
||
`mpiexec -n 2 python -m mpi4py mnist-mpi4py.py` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Distributed deep learning example with MNIST dataset, using mpi4py and PyTorch. | ||
# Adapted from https://pytorch.org/tutorials/intermediate/dist_tuto.html | ||
# and /~https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py. | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import math | ||
import random | ||
import time | ||
import sys | ||
from torch.autograd import Variable | ||
from torchvision import datasets, transforms | ||
from mpi4py import MPI | ||
import numpy as np | ||
|
||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
nprocs = comm.Get_size() | ||
|
||
# Dataset partitioning helper | ||
class Partition(object): | ||
|
||
def __init__(self, data, index): | ||
self.data = data | ||
self.index = index | ||
|
||
def __len__(self): | ||
return len(self.index) | ||
|
||
def __getitem__(self, index): | ||
data_idx = self.index[index] | ||
return self.data[data_idx] | ||
|
||
class DataPartitioner(object): | ||
|
||
def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): | ||
self.data = data | ||
self.partitions = [] | ||
rng = random.Random() | ||
rng.seed(seed) | ||
data_len = len(data) | ||
indexes = [x for x in range(0, data_len)] | ||
rng.shuffle(indexes) | ||
|
||
for frac in sizes: | ||
part_len = int(frac * data_len) | ||
self.partitions.append(indexes[0:part_len]) | ||
indexes = indexes[part_len:] | ||
|
||
def use(self, partition): | ||
return Partition(self.data, self.partitions[partition]) | ||
|
||
# Neural network architecture | ||
class Net(nn.Module): | ||
|
||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||
self.conv2_drop = nn.Dropout2d() | ||
self.fc1 = nn.Linear(320, 50) | ||
self.fc2 = nn.Linear(50, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | ||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | ||
x = x.view(-1, 320) | ||
x = F.relu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
# Worker object (1 per MPI rank) | ||
class Worker(object): | ||
|
||
def __init__(self, num_workers, epochs): | ||
self.num_workers = num_workers | ||
self.epochs = epochs | ||
self.agg_time = 0.0 | ||
self.time_cnt = 0 | ||
self.agg_time_all = 0.0 | ||
|
||
# Partitioning MNIST dataset | ||
def partition_dataset(self): | ||
dataset = datasets.MNIST('./data', train=True, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])) | ||
size = self.num_workers | ||
bsz = int(128 / float(size)) # my batch size | ||
partition_sizes = [1.0 / size for _ in range(size)] | ||
partition = DataPartitioner(dataset, partition_sizes) | ||
partition = partition.use(rank) | ||
train_set = torch.utils.data.DataLoader(partition, | ||
batch_size=bsz, | ||
shuffle=True) | ||
return train_set, bsz | ||
|
||
# Distributed SGD | ||
def run(self, device): | ||
# Starting a new run | ||
self.train_set, bsz = self.partition_dataset() | ||
self.model = Net().to(device) | ||
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.5) | ||
self.num_batches = math.ceil(len(self.train_set.dataset) / float(bsz)) | ||
self.epoch = 0 | ||
|
||
while self.epoch < self.epochs: | ||
t0 = time.time() | ||
epoch_loss = 0.0 | ||
for data, target in self.train_set: | ||
data, target = data.to(device), target.to(device) | ||
self.optimizer.zero_grad() | ||
output = self.model(data) | ||
loss = F.nll_loss(output, target) | ||
epoch_loss += loss.item() | ||
loss.backward() | ||
self.average_gradients(self.model, device) | ||
self.optimizer.step() | ||
print(f'Rank {rank:4d} | Epoch {self.epoch:4d} | Loss {(epoch_loss / self.num_batches):9.3f} | Time {(time.time() - t0):9.3f}') | ||
self.epoch += 1 | ||
|
||
print(f'Rank {rank:4d} training complete, average allreduce time (us): {((self.agg_time / self.time_cnt) * 1000000):9.3f}') | ||
agg_time_arr = np.array([self.agg_time]) | ||
agg_time_all_arr = np.array([0.0]) | ||
comm.Allreduce(agg_time_arr, agg_time_all_arr, op=MPI.SUM) | ||
self.agg_time_all = agg_time_all_arr[0] | ||
if rank == 0: | ||
print(f'Rank {rank:4d} all average allreduce time (us): {((self.agg_time_all / self.num_workers / self.time_cnt) * 1000000):9.3f}') | ||
|
||
|
||
# Gradient averaging | ||
def average_gradients(self, model, device): | ||
for param in model.parameters(): | ||
param.grad.data = param.grad.data.cpu() | ||
# Obtain numpy arrays from gradient data | ||
data_shape = param.grad.data.shape | ||
send_data = param.grad.data.numpy() | ||
recv_data = np.empty_like(send_data) | ||
|
||
# Blocking allreduce | ||
start_time = time.time() | ||
comm.Allreduce(send_data, recv_data, op=MPI.SUM) | ||
self.agg_time += time.time() - start_time | ||
self.time_cnt += 1 | ||
|
||
# Restore original shape of gradient data | ||
param.grad.data = torch.from_numpy(recv_data).to(device) | ||
param.grad.data = param.grad.data.reshape(data_shape) / float(self.num_workers) | ||
|
||
|
||
def main(): | ||
# Initialize PyTorch on all PEs | ||
num_threads = 1 | ||
torch.set_num_threads(num_threads) | ||
torch.manual_seed(1234) | ||
print(f'MPI rank {rank} initialized PyTorch with {num_threads} threads') | ||
|
||
if torch.cuda.is_available(): | ||
# if multiple devices are available (running with mpirun, not srun), should assign round-robin | ||
dev_id = rank % torch.cuda.device_count() | ||
device = torch.device("cuda:" + str(dev_id)) | ||
else: | ||
device = torch.device("cpu") | ||
|
||
# Create workers and start training | ||
epochs = 6 | ||
workers = Worker(nprocs, epochs) | ||
t0 = time.time() | ||
print(f'Starting MNIST dataset training with {nprocs} MPI processes for {epochs} epochs on device {device}') | ||
workers.run(device) | ||
|
||
comm.Barrier() | ||
|
||
# Training complete | ||
if rank == 0: | ||
print(f'Done. Elapsed time: {(time.time() - t0):9.3f} s') | ||
|
||
main() |
Oops, something went wrong.