Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解Python中的元类(metaclass)
Feb 14 Python
简单掌握Python中glob模块查找文件路径的用法
Jul 05 Python
Python中 Lambda表达式全面解析
Nov 28 Python
Python开发中爬虫使用代理proxy抓取网页的方法示例
Sep 26 Python
书单|人生苦短,你还不用python!
Dec 29 Python
Python实现的径向基(RBF)神经网络示例
Feb 06 Python
python实现简易通讯录修改版
Mar 13 Python
PyQt5 在label显示的图片中绘制矩形的方法
Jun 17 Python
Python爬虫库BeautifulSoup的介绍与简单使用实例
Jan 25 Python
解决import tensorflow导致jupyter内核死亡的问题
Feb 06 Python
python3 删除所有自定义变量的操作
Apr 08 Python
Django程序的优化技巧
Apr 29 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
Symfony控制层深入详解
2016/03/17 PHP
PHP自动识别当前使用移动终端
2018/05/21 PHP
laravel5.2表单验证,并显示错误信息的实例
2019/09/29 PHP
使用script的src实现跨域和类似ajax效果
2014/11/10 Javascript
JavaScript中的值是按值传递还是按引用传递问题探讨
2015/01/30 Javascript
jQuery EasyUI框架中的Datagrid数据表格组件结构详解
2016/06/09 Javascript
Javascript农历与公历相互转换的简单实例
2016/10/09 Javascript
Vue.js实现文章评论和回复评论功能
2020/05/30 Javascript
Vue中添加手机验证码组件功能操作方法
2017/12/07 Javascript
详解react-redux插件入门
2018/04/19 Javascript
vue-router3.0版本中 router.push 不能刷新页面的问题
2018/05/10 Javascript
利用vscode调试编译后的js代码详解
2018/05/14 Javascript
深入理解js A*寻路算法原理与具体实现过程
2018/12/13 Javascript
js实现图片跟随鼠标移动效果
2019/10/16 Javascript
JS面向对象编程基础篇(一) 对象和构造函数实例详解
2020/03/03 Javascript
[49:27]2018DOTA2亚洲邀请赛 4.4 淘汰赛 TNC vs VG 第一场
2018/04/05 DOTA
python中xrange和range的区别
2014/05/13 Python
Python之web模板应用
2017/12/26 Python
python实现串口通信的示例代码
2020/02/10 Python
Python依赖包迁移到断网环境操作
2020/07/13 Python
CSS3系列之3D制作方法案例
2017/08/14 HTML / CSS
英国精品买手店:Browns Fashion
2016/09/29 全球购物
美国踏板车和轻便摩托车销售网站:Mega Motor Madness
2020/02/26 全球购物
邮政员工辞职信
2014/01/16 职场文书
三年级数学教学反思
2014/01/31 职场文书
读群众路线心得体会
2014/03/07 职场文书
小学生手册家长评语
2014/04/16 职场文书
奥巴马的演讲稿
2014/05/15 职场文书
2014世界杯球队球队口号
2014/06/05 职场文书
2014学生会工作总结报告
2014/12/02 职场文书
挂靠协议书
2015/01/27 职场文书
2015年数学教师工作总结
2015/05/20 职场文书
教师学习心得体会范文
2016/01/21 职场文书
2016年圣诞节义工活动总结
2016/04/01 职场文书
Redis缓存-序列化对象存储乱码问题的解决
2021/06/21 Redis
JS setTimeout与setInterval的区别
2022/04/20 Javascript