Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用os模块的os.walk遍历文件夹示例
Jan 27 Python
Swift中的协议(protocol)学习教程
Jul 08 Python
Python基于pillow判断图片完整性的方法
Sep 18 Python
Python实现的快速排序算法详解
Aug 01 Python
基于python socketserver框架全面解析
Sep 21 Python
python在TXT文件中按照某一字符串取出该字符串所在的行方法
Dec 10 Python
详解Python locals()的陷阱
Mar 26 Python
python使用time、datetime返回工作日列表实例代码
May 09 Python
django 连接数据库 sqlite的例子
Aug 14 Python
详解使用Python下载文件的几种方法
Oct 13 Python
Python字典深浅拷贝与循环方式方法详解
Feb 09 Python
解析python中的jsonpath 提取器
Jan 18 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
PHP中的加密功能
2006/10/09 PHP
Laravel5中contracts详解
2015/03/02 PHP
PHP getID3类的使用方法学习笔记【附getID3源码下载】
2019/10/18 PHP
细品javascript 寻址,闭包,对象模型和相关问题
2009/04/27 Javascript
JavaScript字符串对象toLowerCase方法入门实例(用于把字母转换为小写)
2014/10/17 Javascript
JavaScript Function函数类型介绍
2015/04/08 Javascript
javascript实现的全国省市县无刷新多级关联菜单效果代码
2016/08/01 Javascript
Ajax和Comet技术总结
2017/02/19 Javascript
xmlplus组件设计系列之选项卡(Tabbar)(5)
2017/05/03 Javascript
JavaScript中使用Async实现异步控制
2017/08/15 Javascript
Vue 2.5.2下axios + express 本地请求404的解决方法
2018/02/21 Javascript
使用vuex缓存数据并优化自己的vuex-cache
2018/05/30 Javascript
微信小程序获取音频时长与实时获取播放进度问题
2018/08/28 Javascript
Vue.js 中 axios 跨域访问错误问题及解决方法
2018/11/21 Javascript
基于vue-cli、elementUI的Vue超简单入门小例子(推荐)
2019/04/17 Javascript
详解Vue+ElementUI从零开始搭建自己的网站(一、环境搭建)
2019/04/30 Javascript
一文了解vue-router之hash模式和history模式
2019/05/31 Javascript
详解NodeJs项目 CentOs linux服务器线上部署
2019/09/16 NodeJs
vue中touch和click共存的解决方式
2020/07/28 Javascript
结合axios对项目中的api请求进行封装操作
2020/09/21 Javascript
[04:13]2018国际邀请赛典藏宝瓶Ⅱ饰品一览
2018/07/21 DOTA
[41:17]完美世界DOTA2联赛PWL S3 access vs CPG 第二场 12.13
2020/12/17 DOTA
Python开发实例分享bt种子爬虫程序和种子解析
2014/05/21 Python
Python3基础之基本数据类型概述
2014/08/13 Python
详解Python3操作Mongodb简明易懂教程
2017/05/25 Python
python_opencv用线段画封闭矩形的实例
2018/12/05 Python
对python周期性定时器的示例详解
2019/02/19 Python
Python对Excel按列值筛选并拆分表格到多个文件的代码
2019/11/05 Python
Django与pyecharts结合的实例代码
2020/05/13 Python
Python如何将将模块分割成多个文件
2020/08/04 Python
计算机求职信
2013/12/01 职场文书
医学专业应届生的自我评价
2014/02/28 职场文书
干部作风整顿自我剖析材料和整改措施
2014/09/18 职场文书
质量保证书格式模板
2015/02/27 职场文书
C#连接ORACLE出现乱码问题的解决方法
2021/10/05 Oracle
如何解决flex文本溢出问题小结
2022/07/15 HTML / CSS