Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的内存泄漏及gc模块的使用分析
Jul 16 Python
python利用拉链法实现字典方法示例
Mar 25 Python
Windows环境下python环境安装使用图文教程
Mar 13 Python
python实现windows下文件备份脚本
May 27 Python
Django框架自定义session处理操作示例
May 27 Python
numpy中的ndarray方法和属性详解
May 27 Python
python飞机大战pygame游戏之敌机出场实现方法详解
Dec 17 Python
Python实现aes加密解密多种方法解析
May 15 Python
解决python对齐错误的方法
Jul 16 Python
关于Python3爬虫利器Appium的安装步骤
Jul 29 Python
Python爬取你好李焕英豆瓣短评生成词云的示例代码
Feb 24 Python
Python中的嵌套循环详情
Mar 23 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
浅谈PHP语法(1)
2006/10/09 PHP
用PHP的超级变量$_GET获取HTML表单(Form) 数据
2011/05/07 PHP
ThinkPHP的MVC开发机制实例解析
2014/08/23 PHP
PHP性能分析工具XHProf安装使用教程
2015/05/13 PHP
CodeIgniter配置之database.php用法实例分析
2016/01/20 PHP
php图片裁剪函数
2018/10/31 PHP
jquery 弹出层注册页面等(asp.net后台)
2010/06/17 Javascript
最精简的JavaScript实现鼠标拖动效果的方法
2015/05/11 Javascript
jquery+json实现分页效果
2016/03/07 Javascript
JS框架之vue.js(深入三:组件1)
2016/09/29 Javascript
js如何获取网页所有图片
2017/05/12 Javascript
vue-loader教程介绍
2017/06/14 Javascript
js Date()日期函数浏览器兼容问题解决方法
2017/09/12 Javascript
详解node+express+ejs+bootstrap构建项目
2017/09/27 Javascript
vue2.0 根据状态值进行样式的改变展示方法
2018/03/13 Javascript
element 中 el-menu 组件的无限极循环思路代码详解
2020/04/26 Javascript
JavaScript实现简单动态表格
2020/12/02 Javascript
关于javascript中的promise的用法和注意事项(推荐)
2021/01/15 Javascript
vue实现简易计算器功能
2021/01/20 Vue.js
WebStorm无法正确识别Vue3组合式API的解决方案
2021/02/18 Vue.js
[01:23:35]Ti4主赛事胜者组 DK vs EG 1
2014/07/19 DOTA
python定时执行指定函数的方法
2015/05/27 Python
详解python发送各类邮件的主要方法
2016/12/22 Python
Python代码实现删除一个list里面重复元素的方法
2019/04/02 Python
详解Python高阶函数
2020/08/15 Python
Opencv 图片的OCR识别的实战示例
2021/03/02 Python
中国央视网签名寄语
2014/01/18 职场文书
民生工程实施方案
2014/03/22 职场文书
学生期末评语大全
2014/04/30 职场文书
党务公开方案
2014/05/06 职场文书
大学生活动总结模板
2014/07/02 职场文书
学校交通安全责任书
2014/08/25 职场文书
ktv服务员岗位职责
2015/02/09 职场文书
教师岗位职责范本
2015/04/02 职场文书
幽灵公主观后感
2015/06/09 职场文书
创业计划书之书店
2019/09/10 职场文书