Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pycharm 使用心得(九)解决No Python interpreter selected的问题
Jun 06 Python
如何处理Python3.4 使用pymssql 乱码问题
Jan 08 Python
python列表的增删改查实例代码
Jan 30 Python
Tensorflow 合并通道及加载子模型的方法
Jul 26 Python
老生常谈python中的重载
Nov 11 Python
python 将有序数组转换为二叉树的方法
Mar 26 Python
pyqt5实现登录界面的模板
May 30 Python
python3实现mysql导出excel的方法
Jul 31 Python
python 进程的几种创建方式详解
Aug 29 Python
python实现画图工具
Aug 27 Python
python hmac模块验证客户端的合法性
Nov 07 Python
selenium设置浏览器为headless无头模式(Chrome和Firefox)
Jan 08 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
php中文字符串截取方法实例总结
2014/09/30 PHP
PHP中使用SimpleXML检查XML文件结构实例
2015/01/07 PHP
php判断目录存在的简单方法
2019/09/26 PHP
js拖动div 当鼠标移动时整个div也相应的移动
2013/11/21 Javascript
jQuery实现DIV层收缩展开的方法
2015/02/27 Javascript
Vue.js每天必学之构造器与生命周期
2016/09/05 Javascript
jQuery的事件预绑定
2016/12/05 Javascript
详解JavaScript中return的用法
2017/05/08 Javascript
JavaScript之Map和Set_动力节点Java学院整理
2017/06/29 Javascript
js实现鼠标拖拽缩放div实例代码
2019/03/25 Javascript
node.js使用mongoose操作数据库实现购物车的增、删、改、查功能示例
2019/12/23 Javascript
把MySQL表结构映射为Python中的对象的教程
2015/04/07 Python
在Python中操作字典之update()方法的使用
2015/05/22 Python
在Python中的Django框架中进行字符串翻译
2015/07/27 Python
Python 爬虫学习笔记之多线程爬虫
2016/09/21 Python
Python正则捕获操作示例
2017/08/19 Python
解决Django模板无法使用perms变量问题的方法
2017/09/10 Python
python pyinstaller 加载ui路径方法
2019/06/10 Python
Django用户认证系统 组与权限解析
2019/08/02 Python
python 图片二值化处理(处理后为纯黑白的图片)
2019/11/01 Python
PyQt5通过信号实现MVC的示例
2021/02/06 Python
python 求两个向量的顺时针夹角操作
2021/03/04 Python
html5 http的轮询和Websocket原理
2018/10/19 HTML / CSS
英国最大的女性服装零售商:Dorothy Perkins
2017/03/30 全球购物
贪睡宠物用品:Snoozer Pet Products
2020/02/04 全球购物
工业自动化毕业生自荐信范文
2014/01/04 职场文书
学生会主席演讲稿
2014/04/25 职场文书
节能环保标语
2014/06/12 职场文书
超市理货员岗位职责
2014/07/04 职场文书
机电一体化专业求职信
2014/07/22 职场文书
领导班子专题民主生活会情况想汇报
2014/09/30 职场文书
个人欠款协议书范本2014
2014/11/02 职场文书
2015应届毕业生自荐信范文
2015/03/05 职场文书
初中英语教学随笔
2015/08/15 职场文书
一文彻底理解js原生语法prototype,__proto__和constructor
2021/10/24 Javascript
人工智能深度学习OpenAI baselines的使用方法
2022/05/20 Python