解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
通过Python爬虫代理IP快速增加博客阅读量
Dec 14 Python
python3.0 模拟用户登录,三次错误锁定的实例
Nov 02 Python
python使用sqlite3时游标使用方法
Mar 13 Python
TensorFlow打印tensor值的实现方法
Jul 27 Python
tensorflow 打印内存中的变量方法
Jul 30 Python
对python tkinter窗口弹出置顶的方法详解
Jun 14 Python
利用python numpy+matplotlib绘制股票k线图的方法
Jun 26 Python
python读取并定位excel数据坐标系详解
Jun 26 Python
python 反编译exe文件为py文件的实例代码
Jun 27 Python
pandas计算最大连续间隔的方法
Jul 04 Python
解决django FileFIELD的编码问题
Mar 30 Python
使用pygame实现垃圾分类小游戏功能(已获校级二等奖)
Jul 23 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
PHP个人网站架设连环讲(二)
2006/10/09 PHP
深入探讨<br />和 \r\n两者有什么区别??
2013/06/05 PHP
PHP+Mysql基于事务处理实现转账功能的方法
2015/07/08 PHP
EarthLiveSharp中cloudinary的CDN图片缓存自动清理python脚本
2017/04/04 PHP
php 静态属性和静态方法区别详解
2017/04/09 PHP
PHPTree――php快速生成无限级分类
2018/03/30 PHP
PHP安装memcache扩展的步骤讲解
2019/02/14 PHP
ie 处理 gif动画 的onload 事件的一个 bug
2007/04/12 Javascript
基于jquery的给文章加入关键字链接
2010/10/26 Javascript
关于jquery.validate1.9.0前台验证的使用介绍
2013/04/26 Javascript
浅析JavaScript中的delete运算符
2013/11/30 Javascript
简介JavaScript中的setTime()方法的使用
2015/06/11 Javascript
javascript实现的闭包简单实例
2015/07/17 Javascript
WEB前端开发都应知道的jquery小技巧及jquery三个简写
2015/11/15 Javascript
Knockoutjs 学习系列(二)花式捆绑
2016/06/07 Javascript
深入理解Node.js的HTTP模块
2016/10/12 Javascript
JavaScript实现格式化字符串函数String.format
2016/12/16 Javascript
求js数组的最大值和最小值的四种方法
2017/03/03 Javascript
Express的HTTP重定向到HTTPS的方法
2018/06/06 Javascript
React之PureComponent的使用作用
2018/07/10 Javascript
js实现的格式化数字和金额功能简单示例
2019/07/30 Javascript
js中的this的指向问题详解
2019/08/29 Javascript
ElementUI中el-tree节点的操作的实现
2020/02/27 Javascript
python完成FizzBuzzWhizz问题(拉勾网面试题)示例
2014/05/05 Python
Python的Flask框架及Nginx实现静态文件访问限制功能
2016/06/27 Python
django文档学习之applications使用详解
2018/01/29 Python
使用Python 自动生成 Word 文档的教程
2020/02/13 Python
Python 多线程共享变量的实现示例
2020/04/17 Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
2020/07/03 Python
python中spy++的使用超详细教程
2021/01/29 Python
美国电视购物:QVC
2017/02/06 全球购物
法律专业自荐信
2014/06/03 职场文书
大学生学年个人总结
2015/02/15 职场文书
毕业设计致谢词
2015/05/14 职场文书
元旦主持词开场白
2015/05/29 职场文书
Pytorch可视化的几种实现方法
2021/06/10 Python