解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用htpasswd实现基本认证授权的例子
Jun 10 Python
Django如何配置mysql数据库
May 04 Python
python3.5基于TCP实现文件传输
Mar 20 Python
解决安装python库时windows error5 报错的问题
Oct 21 Python
对python中数据集划分函数StratifiedShuffleSplit的使用详解
Dec 11 Python
python3.4爬虫demo
Jan 22 Python
python下载微信公众号相关文章
Feb 26 Python
Laravel框架表单验证格式化输出的方法
Sep 25 Python
PyCharm导入python项目并配置虚拟环境的教程详解
Oct 13 Python
PyTorch中的padding(边缘填充)操作方式
Jan 03 Python
Python threading模块condition原理及运行流程详解
Oct 05 Python
Python并发编程实例教程之线程的玩法
Jun 20 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
php判断输入不超过mysql的varchar字段的长度范围
2011/06/24 PHP
php模拟ping命令(php exec函数的使用方法)
2013/10/25 PHP
php使用CURL不依赖COOKIEJAR获取COOKIE的方法
2015/06/17 PHP
一个非常实用的php文件上传类
2017/07/04 PHP
javascript生成随机颜色示例代码
2014/05/05 Javascript
浅析JavaScript 箭头函数 generator Date JSON
2016/05/23 Javascript
微信小程序中的swiper组件详解
2017/04/14 Javascript
AngularJS实现进度条功能示例
2017/07/05 Javascript
原生JS实现多个小球碰撞反弹效果示例
2018/01/31 Javascript
Vue仿今日头条实例详解
2018/02/06 Javascript
React中的render何时执行过程
2018/04/13 Javascript
基于D3.js实现时钟效果
2018/07/17 Javascript
Vue中用props给data赋初始值遇到的问题解决
2018/11/27 Javascript
详解js 创建对象的几种方法
2019/03/08 Javascript
微信小程序如何连接Java后台
2019/08/08 Javascript
微信小程序 scroll-view 水平滚动实现过程解析
2019/10/12 Javascript
Python实现两款计算器功能示例
2017/12/19 Python
python实现requests发送/上传多个文件的示例
2018/06/04 Python
对python:print打印时加u的含义详解
2018/12/15 Python
浅谈pandas筛选出表中满足另一个表所有条件的数据方法
2019/02/08 Python
django基于存储在前端的token用户认证解析
2019/08/06 Python
python3爬取torrent种子链接实例
2020/01/16 Python
django queryset相加和筛选教程
2020/05/18 Python
keras使用Sequence类调用大规模数据集进行训练的实现
2020/06/22 Python
Django实现微信小程序支付的示例代码
2020/09/03 Python
运动会四百米广播稿
2014/01/19 职场文书
幼儿园运动会加油词
2014/02/14 职场文书
《童年的发现》教学反思
2014/02/14 职场文书
道路交通安全实施方案
2014/03/12 职场文书
代办委托书怎么写
2014/08/01 职场文书
个人总结怎么写
2015/02/26 职场文书
2015年乡镇卫生院工作总结
2015/04/22 职场文书
2016年教师政治思想表现评语
2015/12/02 职场文书
史上最全书信经典范文大全(建议收藏)
2019/07/10 职场文书
深入理解go缓存库freecache的使用
2022/02/15 Golang
Python绘制散点图之可视化神器pyecharts
2022/07/07 Python