mirror of https://github.com/kubeflow/examples.git
				
				
				
			
		
			
				
	
	
		
			246 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
			
		
		
	
	
			246 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
| #!/usr/bin/env python
 | |
| '''
 | |
| Copyright 2018 Google LLC
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     https://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| '''
 | |
| import datetime
 | |
| import logging
 | |
| import os
 | |
| from math import ceil
 | |
| from random import Random
 | |
| 
 | |
| import torch
 | |
| import torch.distributed as dist
 | |
| import torch.nn as nn  # pylint: disable = all
 | |
| import torch.nn.functional as F
 | |
| import torch.optim as optim # pylint: disable = all
 | |
| import torch.utils.data
 | |
| import torch.utils.data.distributed
 | |
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
 | |
| from torch.autograd import Variable
 | |
| from torch.nn.modules import Module
 | |
| from torchvision import datasets, transforms
 | |
| 
 | |
| gbatch_size = 128
 | |
| 
 | |
| 
 | |
| class DistributedDataParallel(Module):
 | |
|   def __init__(self, module):
 | |
|     super(DistributedDataParallel, self).__init__()
 | |
|     self.module = module
 | |
|     self.first_call = True
 | |
| 
 | |
|     def allreduce_params():
 | |
|       if self.needs_reduction:
 | |
|         self.needs_reduction = False  # pylint: disable = attribute-defined-outside-init
 | |
|         buckets = {}
 | |
|         for param in self.module.parameters():
 | |
|           if param.requires_grad and param.grad is not None:
 | |
|             tp = type(param.data)
 | |
|             if tp not in buckets:
 | |
|               buckets[tp] = []
 | |
|             buckets[tp].append(param)
 | |
|         for tp in buckets:
 | |
|           bucket = buckets[tp]
 | |
|           grads = [param.grad.data for param in bucket]
 | |
|           coalesced = _flatten_dense_tensors(grads)
 | |
|           dist.all_reduce(coalesced)
 | |
|           coalesced /= dist.get_world_size()
 | |
|           for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
 | |
|             buf.copy_(synced)
 | |
| 
 | |
|     for param in list(self.module.parameters()):
 | |
|       def allreduce_hook(*unused):  # pylint: disable = unused-argument
 | |
|         Variable._execution_engine.queue_callback(allreduce_params)  # pylint: disable = protected-access
 | |
| 
 | |
|       if param.requires_grad:
 | |
|         param.register_hook(allreduce_hook)
 | |
| 
 | |
|   def weight_broadcast(self):
 | |
|     for param in self.module.parameters():
 | |
|       dist.broadcast(param.data, 0)
 | |
| 
 | |
|   def forward(self, *inputs, **kwargs):  # pylint: disable = arguments-differ
 | |
|     if self.first_call:
 | |
|       logging.info("first broadcast start")
 | |
|       self.weight_broadcast()
 | |
|       self.first_call = False
 | |
|       logging.info("first broadcast done")
 | |
|     self.needs_reduction = True  # pylint: disable = attribute-defined-outside-init
 | |
|     return self.module(*inputs, **kwargs)
 | |
| 
 | |
| 
 | |
| class Partition(object):  # pylint: disable = all
 | |
|   """ Dataset-like object, but only access a subset of it. """
 | |
| 
 | |
|   def __init__(self, data, index):
 | |
|     self.data = data
 | |
|     self.index = index
 | |
| 
 | |
|   def __len__(self):
 | |
|     return len(self.index)
 | |
| 
 | |
|   def __getitem__(self, index):
 | |
|     data_idx = self.index[index]
 | |
|     return self.data[data_idx]
 | |
| 
 | |
| 
 | |
| class DataPartitioner(object):  # pylint: disable = all
 | |
|   """ Partitions a dataset into different chuncks. """
 | |
| 
 | |
|   def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234):  # pylint: disable = dangerous-default-value
 | |
|     self.data = data
 | |
|     self.partitions = []
 | |
|     rng = Random()
 | |
|     rng.seed(seed)
 | |
|     data_len = len(data)
 | |
|     indexes = [x for x in range(0, data_len)]
 | |
|     rng.shuffle(indexes)
 | |
| 
 | |
|     for frac in sizes:
 | |
|       part_len = int(frac * data_len)
 | |
|       self.partitions.append(indexes[0:part_len])
 | |
|       indexes = indexes[part_len:]
 | |
| 
 | |
|   def use(self, partition):
 | |
|     return Partition(self.data, self.partitions[partition])
 | |
| 
 | |
| 
 | |
| class Net(nn.Module):
 | |
|   """ Network architecture. """
 | |
| 
 | |
|   def __init__(self):
 | |
|     super(Net, self).__init__()
 | |
|     self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
 | |
|     self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
 | |
|     self.conv2_drop = nn.Dropout2d()
 | |
|     self.fc1 = nn.Linear(320, 50)
 | |
|     self.fc2 = nn.Linear(50, 10)
 | |
| 
 | |
|   def forward(self, x):  # pylint: disable = arguments-differ
 | |
|     x = F.relu(F.max_pool2d(self.conv1(x), 2))
 | |
|     x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
 | |
|     x = x.view(-1, 320)
 | |
|     x = F.relu(self.fc1(x))
 | |
|     x = F.dropout(x, training=self.training)
 | |
|     x = self.fc2(x)
 | |
|     return F.log_softmax(x)
 | |
