解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在漏洞利用Python代码真的很爽
Aug 26 Python
python数据结构树和二叉树简介
Apr 29 Python
Django集成百度富文本编辑器uEditor攻略
Jul 04 Python
Python3中常用的处理时间和实现定时任务的方法的介绍
Apr 07 Python
用Python计算三角函数之acos()方法的使用
May 15 Python
在Python的Django框架中编写错误提示页面
Jul 22 Python
解决Numpy中sum函数求和结果维度的问题
Dec 06 Python
Python计算机视觉里的IOU计算实例
Jan 17 Python
python 截取XML中bndbox的坐标中的图像,另存为jpg的实例
Mar 10 Python
Python configparser模块封装及构造配置文件
Aug 07 Python
python判断一个变量是否已经设置的方法
Aug 13 Python
python实现KNN近邻算法
Dec 30 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
php 图片上传类代码
2009/07/17 PHP
回帖脱衣服的图片实现代码
2014/02/15 PHP
浅谈laravel orm 中的一对多关系 hasMany
2019/10/21 PHP
JAVASCRIPT  THIS详解 面向对象
2009/03/25 Javascript
EXTJS内使用ACTIVEX控件引起崩溃问题的解决方法
2010/03/31 Javascript
快速解决jQuery与其他库冲突的方法介绍
2014/01/02 Javascript
提高NodeJS中SSL服务的性能
2014/07/15 NodeJs
JavaScript DOM操作表格及样式
2015/04/13 Javascript
ECMAScript6中Set/WeakSet详解
2015/06/12 Javascript
JS实现霓虹灯文字效果的方法
2015/08/06 Javascript
基于javascript编写简单日历
2016/05/02 Javascript
深入浅析search 搜索框的写法
2016/08/02 Javascript
使用vue编写一个点击数字计时小游戏
2016/08/31 Javascript
详解React开发中使用require.ensure()按需加载ES6组件
2017/05/12 Javascript
jQuery+ajax实现动态添加表格tr td功能示例
2018/04/23 jQuery
原生JS实现DOM加载完成马上执行JS代码的方法
2018/09/07 Javascript
微信小程序列表中item左滑删除功能
2018/11/07 Javascript
Bootstrap 按钮样式与使用代码详解
2018/12/09 Javascript
小程序云开发获取不到数据库记录的解决方法
2019/05/18 Javascript
Node.js API详解之 string_decoder用法实例分析
2020/04/29 Javascript
详解Vue之计算属性
2020/06/20 Javascript
Echarts.js无法引入问题解决方案
2020/10/30 Javascript
Python文件的读写和异常代码示例
2017/10/31 Python
详谈python3中用for循环删除列表中元素的坑
2018/04/19 Python
python 实现登录网页的操作方法
2018/05/11 Python
TensorFlow 合并/连接数组的方法
2018/07/27 Python
详解如何将python3.6软件的py文件打包成exe程序
2018/10/09 Python
对Python函数设计规范详解
2019/07/19 Python
基于Python绘制个人足迹地图
2020/06/01 Python
美国糖果店:Sugarfina
2019/02/21 全球购物
广告学专业毕业生自荐信
2013/09/24 职场文书
就业表自我评价分享
2014/02/06 职场文书
农村优秀教师事迹材料
2014/08/27 职场文书
2015年高校教师个人工作总结
2015/05/25 职场文书
pytorch 预训练模型读取修改相关参数的填坑问题
2021/06/05 Python
详细聊聊关于Mysql联合查询的那些事儿
2021/10/24 MySQL