Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用scrapy解析js示例
Jan 23 Python
python中元类用法实例
Oct 10 Python
Python的动态重新封装的教程
Apr 11 Python
栈和队列数据结构的基本概念及其相关的Python实现
Aug 24 Python
python利用lxml读写xml格式的文件
Aug 10 Python
python hbase读取数据发送kafka的方法
Dec 27 Python
使用Python检测文章抄袭及去重算法原理解析
Jun 14 Python
python爬虫解决验证码的思路及示例
Aug 01 Python
Python Sympy计算梯度、散度和旋度的实例
Dec 06 Python
Python基于类路径字符串获取静态属性
Mar 12 Python
python输出数学符号实例
May 11 Python
numpy的Fancy Indexing和array比较详解
Jun 11 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
php内嵌函数用法实例
2015/03/20 PHP
php自定义截取中文字符串-utf8版
2017/02/27 PHP
浅谈ThinkPHP中initialize和construct的区别
2017/04/01 PHP
如何通过View::first使用Laravel Blade的动态模板详解
2017/09/21 PHP
浅析PHP反序列化中过滤函数使用不当导致的对象注入问题
2020/02/15 PHP
ExtJs扩展之GroupPropertyGrid代码
2010/03/05 Javascript
一行代码实现纯数据json对象的深度克隆实现思路
2013/01/09 Javascript
js隐式全局变量造成的bug示例代码
2014/04/22 Javascript
IE浏览器IFrame对象内存不释放问题解决方法
2014/08/22 Javascript
JavaScript 性能优化小结
2015/10/12 Javascript
解决Angular.Js与Django标签冲突的方案
2016/12/20 Javascript
BootStrap下的弹出框加载select2框架失败的解决方法
2017/08/31 Javascript
bootstrap+jquery项目引入文件报错的解决方法
2018/01/22 jQuery
Vue中使用Sortable的示例代码
2018/04/07 Javascript
详解React中setState回调函数
2018/06/14 Javascript
nvm、nrm、npm 安装和使用详解(小结)
2019/01/17 Javascript
如何在vue中使用jointjs过程解析
2020/05/29 Javascript
[14:36]2014 DOTA2国际邀请赛中国区预选赛5.21 Orenda VS NE
2014/05/22 DOTA
[51:17]完美世界DOTA2联赛循环赛Inki vs DeMonsTer 第二场 10月30日
2020/10/31 DOTA
Python map和reduce函数用法示例
2015/02/26 Python
Python+Wordpress制作小说站
2017/04/14 Python
Python numpy实现数组合并实例(vstack,hstack)
2018/01/09 Python
python匹配两个短语之间的字符实例
2018/12/25 Python
django mysql数据库及图片上传接口详解
2019/07/18 Python
Python 根据日志级别打印不同颜色的日志的方法示例
2019/08/08 Python
基于Python实现签到脚本过程解析
2019/10/25 Python
Python实现链表反转的方法分析【迭代法与递归法】
2020/02/22 Python
纯CSS改变webkit内核浏览器的滚动条样式
2014/04/17 HTML / CSS
Myholidays美国:在线旅游网站
2019/08/16 全球购物
大学生物业管理求职信
2013/10/24 职场文书
客服专员岗位职责范本
2013/11/29 职场文书
《狐假虎威》教学反思
2014/02/07 职场文书
计算机应用专业毕业生求职信
2014/06/03 职场文书
2014年社区妇联工作总结
2014/12/02 职场文书
电子商务专业求职信范文
2015/03/19 职场文书
优化经济发展环境工作总结
2015/08/11 职场文书