解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现微信公众平台自定义菜单实例
Mar 20 Python
Python3.2中Print函数用法实例详解
May 19 Python
python实现批量修改文件名代码
Sep 10 Python
Python解析命令行读取参数--argparse模块使用方法
Jan 23 Python
python和shell获取文本内容的方法
Jun 05 Python
pybind11在Windows下的使用教程
Jul 04 Python
python+numpy实现的基本矩阵操作示例
Jul 19 Python
python实现桌面托盘气泡提示
Jul 29 Python
Python 使用 docopt 解析json参数文件过程讲解
Aug 13 Python
Python多线程爬取豆瓣影评API接口
Oct 22 Python
python 读取、写入txt文件的示例
Sep 27 Python
python数据分析之用sklearn预测糖尿病
Apr 22 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
PHP设计模式之注册树模式分析
2018/01/26 PHP
收集的10个免费的jQuery相册
2011/02/26 Javascript
JavaScript函数详解
2014/11/17 Javascript
jquery+php实现搜索框自动提示
2014/11/28 Javascript
js实现仿QQ秀换装效果的方法
2015/03/04 Javascript
js中日期的加减法
2015/05/06 Javascript
JavaScript中Date.toSource()方法的使用教程
2015/06/12 Javascript
jQuery+CSS实现滑动的标签分栏切换效果
2015/12/17 Javascript
基于ajax与msmq技术的消息推送功能实现代码
2016/12/26 Javascript
vue组件中点击按钮后修改输入框的状态实例代码
2017/04/14 Javascript
js Date()日期函数浏览器兼容问题解决方法
2017/09/12 Javascript
AngularJS实现的省市二级联动功能示例【可对选项实现增删】
2017/10/26 Javascript
Angular动态绑定样式及改变UI框架样式的方法小结
2018/09/03 Javascript
轻量级富文本编辑器wangEditor结合vue使用方法示例
2018/10/10 Javascript
对layui中的onevent 和event的使用详解
2019/09/06 Javascript
js中关于Blob对象的介绍与使用
2019/11/29 Javascript
详解vue3.0 的 Composition API 的一种使用方法
2020/10/26 Javascript
Python数组条件过滤filter函数使用示例
2014/07/22 Python
神经网络python源码分享
2017/12/15 Python
简单了解Django模板的使用
2017/12/20 Python
Python实现的knn算法示例
2018/06/14 Python
pycharm运行出现ImportError:No module named的解决方法
2018/10/13 Python
Django 拆分model和view的实现方法
2019/08/16 Python
Python连接SQLite数据库并进行增册改查操作方法详解
2020/02/18 Python
Python实现屏幕录制功能的代码
2020/03/02 Python
Python制作数据预测集成工具(值得收藏)
2020/08/21 Python
CSS3实现精美横向滚动菜单按钮
2017/04/14 HTML / CSS
SIDESTEP荷兰:在线购买鞋子
2019/11/18 全球购物
商务英语毕业生自荐信范文
2013/11/08 职场文书
教师评优事迹材料
2014/01/10 职场文书
安全生产管理合理化建议书
2014/03/12 职场文书
班主任工作经验交流材料
2014/05/13 职场文书
2015年挂职干部工作总结
2015/05/14 职场文书
创业的9条正确思考方式
2019/08/26 职场文书
springboot+VUE实现登录注册
2021/05/27 Vue.js
十大经典日本动漫排行榜 海贼王第三,犬夜叉仅第八
2022/03/18 日漫