examples/pytorch_mnist/training/ddp/mnist/mnist_DDP.py

246 lines
8.3 KiB
Python
Executable File

#!/usr/bin/env python
'''
Copyright 2018 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''
import datetime
import logging
import os
from math import ceil
from random import Random
import torch
import torch.distributed as dist
import torch.nn as nn # pylint: disable = all
import torch.nn.functional as F
import torch.optim as optim # pylint: disable = all
import torch.utils.data
import torch.utils.data.distributed
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.autograd import Variable
from torch.nn.modules import Module
from torchvision import datasets, transforms
gbatch_size = 128
class DistributedDataParallel(Module):
def __init__(self, module):
super(DistributedDataParallel, self).__init__()
self.module = module
self.first_call = True
def allreduce_params():
if self.needs_reduction:
self.needs_reduction = False # pylint: disable = attribute-defined-outside-init
buckets = {}
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
for param in list(self.module.parameters()):
def allreduce_hook(*unused): # pylint: disable = unused-argument
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable = protected-access
if param.requires_grad:
param.register_hook(allreduce_hook)
def weight_broadcast(self):
for param in self.module.parameters():
dist.broadcast(param.data, 0)
def forward(self, *inputs, **kwargs): # pylint: disable = arguments-differ
if self.first_call:
logging.info("first broadcast start")
self.weight_broadcast()
self.first_call = False
logging.info("first broadcast done")
self.needs_reduction = True # pylint: disable = attribute-defined-outside-init
return self.module(*inputs, **kwargs)
class Partition(object): # pylint: disable = all
""" Dataset-like object, but only access a subset of it. """
def __init__(self, data, index):
self.data = data
self.index = index
def __len__(self):
return len(self.index)
def __getitem__(self, index):
data_idx = self.index[index]
return self.data[data_idx]
class DataPartitioner(object): # pylint: disable = all
""" Partitions a dataset into different chuncks. """
def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): # pylint: disable = dangerous-default-value
self.data = data
self.partitions = []
rng = Random()
rng.seed(seed)
data_len = len(data)
indexes = [x for x in range(0, data_len)]
rng.shuffle(indexes)
for frac in sizes:
part_len = int(frac * data_len)
self.partitions.append(indexes[0:part_len])
indexes = indexes[part_len:]
def use(self, partition):
return Partition(self.data, self.partitions[partition])
class Net(nn.Module):
""" Network architecture. """
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x): # pylint: disable = arguments-differ
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def partition_dataset(rank):
""" Partitioning MNIST """
dataset = datasets.MNIST(
'./data{}'.format(rank),
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
size = dist.get_world_size()
bsz = int(gbatch_size / float(size))
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
train_set = torch.utils.data.DataLoader(
dataset, batch_size=bsz, shuffle=(train_sampler is None), sampler=train_sampler)
return train_set, bsz
def average_gradients(model):
""" Gradient averaging. """
size = float(dist.get_world_size())
group = dist.new_group([0])
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=group)
param.grad.data /= size
def run(modelpath, gpu):
""" Distributed Synchronous SGD Example """
rank = dist.get_rank()
torch.manual_seed(1234)
train_set, bsz = partition_dataset(rank)
model = Net()
if gpu:
model = model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
else:
model = DistributedDataParallel(model)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
model_dir = modelpath
num_batches = ceil(len(train_set.dataset) / float(bsz))
logging.info("num_batches = %s", num_batches)
time_start = datetime.datetime.now()
for epoch in range(3):
epoch_loss = 0.0
for data, target in train_set:
if gpu:
data, target = Variable(data).cuda(), Variable(target).cuda()
else:
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
epoch_loss += loss.item()
loss.backward()
average_gradients(model)
optimizer.step()
logging.info('Epoch {} Loss {:.6f} Global batch size {} on {} ranks'.format(
epoch, epoch_loss / num_batches, gbatch_size, dist.get_world_size()))
# Ensure only the master node saves the model
main_proc = rank == 0
if main_proc:
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if gpu:
model_path = model_dir + "/model_gpu.dat"
else:
model_path = model_dir + "/model_cpu.dat"
logging.info("Saving model in {}".format(model_path)) # pylint: disable = logging-format-interpolation
torch.save(model.module.state_dict(), model_path)
if gpu:
logging.info("GPU training time= {}".format( # pylint: disable = logging-format-interpolation
str(datetime.datetime.now() - time_start))) # pylint: disable = logging-format-interpolation
else:
logging.info("CPU training time= {}".format( # pylint: disable = logging-format-interpolation
str(datetime.datetime.now() - time_start))) # pylint: disable = logging-format-interpolation
if __name__ == "__main__":
import argparse
logging.basicConfig(level=logging.INFO,
format=('%(levelname)s|%(asctime)s'
'|%(pathname)s|%(lineno)d| %(message)s'),
datefmt='%Y-%m-%dT%H:%M:%S',
)
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser(description='Train Pytorch MNIST model using DDP')
parser.add_argument('--gpu', action='store_true',
help='Use GPU and CUDA')
parser.set_defaults(gpu=False)
parser.add_argument('--modelpath', metavar='path', required=True,
help='Path to model, e.g., /mnt/kubeflow-gcfs/pytorch/model')
args = parser.parse_args()
if args.gpu:
logging.info("\n======= CUDA INFO =======")
logging.info("CUDA Availibility: %s", torch.cuda.is_available())
if torch.cuda.is_available():
logging.info("CUDA Device Name: %s", torch.cuda.get_device_name(0))
logging.info("CUDA Version: %s", torch.version.cuda)
logging.info("=========================\n")
dist.init_process_group(backend='gloo')
run(modelpath=args.modelpath, gpu=args.gpu)
dist.destroy_process_group()