Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解python中的select模块
Apr 23 Python
Python cookbook(数据结构与算法)从任意长度的可迭代对象中分解元素操作示例
Feb 13 Python
python爬虫基本知识
Mar 05 Python
详解通过API管理或定制开发ECS实例
Sep 30 Python
flask框架自定义过滤器示例【markdown文件读取和展示功能】
Nov 08 Python
pygame库实现移动底座弹球小游戏
Apr 14 Python
Flask中endpoint的理解(小结)
Dec 11 Python
浅谈Python中range与Numpy中arange的比较
Mar 11 Python
查看keras各种网络结构各层的名字方式
Jun 11 Python
Python常用库Numpy进行矩阵运算详解
Jul 21 Python
解决import tensorflow导致jupyter内核死亡的问题
Feb 06 Python
Python爬虫基础讲解之请求
May 13 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
php查看一个变量的占用内存的实例代码
2020/03/29 PHP
PHP程序守护进程化实现方法详解
2020/07/16 PHP
PHP7 新增功能
2021/03/09 PHP
如果文字过长,则将过长的部分变成省略号显示
2006/06/26 Javascript
Javascript isArray 数组类型检测函数
2009/10/08 Javascript
js数组Array sort方法使用深入分析
2013/02/21 Javascript
调用HttpHanlder的几种返回方式小结
2013/12/20 Javascript
jquery幻灯片插件bxslider样式改进实例
2014/10/15 Javascript
jquery通过ajax加载一段文本内容的方法
2015/01/15 Javascript
javascript实现禁止复制网页内容汇总
2015/12/30 Javascript
javascript每日必学之基础入门
2016/02/16 Javascript
JavaScript的==运算详解
2016/07/20 Javascript
jQuery+HTML5+CSS3制作支持响应式布局时间轴插件
2016/08/10 Javascript
超全面的javascript中变量命名规则
2017/02/09 Javascript
Angular 4依赖注入学习教程之ValueProvider的使用(七)
2017/06/04 Javascript
Vue-cli3.x + axios 跨域方案踩坑指北
2019/07/04 Javascript
Vue+Element ui 根据后台返回数据设置动态表头操作
2020/09/21 Javascript
windows 10下安装搭建django1.10.3和Apache2.4的方法
2017/04/05 Python
Python随机读取文件实现实例
2017/05/25 Python
python爬虫headers设置后无效的解决方法
2017/10/21 Python
对python dataframe逻辑取值的方法详解
2019/01/30 Python
对Python 简单串口收发GUI界面的实例详解
2019/06/12 Python
python内存动态分配过程详解
2019/07/15 Python
详解Python用三种方式统计词频的方法
2019/07/29 Python
python实现逆滤波与维纳滤波示例
2020/02/26 Python
自定义实现 PyQt5 下拉复选框 ComboCheckBox的完整代码
2020/03/30 Python
Jmeter调用Python脚本实现参数互相传递的实现
2021/01/22 Python
canvas之万花筒效果的简单实现(推荐)
2016/08/16 HTML / CSS
授权委托书样本及填写说明
2014/09/19 职场文书
机关党总支领导班子整改方案
2014/09/20 职场文书
公积金接收函格式
2015/01/30 职场文书
创业计划书之便利店
2019/09/05 职场文书
python如何在word中存储本地图片
2021/04/07 Python
Python连续赋值需要注意的一些问题
2021/06/03 Python
Pandas实现DataFrame的简单运算、统计与排序
2022/03/31 Python
JavaScript parseInt0.0000005打印5原理解析
2022/07/23 Javascript