解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之有容乃大的list(3)
Sep 15 Python
Python使用urllib模块的urlopen超时问题解决方法
Nov 08 Python
Python中实现从目录中过滤出指定文件类型的文件
Feb 02 Python
Python中实现三目运算的方法
Jun 21 Python
python实现的用于搜索文件并进行内容替换的类实例
Jun 28 Python
python实现中文分词FMM算法实例
Jul 10 Python
Python实现的径向基(RBF)神经网络示例
Feb 06 Python
tensorflow实现简单的卷积网络
May 24 Python
Python3中bytes类型转换为str类型
Sep 27 Python
简单了解django orm中介模型
Jul 30 Python
Django多数据库配置及逆向生成model教程
Mar 28 Python
Python保存并浏览用户的历史记录
Apr 29 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
DC动画很好看?新作烂得令人发指,名叫《红色之子》
2020/04/09 欧美动漫
php4的session功能评述(三)
2006/10/09 PHP
php下用cookie统计用户访问网页次数的代码
2010/05/09 PHP
PHP+JS实现的商品秒杀倒计时用法示例
2016/11/15 PHP
php多文件打包下载的实例代码
2017/07/12 PHP
php5.3/5.4/5.5/5.6/7常见新增特性汇总整理
2020/02/27 PHP
cloudgamer出品ImageZoom 图片放大效果
2010/04/01 Javascript
关于ExtJS4.1:快捷键支持的问题
2013/04/24 Javascript
js中如何把字符串转化为对象、数组示例代码
2013/07/17 Javascript
js类式继承的具体实现方法
2013/12/31 Javascript
Jquery通过JSON字符串创建JSON对象
2014/08/24 Javascript
使用phantomjs进行网页抓取的实现代码
2014/09/29 Javascript
javascript判断数组内是否重复的方法
2015/04/21 Javascript
jQuery简单获取键盘事件的方法
2016/01/22 Javascript
JS获取时间的相关函数及时间戳与时间日期之间的转换
2016/02/04 Javascript
JS动态计算移动端rem的解决方案
2016/10/14 Javascript
thinkjs之页面跳转同步异步操作
2017/02/05 Javascript
详解Angular路由 ng-route和ui-router的区别
2017/05/22 Javascript
Angularjs 根据一个select的值去设置另一个select的值方法
2018/08/13 Javascript
vue中子组件的methods中获取到props中的值方法
2018/08/27 Javascript
jquery实现自定义树形表格的方法【自定义树形结构table】
2019/07/12 jQuery
[01:45]亚洲邀请赛互动指南虚拟物品介绍
2015/01/30 DOTA
python网络编程学习笔记(二):socket建立网络客户端
2014/06/09 Python
详解Python list 与 NumPy.ndarry 切片之间的对比
2017/07/24 Python
关于python写入文件自动换行的问题
2018/06/23 Python
浅谈Scrapy网络爬虫框架的工作原理和数据采集
2019/02/07 Python
Python实现字典按key或者value进行排序操作示例【sorted】
2019/05/03 Python
PyQt5 控件字体样式等设置的实现
2020/05/13 Python
详解如何用canvas画一个微笑的表情
2019/03/14 HTML / CSS
香港个人化生活购物网站:Ballyhoo Limited
2016/09/10 全球购物
银行进社区活动总结
2014/07/07 职场文书
开展党的群众路线教育实践活动领导班子对照检查材料
2014/09/25 职场文书
如何写通讯稿
2015/07/22 职场文书
读《工匠精神》有感:热爱工作,精益求精
2019/12/28 职场文书
python中的None与NULL用法说明
2021/05/25 Python
JavaScript选择器函数querySelector和querySelectorAll
2021/11/27 Javascript