pytorch使用horovod多gpu训练的实现


Posted in Python onSeptember 09, 2020

pytorch在Horovod上训练步骤分为以下几步:

import torch
import horovod.torch as hvd

# Initialize Horovod 初始化horovod
hvd.init()

# Pin GPU to be used to process local rank (one GPU per process) 分配到每个gpu上
torch.cuda.set_device(hvd.local_rank())

# Define dataset... 定义dataset
train_dataset = ...

# Partition dataset among workers using DistributedSampler 对dataset的采样器进行调整,使用torch.utils.data.distributed.DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(
  train_dataset, num_replicas=hvd.size(), rank=hvd.rank())

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)

# Build model...
model = ...
model.cuda()

optimizer = optim.SGD(model.parameters())

# Add Horovod Distributed Optimizer 使用Horovod的分布式优化器函数包裹在原先optimizer上
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())

# Broadcast parameters from rank 0 to all other processes. 参数广播到每个gpu上
hvd.broadcast_parameters(model.state_dict(), root_rank=0)

for epoch in range(100):
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{}]\tLoss: {}'.format(
        epoch, batch_idx * len(data), len(train_sampler), loss.item()))

完整示例代码如下,在imagenet上采用resnet50进行训练

from __future__ import print_function
  
  import torch
  import argparse
  import torch.backends.cudnn as cudnn
  import torch.nn.functional as F
  import torch.optim as optim
  import torch.utils.data.distributed
  from torchvision import datasets, transforms, models
 import horovod.torch as hvd
 import os
 import math
 from tqdm import tqdm
 from distutils.version import LooseVersion
 
 # Training settings
 parser = argparse.ArgumentParser(description='PyTorch ImageNet Example',
                  formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 parser.add_argument('--train-dir', default=os.path.expanduser('~/imagenet/train'),
           help='path to training data')
 parser.add_argument('--val-dir', default=os.path.expanduser('~/imagenet/validation'),
           help='path to validation data')
 parser.add_argument('--log-dir', default='./logs',
           help='tensorboard log directory')
 parser.add_argument('--checkpoint-format', default='./checkpoint-{epoch}.pth.tar',
           help='checkpoint file format')
 parser.add_argument('--fp-allreduce', action='store_true', default=False,
           help='use fp compression during allreduce')
 parser.add_argument('--batches-per-allreduce', type=int, default=,
           help='number of batches processed locally before '
              'executing allreduce across workers; it multiplies '
              'total batch size.')
 parser.add_argument('--use-adasum', action='store_true', default=False,
           help='use adasum algorithm to do reduction')

 # Default settings from https://arxiv.org/abs/1706.02677.
 parser.add_argument('--batch-size', type=int, default=32,
           help='input batch size for training')
 parser.add_argument('--val-batch-size', type=int, default=32,
           help='input batch size for validation')
 parser.add_argument('--epochs', type=int, default=90,
           help='number of epochs to train')
 parser.add_argument('--base-lr', type=float, default=0.0125,
 44           help='learning rate for a single GPU')
 45 parser.add_argument('--warmup-epochs', type=float, default=5,
           help='number of warmup epochs')
 parser.add_argument('--momentum', type=float, default=0.9,
           help='SGD momentum')
 parser.add_argument('--wd', type=float, default=0.00005,
           help='weight decay')
 
 parser.add_argument('--no-cuda', action='store_true', default=False,
           help='disables CUDA training')
 parser.add_argument('--seed', type=int, default=42,
           help='random seed')
 
 args = parser.parse_args()
 args.cuda = not args.no_cuda and torch.cuda.is_available()
 
 allreduce_batch_size = args.batch_size * args.batches_per_allreduce
 
 hvd.init()
 torch.manual_seed(args.seed)
 
 if args.cuda:
   # Horovod: pin GPU to local rank.
   torch.cuda.set_device(hvd.local_rank())
   torch.cuda.manual_seed(args.seed)
 
 cudnn.benchmark = True
 
 # If set > 0, will resume training from a given checkpoint.
 resume_from_epoch = 0
 for try_epoch in range(args.epochs, 0, -1):
   if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)):
     resume_from_epoch = try_epoch
     break
 
 # Horovod: broadcast resume_from_epoch from rank 0 (which will have
 # checkpoints) to other ranks.
 resume_from_epoch = hvd.broadcast(torch.tensor(resume_from_epoch), root_rank=0,
                  name='resume_from_epoch').item()
 
 # Horovod: print logs on the first worker.
 verbose = 1 if hvd.rank() == 0 else 0
 
 # Horovod: write TensorBoard logs on first worker.
 try:
   if LooseVersion(torch.__version__) >= LooseVersion('1.2.0'):
     from torch.utils.tensorboard import SummaryWriter
   else:
     from tensorboardX import SummaryWriter
   log_writer = SummaryWriter(args.log_dir) if hvd.rank() == 0 else None
 except ImportError:
   log_writer = None
 
 # Horovod: limit # of CPU threads to be used per worker.
 torch.set_num_threads(4)
 
 kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
 train_dataset = \
   datasets.ImageFolder(args.train_dir,
             transform=transforms.Compose([
               transforms.RandomResizedCrop(224),
               transforms.RandomHorizontalFlip(),
               transforms.ToTensor(),
               transforms.Normalize(mean=[., ., .],
                          std=[0.229, 0.224, 0.225])
             ]))
 # Horovod: use DistributedSampler to partition data among workers. Manually specify
 # `num_replicas=hvd.size()` and `rank=hvd.rank()`.
 train_sampler = torch.utils.data.distributed.DistributedSampler(
   train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
 train_loader = torch.utils.data.DataLoader(
   train_dataset, batch_size=allreduce_batch_size,
   sampler=train_sampler, **kwargs)
 
 val_dataset = \
   datasets.ImageFolder(args.val_dir,
             transform=transforms.Compose([
               transforms.Resize(256),
               transforms.CenterCrop(224),
               transforms.ToTensor(),
               transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
             ]))
 val_sampler = torch.utils.data.distributed.DistributedSampler(
   val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.val_batch_size,
                     sampler=val_sampler, **kwargs)
 
 
 # Set up standard ResNet-50 model.
 model = models.resnet50()
 
 # By default, Adasum doesn't need scaling up learning rate.
 # For sum/average with gradient Accumulation: scale learning rate by batches_per_allreduce
 lr_scaler = args.batches_per_allreduce * hvd.size() if not args.use_adasum else 1
 
 if args.cuda:
   # Move model to GPU.
   model.cuda()
   # If using GPU Adasum allreduce, scale learning rate by local_size.
   if args.use_adasum and hvd.nccl_built():
     lr_scaler = args.batches_per_allreduce * hvd.local_size()
 
 # Horovod: scale learning rate by the number of GPUs.
 optimizer = optim.SGD(model.parameters(),
            lr=(args.base_lr *
              lr_scaler),
            momentum=args.momentum, weight_decay=args.wd)
 
 # Horovod: (optional) compression algorithm.
 compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
 
 # Horovod: wrap optimizer with DistributedOptimizer.
 optimizer = hvd.DistributedOptimizer(
   optimizer, named_parameters=model.named_parameters(),
   compression=compression,
   backward_passes_per_step=args.batches_per_allreduce,
   op=hvd.Adasum if args.use_adasum else hvd.Average)
 
 # Restore from a previous checkpoint, if initial_epoch is specified.
 # Horovod: restore on the first worker which will broadcast weights to other workers.
 if resume_from_epoch > 0 and hvd.rank() == 0:
   filepath = args.checkpoint_format.format(epoch=resume_from_epoch)
   checkpoint = torch.load(filepath)
   model.load_state_dict(checkpoint['model'])
   optimizer.load_state_dict(checkpoint['optimizer'])
 
 # Horovod: broadcast parameters & optimizer state.
 hvd.broadcast_parameters(model.state_dict(), root_rank=)
 hvd.broadcast_optimizer_state(optimizer, root_rank=)
 
 def train(epoch):
   model.train()
   train_sampler.set_epoch(epoch)
   train_loss = Metric('train_loss')
   train_accuracy = Metric('train_accuracy')
 
   with tqdm(total=len(train_loader),
        desc='Train Epoch   #{}'.format(epoch + 1),
        disable=not verbose) as t:
     for batch_idx, (data, target) in enumerate(train_loader):
       adjust_learning_rate(epoch, batch_idx)
 
       if args.cuda:
         data, target = data.cuda(), target.cuda()
       optimizer.zero_grad()
       # Split data into sub-batches of size batch_size
       for i in range(0, len(data), args.batch_size):
         data_batch = data[i:i + args.batch_size]
         target_batch = target[i:i + args.batch_size]
         output = model(data_batch)
         train_accuracy.update(accuracy(output, target_batch))
         loss = F.cross_entropy(output, target_batch)
         train_loss.update(loss)
         # Average gradients among sub-batches
         loss.div_(math.ceil(float(len(data)) / args.batch_size))
         loss.backward()
       # Gradient is applied across all ranks
       optimizer.step()
       t.set_postfix({'loss': train_loss.avg.item(),
              'accuracy': 100. * train_accuracy.avg.item()})
       t.update(1)
 
   if log_writer:
     log_writer.add_scalar('train/loss', train_loss.avg, epoch)
     log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)
 
 
 def validate(epoch):
   model.eval()
   val_loss = Metric('val_loss')
   val_accuracy = Metric('val_accuracy')
 
   with tqdm(total=len(val_loader),
        desc='Validate Epoch #{}'.format(epoch + ),
        disable=not verbose) as t:
     with torch.no_grad():
       for data, target in val_loader:
         if args.cuda:
           data, target = data.cuda(), target.cuda()
         output = model(data)
 
         val_loss.update(F.cross_entropy(output, target))
         val_accuracy.update(accuracy(output, target))
         t.set_postfix({'loss': val_loss.avg.item(),
                'accuracy': 100. * val_accuracy.avg.item()})
        t.update(1)
 
   if log_writer:
     log_writer.add_scalar('val/loss', val_loss.avg, epoch)
     log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch)
 
 
 # Horovod: using `lr = base_lr * hvd.size()` from the very beginning leads to worse final
 # accuracy. Scale the learning rate `lr = base_lr` ---> `lr = base_lr * hvd.size()` during
 # the first five epochs. See https://arxiv.org/abs/1706.02677 for details.
 # After the warmup reduce learning rate by 10 on the 30th, 60th and 80th epochs.
 def adjust_learning_rate(epoch, batch_idx):
   if epoch < args.warmup_epochs:
     epoch += float(batch_idx + 1) / len(train_loader)
     lr_adj = 1. / hvd.size() * (epoch * (hvd.size() - 1) / args.warmup_epochs + 1)
   elif epoch < 30:
     lr_adj = 1.
   elif epoch < 60:
     lr_adj = 1e-1
   elif epoch < 80:
     lr_adj = 1e-2
   else:
     lr_adj = 1e-3
   for param_group in optimizer.param_groups:
     param_group['lr'] = args.base_lr * hvd.size() * args.batches_per_allreduce * lr_adj
 
 
 def accuracy(output, target):
   # get the index of the max log-probability
   pred = output.max(1, keepdim=True)[1]
   return pred.eq(target.view_as(pred)).cpu().float().mean()
 
 
 def save_checkpoint(epoch):
   if hvd.rank() == 0:
     filepath = args.checkpoint_format.format(epoch=epoch + 1)
     state = {
       'model': model.state_dict(),
       'optimizer': optimizer.state_dict(),
     }
     torch.save(state, filepath)
 
 
 # Horovod: average metrics from distributed training.
 class Metric(object):
   def __init__(self, name):
     self.name = name
     self.sum = torch.tensor(0.)
     self.n = torch.tensor(0.)
 
   def update(self, val):
     self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
     self.n += 1
 
   @property
   def avg(self):
     return self.sum / self.n
 
 
 for epoch in range(resume_from_epoch, args.epochs):
   train(epoch)
   validate(epoch)
   save_checkpoint(epoch)

