mirror of https://github.com/kubeflow/examples.git
246 lines
8.3 KiB
Python
Executable File
246 lines
8.3 KiB
Python
Executable File
#!/usr/bin/env python
|
|
'''
|
|
Copyright 2018 Google LLC
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
'''
|
|
import datetime
|
|
import logging
|
|
import os
|
|
from math import ceil
|
|
from random import Random
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn # pylint: disable = all
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim # pylint: disable = all
|
|
import torch.utils.data
|
|
import torch.utils.data.distributed
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
from torch.autograd import Variable
|
|
from torch.nn.modules import Module
|
|
from torchvision import datasets, transforms
|
|
|
|
gbatch_size = 128
|
|
|
|
|
|
class DistributedDataParallel(Module):
|
|
def __init__(self, module):
|
|
super(DistributedDataParallel, self).__init__()
|
|
self.module = module
|
|
self.first_call = True
|
|
|
|
def allreduce_params():
|
|
if self.needs_reduction:
|
|
self.needs_reduction = False # pylint: disable = attribute-defined-outside-init
|
|
buckets = {}
|
|
for param in self.module.parameters():
|
|
if param.requires_grad and param.grad is not None:
|
|
tp = type(param.data)
|
|
if tp not in buckets:
|
|
buckets[tp] = []
|
|
buckets[tp].append(param)
|
|
for tp in buckets:
|
|
bucket = buckets[tp]
|
|
grads = [param.grad.data for param in bucket]
|
|
coalesced = _flatten_dense_tensors(grads)
|
|
dist.all_reduce(coalesced)
|
|
coalesced /= dist.get_world_size()
|
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
|
buf.copy_(synced)
|
|
|
|
for param in list(self.module.parameters()):
|
|
def allreduce_hook(*unused): # pylint: disable = unused-argument
|
|
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable = protected-access
|
|
|
|
if param.requires_grad:
|
|
param.register_hook(allreduce_hook)
|
|
|
|
def weight_broadcast(self):
|
|
for param in self.module.parameters():
|
|
dist.broadcast(param.data, 0)
|
|
|
|
def forward(self, *inputs, **kwargs): # pylint: disable = arguments-differ
|
|
if self.first_call:
|
|
logging.info("first broadcast start")
|
|
self.weight_broadcast()
|
|
self.first_call = False
|
|
logging.info("first broadcast done")
|
|
self.needs_reduction = True # pylint: disable = attribute-defined-outside-init
|
|
return self.module(*inputs, **kwargs)
|
|
|
|
|
|
class Partition(object): # pylint: disable = all
|
|
""" Dataset-like object, but only access a subset of it. """
|
|
|
|
def __init__(self, data, index):
|
|
self.data = data
|
|
self.index = index
|
|
|
|
def __len__(self):
|
|
return len(self.index)
|
|
|
|
def __getitem__(self, index):
|
|
data_idx = self.index[index]
|
|
return self.data[data_idx]
|
|
|
|
|
|
class DataPartitioner(object): # pylint: disable = all
|
|
""" Partitions a dataset into different chuncks. """
|
|
|
|
def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): # pylint: disable = dangerous-default-value
|
|
self.data = data
|
|
self.partitions = []
|
|
rng = Random()
|
|
rng.seed(seed)
|
|
data_len = len(data)
|
|
indexes = [x for x in range(0, data_len)]
|
|
rng.shuffle(indexes)
|
|
|
|
for frac in sizes:
|
|
part_len = int(frac * data_len)
|
|
self.partitions.append(indexes[0:part_len])
|
|
indexes = indexes[part_len:]
|
|
|
|
def use(self, partition):
|
|
return Partition(self.data, self.partitions[partition])
|
|
|
|
|
|
class Net(nn.Module):
|
|
""" Network architecture. """
|
|
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|
self.conv2_drop = nn.Dropout2d()
|
|
self.fc1 = nn.Linear(320, 50)
|
|
self.fc2 = nn.Linear(50, 10)
|
|
|
|
def forward(self, x): # pylint: disable = arguments-differ
|
|
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
|
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
|
x = x.view(-1, 320)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.dropout(x, training=self.training)
|
|
x = self.fc2(x)
|
|
return F.log_softmax(x)
|
|
|
|
|
|
def partition_dataset(rank):
|
|
""" Partitioning MNIST """
|
|
dataset = datasets.MNIST(
|
|
'./data{}'.format(rank),
|
|
train=True,
|
|
download=True,
|
|
transform=transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.1307,), (0.3081,))
|
|
]))
|
|
size = dist.get_world_size()
|
|
bsz = int(gbatch_size / float(size))
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
|
train_set = torch.utils.data.DataLoader(
|
|
dataset, batch_size=bsz, shuffle=(train_sampler is None), sampler=train_sampler)
|
|
return train_set, bsz
|
|
|
|
|
|
def average_gradients(model):
|
|
""" Gradient averaging. """
|
|
size = float(dist.get_world_size())
|
|
group = dist.new_group([0])
|
|
for param in model.parameters():
|
|
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=group)
|
|
param.grad.data /= size
|
|
|
|
|
|
def run(modelpath, gpu):
|
|
""" Distributed Synchronous SGD Example """
|
|
rank = dist.get_rank()
|
|
torch.manual_seed(1234)
|
|
train_set, bsz = partition_dataset(rank)
|
|
model = Net()
|
|
if gpu:
|
|
model = model.cuda()
|
|
model = torch.nn.parallel.DistributedDataParallel(model)
|
|
else:
|
|
model = DistributedDataParallel(model)
|
|
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
|
|
model_dir = modelpath
|
|
|
|
num_batches = ceil(len(train_set.dataset) / float(bsz))
|
|
logging.info("num_batches = %s", num_batches)
|
|
time_start = datetime.datetime.now()
|
|
for epoch in range(3):
|
|
epoch_loss = 0.0
|
|
for data, target in train_set:
|
|
if gpu:
|
|
data, target = Variable(data).cuda(), Variable(target).cuda()
|
|
else:
|
|
data, target = Variable(data), Variable(target)
|
|
optimizer.zero_grad()
|
|
output = model(data)
|
|
loss = F.nll_loss(output, target)
|
|
epoch_loss += loss.item()
|
|
loss.backward()
|
|
average_gradients(model)
|
|
optimizer.step()
|
|
logging.info('Epoch {} Loss {:.6f} Global batch size {} on {} ranks'.format(
|
|
epoch, epoch_loss / num_batches, gbatch_size, dist.get_world_size()))
|
|
# Ensure only the master node saves the model
|
|
main_proc = rank == 0
|
|
if main_proc:
|
|
if not os.path.exists(model_dir):
|
|
os.makedirs(model_dir)
|
|
if gpu:
|
|
model_path = model_dir + "/model_gpu.dat"
|
|
else:
|
|
model_path = model_dir + "/model_cpu.dat"
|
|
logging.info("Saving model in {}".format(model_path)) # pylint: disable = logging-format-interpolation
|
|
torch.save(model.module.state_dict(), model_path)
|
|
if gpu:
|
|
logging.info("GPU training time= {}".format( # pylint: disable = logging-format-interpolation
|
|
str(datetime.datetime.now() - time_start))) # pylint: disable = logging-format-interpolation
|
|
else:
|
|
logging.info("CPU training time= {}".format( # pylint: disable = logging-format-interpolation
|
|
str(datetime.datetime.now() - time_start))) # pylint: disable = logging-format-interpolation
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
logging.basicConfig(level=logging.INFO,
|
|
format=('%(levelname)s|%(asctime)s'
|
|
'|%(pathname)s|%(lineno)d| %(message)s'),
|
|
datefmt='%Y-%m-%dT%H:%M:%S',
|
|
)
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
parser = argparse.ArgumentParser(description='Train Pytorch MNIST model using DDP')
|
|
parser.add_argument('--gpu', action='store_true',
|
|
help='Use GPU and CUDA')
|
|
parser.set_defaults(gpu=False)
|
|
parser.add_argument('--modelpath', metavar='path', required=True,
|
|
help='Path to model, e.g., /mnt/kubeflow-gcfs/pytorch/model')
|
|
args = parser.parse_args()
|
|
if args.gpu:
|
|
logging.info("\n======= CUDA INFO =======")
|
|
logging.info("CUDA Availibility: %s", torch.cuda.is_available())
|
|
if torch.cuda.is_available():
|
|
logging.info("CUDA Device Name: %s", torch.cuda.get_device_name(0))
|
|
logging.info("CUDA Version: %s", torch.version.cuda)
|
|
logging.info("=========================\n")
|
|
dist.init_process_group(backend='gloo')
|
|
run(modelpath=args.modelpath, gpu=args.gpu)
|
|
dist.destroy_process_group()
|