| 
 | |
| 
 | |
| def partition_dataset(rank):
 | |
|   """ Partitioning MNIST """
 | |
|   dataset = datasets.MNIST(
 | |
|     './data{}'.format(rank),
 | |
|     train=True,
 | |
|     download=True,
 | |
|     transform=transforms.Compose([
 | |
|       transforms.ToTensor(),
 | |
|       transforms.Normalize((0.1307,), (0.3081,))
 | |
|     ]))
 | |
|   size = dist.get_world_size()
 | |
|   bsz = int(gbatch_size / float(size))
 | |
|   train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
 | |
|   train_set = torch.utils.data.DataLoader(
 | |
|     dataset, batch_size=bsz, shuffle=(train_sampler is None), sampler=train_sampler)
 | |
|   return train_set, bsz
 | |
| 
 | |
| 
 | |
| def average_gradients(model):
 | |
|   """ Gradient averaging. """
 | |
|   size = float(dist.get_world_size())
 | |
|   group = dist.new_group([0])
 | |
|   for param in model.parameters():
 | |
|     dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=group)
 | |
|     param.grad.data /= size
 | |
| 
 | |
| 
 | |
| def run(modelpath, gpu):
 | |
|   """ Distributed Synchronous SGD Example """
 | |
|   rank = dist.get_rank()
 | |
|   torch.manual_seed(1234)
 | |
|   train_set, bsz = partition_dataset(rank)
 | |
|   model = Net()
 | |
|   if gpu:
 | |
|     model = model.cuda()
 | |
|     model = torch.nn.parallel.DistributedDataParallel(model)
 | |
|   else:
 | |
|     model = DistributedDataParallel(model)
 | |
|   optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 | |
|   model_dir = modelpath
 | |
| 
 | |
|   num_batches = ceil(len(train_set.dataset) / float(bsz))
 | |
|   logging.info("num_batches = %s", num_batches)
 | |
|   time_start = datetime.datetime.now()
 | |
|   for epoch in range(3):
 | |
|     epoch_loss = 0.0
 | |
|     for data, target in train_set:
 | |
|       if gpu:
 | |
|         data, target = Variable(data).cuda(), Variable(target).cuda()
 | |
|       else:
 | |
|         data, target = Variable(data), Variable(target)
 | |
|       optimizer.zero_grad()
 | |
|       output = model(data)
 | |
|       loss = F.nll_loss(output, target)
 | |
|       epoch_loss += loss.item()
 | |
|       loss.backward()
 | |
|       average_gradients(model)
 | |
|       optimizer.step()
 | |
|     logging.info('Epoch {} Loss {:.6f} Global batch size {} on {} ranks'.format(
 | |
|       epoch, epoch_loss / num_batches, gbatch_size, dist.get_world_size()))
 | |
|   # Ensure only the master node saves the model
 | |
|   main_proc = rank == 0
 | |
|   if main_proc:
 | |
|     if not os.path.exists(model_dir):
 | |
|       os.makedirs(model_dir)
 | |
|     if gpu:
 | |
|       model_path = model_dir + "/model_gpu.dat"
 | |
|     else:
 | |
|       model_path = model_dir + "/model_cpu.dat"
 | |
|     logging.info("Saving model in {}".format(model_path))  # pylint: disable = logging-format-interpolation
 | |
|     torch.save(model.module.state_dict(), model_path)
 | |
|   if gpu:
 | |
|     logging.info("GPU training time= {}".format(  # pylint: disable = logging-format-interpolation
 | |
|       str(datetime.datetime.now() - time_start)))  # pylint: disable = logging-format-interpolation
 | |
|   else:
 | |
|     logging.info("CPU training time= {}".format(  # pylint: disable = logging-format-interpolation
 | |
|       str(datetime.datetime.now() - time_start)))  # pylint: disable = logging-format-interpolation
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|   import argparse
 | |
|   logging.basicConfig(level=logging.INFO,
 | |
|                       format=('%(levelname)s|%(asctime)s'
 | |
|                               '|%(pathname)s|%(lineno)d| %(message)s'),
 | |
|                       datefmt='%Y-%m-%dT%H:%M:%S',
 | |
|                       )
 | |
|   logging.getLogger().setLevel(logging.INFO)
 | |
| 
 | |
|   parser = argparse.ArgumentParser(description='Train Pytorch MNIST model using DDP')
 | |
|   parser.add_argument('--gpu', action='store_true',
 | |
|                       help='Use GPU and CUDA')
 | |
|   parser.set_defaults(gpu=False)
 | |
|   parser.add_argument('--modelpath', metavar='path', required=True,
 | |
|                       help='Path to model, e.g., /mnt/kubeflow-gcfs/pytorch/model')
 | |
|   args = parser.parse_args()
 | |
|   if args.gpu:
 | |
|     logging.info("\n======= CUDA INFO =======")
 | |
|     logging.info("CUDA Availibility: %s", torch.cuda.is_available())
 | |
|     if torch.cuda.is_available():
 | |
|       logging.info("CUDA Device Name: %s", torch.cuda.get_device_name(0))
 | |
|       logging.info("CUDA Version: %s", torch.version.cuda)
 | |
|     logging.info("=========================\n")
 | |
|   dist.init_process_group(backend='gloo')
 | |
|   run(modelpath=args.modelpath, gpu=args.gpu)
 | |
|   dist.destroy_process_group()
 |