到此这篇关于pytorch使用horovod多gpu训练的实现的文章就介绍到这了,更多相关pytorch horovod多gpu训练内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木! 

Python 相关文章推荐
Python中用Ctrl+C终止多线程程序的问题解决
Mar 30 Python
Python中的tuple元组详细介绍
Feb 02 Python
python操作redis的方法
Jul 07 Python
Python使用剪切板的方法
Jun 06 Python
Django 如何获取前端发送的头文件详解(推荐)
Aug 15 Python
浅谈python常用程序算法
Mar 22 Python
使用python写一个自动浏览文章的脚本实例
Dec 05 Python
自定义Django默认的sitemap站点地图样式
Mar 04 Python
Python调用C/C++的方法解析
Aug 05 Python
在终端启动Python时报错的解决方案
Nov 20 Python
Python urllib request模块发送请求实现过程解析
Dec 10 Python
Python趣味实战之手把手教你实现举牌小人生成器
Jun 07 Python
python,Java,JavaScript实现indexOf
Sep 09 #Python
python 5个顶级异步框架推荐
Sep 09 #Python
python PyAUtoGUI库实现自动化控制鼠标键盘
Sep 09 #Python
Pytorch生成随机数Tensor的方法汇总
Sep 09 #Python
详解python内置模块urllib
Sep 09 #Python
python语音识别指南终极版(有这一篇足矣)
Sep 09 #Python
python 爬取B站原视频的实例代码
Sep 09 #Python
You might like
IIS下配置Php+Mysql+zend的图文教程
2006/12/08 PHP
php删除页面记录 同时刷新页面 删除条件用GET方式获得
2012/01/10 PHP
解析PHP跳出循环的方法以及continue、break、exit的区别介绍
2013/07/01 PHP
xss防御之php利用httponly防xss攻击
2014/03/21 PHP
教你php如何实现验证码
2016/01/20 PHP
laravel获取不到session的三种解决办法【推荐】
2018/09/16 PHP
PHP 出现 http500 错误的解决方法
2021/03/09 PHP
js单独获取一个checkbox看其是否被选中
2014/09/22 Javascript
基于javascript实现右下角浮动广告效果
2016/01/08 Javascript
javascript原生ajax写法分享
2016/04/10 Javascript
jQuery动态修改字体大小的方法【测试可用】
2016/09/09 Javascript
jQuery实现 上升、下降、删除、添加一行代码
2017/03/06 Javascript
单击按钮发送验证码,出现倒计时的简单实例
2017/03/17 Javascript
Element-ui table中过滤条件变更表格内容的方法
2018/03/02 Javascript
nodejs+mongodb aggregate级联查询操作示例
2018/03/17 NodeJs
vue中rem的配置的方法示例
2018/08/30 Javascript
JavaScript实现的开关灯泡点击切换特效示例
2019/07/08 Javascript
[41:12]Liquid vs Secret 2019国际邀请赛淘汰赛 败者组 BO3 第一场 8.24
2019/09/10 DOTA
Python合并多个Excel数据的方法
2018/07/16 Python
python的scipy实现插值的示例代码
2019/11/12 Python
Python json模块与jsonpath模块区别详解
2020/03/05 Python
基于Python3.7.1无法导入Numpy的解决方式
2020/03/09 Python
Python使用Turtle模块绘制国旗的方法示例
2021/02/28 Python
HTML5的结构和语义(5):交互
2008/10/17 HTML / CSS
印尼太阳百货公司网站:Matahari
2018/02/04 全球购物
ASOS亚洲:ASOS Asia
2018/03/04 全球购物
购买美国制造的相框和画框架:Picture Frames
2018/08/14 全球购物
英国领先的维生素和补充剂品牌:Higher Nature
2019/08/26 全球购物
Shopee菲律宾:在线购买和出售
2019/11/25 全球购物
经贸日语专业个人求职信范文
2013/12/28 职场文书
教书育人演讲稿
2014/09/11 职场文书
计划生育证明格式及范本
2014/10/09 职场文书
商务宴请邀请函范文
2015/02/02 职场文书
《跨越海峡的生命桥》教学反思
2016/02/18 职场文书
竞聘演讲报告:基本写作有哪些?附开头范文
2019/10/16 职场文书
Spring this调用当前类方法无法拦截的示例代码
2022/03/20 Java/Android