解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程之UDP通信实例(含服务器端、客户端、UDP广播例子)
Apr 25 Python
Python实现的数据结构与算法之队列详解
Apr 22 Python
在Python中使用正则表达式的方法
Aug 13 Python
分数霸榜! python助你微信跳一跳拿高分
Jan 08 Python
python批量读取文件名并写入txt文件中
Sep 05 Python
python中p-value的实现方式
Dec 16 Python
pytorch 中pad函数toch.nn.functional.pad()的用法
Jan 08 Python
Python3.7黑帽编程之病毒篇(基础篇)
Feb 04 Python
Python开发之身份证验证库id_validator验证身份证号合法性及根据身份证号返回住址年龄等信息
Mar 20 Python
django和flask哪个值得研究学习
Jul 31 Python
python3爬虫中多线程的优势总结
Nov 24 Python
Python+腾讯云服务器实现每日自动健康打卡
Dec 06 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
Symfony2联合查询实现方法
2016/03/18 PHP
PHP微商城开源代码实例
2019/03/27 PHP
Aster vs KG BO3 第三场2.18
2021/03/10 DOTA
jquer之ajaxQueue简单实现代码
2011/09/15 Javascript
对javascript的一点点认识总结《javascript高级程序设计》读书笔记
2011/11/30 Javascript
Javascript中的作用域和上下文深入理解
2015/07/03 Javascript
javascript for-in有序遍历json数据并探讨各个浏览器差异
2015/11/30 Javascript
原生js的数组除重复简单实例
2016/05/24 Javascript
浅析JavaScript中命名空间namespace模式
2016/06/22 Javascript
vue打包的时候自动将px转成rem的操作方法
2018/06/20 Javascript
AngularJs1.x自定义指令独立作用域的函数传入参数方法
2018/10/09 Javascript
vue实现form表单与table表格的数据关联功能示例
2019/01/29 Javascript
详解JS实现系统登录页的登录和验证
2019/04/29 Javascript
jQuery实现动态加载(按需加载)javascript文件的方法分析
2019/05/31 jQuery
vue 框架下自定义滚动条(easyscroll)实现方法
2019/08/29 Javascript
layer.msg()去掉默认时间,实现手动关闭的方法
2019/09/12 Javascript
vue实现在线翻译功能
2019/09/27 Javascript
vue 解决数组赋值无法渲染在页面的问题
2019/10/28 Javascript
[01:48]2018DOTA2亚洲邀请赛主赛事第二日五佳镜头 VG完美团战逆转TNC
2018/04/05 DOTA
[00:32]2018DOTA2亚洲邀请赛出场——VP
2018/04/04 DOTA
跟老齐学Python之网站的结构
2014/10/24 Python
Python写的Tkinter程序屏幕居中方法
2015/03/10 Python
Python使用正则表达式过滤或替换HTML标签的方法详解
2017/09/25 Python
python实现对excel进行数据剔除操作实例
2017/12/07 Python
Python判断两个对象相等的原理
2017/12/12 Python
Python Django实现layui风格+django分页功能的例子
2019/08/29 Python
HTML5 Canvas中使用路径描画二阶、三阶贝塞尔曲线
2015/01/01 HTML / CSS
高清安全摄像头系统:Lorex Technology
2018/07/20 全球购物
法国床上用品商店:La Compagnie du lit
2019/12/26 全球购物
工程业务员工作职责
2013/12/07 职场文书
物业门卫岗位职责
2013/12/28 职场文书
个人职业生涯规划书1500字
2013/12/31 职场文书
社团成立邀请函
2014/01/08 职场文书
电焊工工作岗位职责
2014/02/06 职场文书
战友聚会主持词
2014/04/02 职场文书
CSS3鼠标悬浮过渡缩放效果
2021/04/17 HTML / CSS