Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
linux系统使用python获取cpu信息脚本分享
Jan 15 Python
Python中__call__用法实例
Aug 29 Python
Python解析树及树的遍历
Feb 03 Python
解决nohup重定向python输出到文件不成功的问题
May 11 Python
python获取文件真实链接的方法,针对于302返回码
May 14 Python
对Tensorflow中权值和feature map的可视化详解
Jun 14 Python
Python查找第n个子串的技巧分享
Jun 27 Python
pandas 按照特定顺序输出的实现代码
Jul 10 Python
python动态视频下载器的实现方法
Sep 16 Python
matplotlib 曲线图 和 折线图 plt.plot()实例
Apr 17 Python
Django ORM 查询表中某列字段值的方法
Apr 30 Python
如何从csv文件构建Tensorflow的数据集
Sep 21 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
支持oicq头像的留言簿(二)
2006/10/09 PHP
php通过获取头信息判断图片类型的方法
2015/06/26 PHP
mod_php、FastCGI、PHP-FPM等PHP运行方式对比
2015/07/02 PHP
PHP获取当前文件的父目录方法汇总
2016/07/21 PHP
PHP上传图片、删除图片简单实例
2016/11/12 PHP
Yii Framework框架使用PHPExcel组件的方法示例
2019/07/24 PHP
jQuery prev ~ siblings选择器使用介绍
2013/08/09 Javascript
js取值中form.all和不加all的区别介绍
2014/01/20 Javascript
jQuery对于显示和隐藏等常用状态的判断方法
2014/12/13 Javascript
15款jQuery分布引导插件分享
2015/02/04 Javascript
JS判断页面是否出现滚动条的方法
2015/07/17 Javascript
VueJs路由跳转——vue-router的使用详解
2017/01/10 Javascript
js选项卡的制作方法
2017/01/23 Javascript
详解nodejs微信公众号开发——6.自定义菜单
2017/04/13 NodeJs
vue页面离开后执行函数的实例
2018/03/13 Javascript
layui框架table 数据表格的方法级渲染详解
2018/08/19 Javascript
angularJS1 url中携带参数的获取方法
2018/10/09 Javascript
jQuery ajax仿Google自动提示SearchSuggess功能示例
2019/03/28 jQuery
vue flex 布局实现div均分自动换行的示例代码
2020/08/05 Javascript
布同 统计英文单词的个数的python代码
2011/03/13 Python
Windows上使用virtualenv搭建Python+Flask开发环境
2016/06/07 Python
python3爬取数据至mysql的方法
2018/06/26 Python
Flask框架URL管理操作示例【基于@app.route】
2018/07/23 Python
python pandas利用fillna方法实现部分自动填充功能
2020/03/16 Python
python代码区分大小写吗
2020/06/17 Python
Python模块常用四种安装方式
2020/10/20 Python
python基于Kivy写一个图形桌面时钟程序
2021/01/28 Python
CSS3制作气泡对话框的实例教程
2016/05/10 HTML / CSS
编码转换,怎样实现将GB2312编码的字符串转换为ISO-8859-1编码的字符串
2014/01/07 面试题
2014年局领导班子自身建设情况汇报
2014/11/21 职场文书
初中体育教学随笔
2015/08/15 职场文书
2016年寒假家长评语
2015/10/10 职场文书
2019大学毕业晚会主持词
2019/06/21 职场文书
Go语言并发编程 sync.Once
2021/10/16 Golang
JavaScript中MutationObServer监听DOM元素详情
2021/11/27 Javascript
Python超详细分步解析随机漫步
2022/03/17 Python