解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3简单实现微信爬虫
Apr 09 Python
Python打造出适合自己的定制化Eclipse IDE
Mar 02 Python
获取python的list中含有重复值的index方法
Jun 27 Python
对pycharm 修改程序运行所需内存详解
Dec 03 Python
详解python爬虫系列之初识爬虫
Apr 06 Python
程序员的七夕用30行代码让Python化身表白神器
Aug 07 Python
python 实现手机自动拨打电话的方法(通话压力测试)
Aug 08 Python
Python操作SQLite/MySQL/LMDB数据库的方法
Nov 07 Python
Python中包的用法及安装
Feb 11 Python
Spring Boot中使用IntelliJ IDEA插件EasyCode一键生成代码详细方法
Mar 20 Python
Python基于QQ邮箱实现SSL发送
Apr 26 Python
python3 字符串str和bytes相互转换
Mar 23 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
Yii学习总结之数据访问对象 (DAO)
2015/02/22 PHP
codeigniter显示所有脚本执行时间的方法
2015/03/21 PHP
php workerman定时任务的实现代码
2018/12/23 PHP
详解laravel passport OAuth2.0的4种模式
2019/11/04 PHP
Jquery submit()无法提交问题
2013/04/21 Javascript
js获取滚动距离的方法
2015/05/30 Javascript
详解JavaScript中的客户端消息框架设计原理
2015/06/24 Javascript
详解javascript遍历方式
2015/11/11 Javascript
javascript设置页面背景色及背景图片的方法
2015/12/29 Javascript
利用JS判断字符串是否含有数字与特殊字符的方法小结
2016/11/25 Javascript
angular实现表单验证及提交功能
2017/02/01 Javascript
微信小程序中做用户登录与登录态维护的实现详解
2017/05/17 Javascript
将angular.js项目整合到.net mvc中的方法详解
2017/06/29 Javascript
AngularJS实现进度条功能示例
2017/07/05 Javascript
View.post() 不靠谱的地方你知道多少
2017/08/29 Javascript
Nodejs 发布自己的npm包并制作成命令行工具的实例讲解
2018/05/15 NodeJs
Vue触发隐藏input file的方法实例详解
2019/08/14 Javascript
JS操作Fckeditor的一些常用方法(获取、插入等)
2020/02/19 Javascript
Vue用mixin合并重复代码的实现
2020/11/27 Vue.js
[33:15]2018DOTA2亚洲邀请赛3月30日 小组赛B组 VP VS Mineski
2018/03/31 DOTA
[01:08:56]DOTA2-DPC中国联赛 正赛 Magma vs LBZS BO3 第一场 2月7日
2021/03/11 DOTA
Python入门篇之字符串
2014/10/17 Python
跟老齐学Python之通过Python连接数据库
2014/10/28 Python
python计算列表内各元素的个数实例
2018/06/29 Python
Python爬虫PyQuery库基本用法入门教程
2018/08/04 Python
python+opencv像素的加减和加权操作的实现
2019/07/14 Python
Python3实现配置文件差异对比脚本
2019/11/18 Python
jupyter lab文件导出/下载方式
2020/04/22 Python
毕业生机械建模求职信
2013/10/14 职场文书
幼师自荐信
2013/10/26 职场文书
原材料检验岗位职责
2014/03/15 职场文书
聘任书的写作格式及范文
2014/03/29 职场文书
党的群众路线教育实践活动总结大会主持词
2014/10/30 职场文书
幼儿园大班教学反思
2016/03/02 职场文书
学会掌握自己命运的十条黄金法则:
2019/08/08 职场文书
十大经典日本动漫排行榜 海贼王第三,犬夜叉仅第八
2022/03/18 日漫