解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python爬虫_自动获取seebug的poc实例
Aug 05 Python
Pyinstaller将py打包成exe的实例
Mar 31 Python
PyQt5每天必学之工具提示功能
Apr 19 Python
Python实现按照指定要求逆序输出一个数字的方法
Apr 19 Python
详解python 注释、变量、类型
Aug 10 Python
Python 支付整合开发包的实现
Jan 23 Python
Django框架设置cookies与获取cookies操作详解
May 27 Python
Python 函数绘图及函数图像微分与积分
Nov 20 Python
Python pygame绘制文字制作滚动文字过程解析
Dec 12 Python
pytorch的梯度计算以及backward方法详解
Jan 10 Python
使用pycharm和pylint检查python代码规范操作
Jun 09 Python
python中pop()函数的语法与实例
Dec 01 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
PHP.MVC的模板标签系统(二)
2006/09/05 PHP
提升PHP执行速度全攻略(上)
2006/10/09 PHP
php auth_http类库进行身份效验
2009/03/19 PHP
PHP中的生成XML文件的4种方法分享
2012/10/06 PHP
Javascript将string类型转换int类型
2010/12/09 Javascript
如何让你的Lightbox支持滚轮缩放及Base64图片
2014/12/04 Javascript
JS设置下拉列表框当前所选值的方法
2015/12/22 Javascript
js获取当前日期时间及其它日期操作汇总
2016/03/08 Javascript
几种经典排序算法的JS实现方法
2016/03/25 Javascript
node.js连接mongoDB数据库 快速搭建自己的web服务
2016/04/17 Javascript
jQuery可见性过滤选择器用法示例
2016/09/09 Javascript
详解AngularJs中$resource和restfu服务端数据交互
2016/09/21 Javascript
web前端开发中常见的多列布局解决方案整理(一定要看)
2017/10/15 Javascript
vue+element table表格实现动态列筛选的示例代码
2021/01/14 Vue.js
Django1.7+python 2.78+pycharm配置mysql数据库
2016/10/09 Python
浅谈Python实现Apriori算法介绍
2017/12/20 Python
使用Django2快速开发Web项目的详细步骤
2019/01/06 Python
Python模拟百度自动输入搜索功能的实例
2019/02/14 Python
Python使用微信itchat接口实现查看自己微信的信息功能详解
2019/08/22 Python
python 一篇文章搞懂装饰器所有用法(建议收藏)
2019/08/23 Python
Python 限定函数参数的类型及默认值方式
2019/12/24 Python
解决Python spyder显示不全df列和行的问题
2020/04/20 Python
Django 解决开发自定义抛出异常的问题
2020/05/21 Python
selenium判断元素是否存在的两种方法小结
2020/12/07 Python
pycharm 复制代码出现空格的解决方式
2021/01/15 Python
整理的15个非常有用的 HTML5 开发教程和速查手册
2011/10/18 HTML / CSS
英国屋顶用品和材料超市:Roofing Supplies UK
2019/08/24 全球购物
聚网科技C++面试笔试题
2015/09/01 面试题
高中运动会入场词
2014/02/14 职场文书
安全生产网格化管理实施方案
2014/03/01 职场文书
学校党委副书记个人对照检查材料思想汇报
2014/09/28 职场文书
渠道运营商合作协议书范本
2014/10/06 职场文书
幼儿园老师新年寄语2015
2014/12/08 职场文书
2014年共青团工作总结
2014/12/10 职场文书
一年级语文下册复习计划
2015/01/17 职场文书
关于Nginx中虚拟主机的一些冷门知识小结
2022/03/03 Servers