解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中非常实用的一些功能和函数分享
Feb 14 Python
Python去除字符串两端空格的方法
May 21 Python
Python urls.py的三种配置写法实例详解
Apr 28 Python
详解用python实现简单的遗传算法
Jan 02 Python
python sys,os,time模块的使用(包括时间格式的各种转换)
Apr 27 Python
Python函数参数操作详解
Aug 03 Python
python使用PIL模块获取图片像素点的方法
Jan 08 Python
Django框架中序列化和反序列化的例子
Aug 06 Python
python实现电子词典
Mar 03 Python
配置python的编程环境之Anaconda + VSCode的教程
Mar 29 Python
Python类的继承super相关原理解析
Oct 22 Python
Python下载的11种姿势(小结)
Nov 18 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
浅析php适配器模式(Adapter)
2014/11/25 PHP
php实现贪吃蛇小游戏
2016/07/26 PHP
php获取微信共享收货地址的方法
2017/12/21 PHP
jquery实现弹出div,始终显示在屏幕正中间的简单实例
2014/03/08 Javascript
禁用Enter键表单自动提交实现代码
2014/05/22 Javascript
用html+css+js实现的一个简单的图片切换特效
2014/05/28 Javascript
nodejs实现黑名单中间件设计
2014/06/17 NodeJs
JavaScript实现SHA-1加密算法的方法
2015/03/11 Javascript
javascript基础语法——全面理解变量和标识符
2016/06/02 Javascript
Vue.js教程之axios与网络传输的学习实践
2017/04/29 Javascript
vue axios 二次封装的示例代码
2017/12/08 Javascript
详解Angular Forms中自定义ngModel绑定值的方式
2018/12/10 Javascript
JavaScript面试技巧之数组的一些不low操作
2019/03/22 Javascript
Vue 动态添加路由及生成菜单的方法示例
2019/06/20 Javascript
nodejs文件夹深层复制功能
2019/09/03 NodeJs
Vue+Element-UI实现上传图片并压缩
2019/11/26 Javascript
jQuery 判断元素是否存在然后按需加载内容的实现代码
2020/01/16 jQuery
javascript 数组(list)添加/删除的实现
2020/12/17 Javascript
vue实现图书管理系统
2020/12/29 Vue.js
[26:24]完美副总裁、DOTA2负责人蔡玮专访:电竞如人生
2014/09/11 DOTA
python中的sort方法使用详解
2014/07/25 Python
Python编码类型转换方法详解
2016/07/01 Python
Python多线程threading和multiprocessing模块实例解析
2018/01/29 Python
python爬虫正则表达式之处理换行符
2018/06/08 Python
浅谈pytorch卷积核大小的设置对全连接神经元的影响
2020/01/10 Python
python3安装OCR识别库tesserocr过程图解
2020/04/02 Python
Python用Jira库来操作Jira
2020/12/28 Python
澳大利亚最大的女装零售商:Millers
2017/09/10 全球购物
Ootori在线按摩椅店:一家专业的按摩椅制造商
2019/04/10 全球购物
会计顶岗实习心得
2014/01/25 职场文书
食品采购员岗位职责
2014/04/14 职场文书
2014年城市管理工作总结
2014/12/02 职场文书
2016年教师节感恩寄语
2015/12/04 职场文书
vue实现简单数据双向绑定
2021/04/28 Vue.js
Nginx反向代理学习实例教程
2021/10/24 Servers
MySQL 自动填充 create_time 和 update_time
2022/05/20 MySQL