解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题


Posted in Python onJune 23, 2020

背景

在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。

原因

DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。

解决方法

1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i], check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

补充知识:Pytorch:多GPU训练网络与单GPU训练网络保存模型的区别

测试环境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU训练网络需要用到 【nn.DataParallel】:

gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)

而使用单GPU训练网络:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet().to(device)

由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )
)

而不使用多GPU训练的网络结构如下:

LeNet(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

由于在测试模型时不需要用到多GPU测试,因此在保存模型时应该把module层去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(), "model.pth")
else:
  t.save(net.state_dict(), "model.pth")

以上这篇解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python不规范的日期字符串处理类
Jun 10 Python
python脚本爬取字体文件的实现方法
Apr 29 Python
python读取与写入csv格式文件的示例代码
Dec 16 Python
python numpy和list查询其中某个数的个数及定位方法
Jun 27 Python
Python使用pydub库对mp3与wav格式进行互转的方法
Jan 10 Python
python 自动轨迹绘制的实例代码
Jul 05 Python
Python Web程序搭建简单的Web服务器
Jul 31 Python
Django文件存储 默认存储系统解析
Aug 02 Python
用Python做一个久坐提醒小助手的示例代码
Feb 10 Python
30行Python代码实现高分辨率图像导航的方法
May 22 Python
Numpy 多维数据数组的实现
Jun 18 Python
Python进行特征提取的示例代码
Oct 15 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
You might like
极典R601SW收音机
2021/03/02 无线电
php输出echo、print、print_r、printf、sprintf、var_dump的区别比较
2013/06/21 PHP
PHP中include与require使用方法区别详解
2013/10/19 PHP
采用thinkphp自带方法生成静态html文件详解
2014/06/13 PHP
destoon后台网站设置变成空白的解决方法
2014/06/21 PHP
PHP迭代器和生成器用法实例分析
2019/09/28 PHP
php使用pthreads v3多线程实现抓取新浪新闻信息操作示例
2020/02/21 PHP
获取页面高度,窗口高度,滚动条高度等参数值getPageSize,getPageScroll
2006/09/22 Javascript
js判断页面中是否有指定控件的简单实例
2014/03/04 Javascript
js获得当前系统日期时间的方法
2015/05/06 Javascript
AngularJS框架的ng-app指令与自动加载实现方法分析
2017/01/04 Javascript
原生js仿浏览器滚动条效果
2017/03/02 Javascript
javascript 中iframe高度自适应(同域)实例详解
2017/05/16 Javascript
jquery实现放大镜简洁代码(推荐)
2017/06/08 jQuery
JS模拟实现哈希表及应用详解
2018/05/04 Javascript
基于JavaScript实现简单的轮播图
2021/03/03 Javascript
[05:26]2014DOTA2西雅图国际邀请赛 iG战队巡礼
2014/07/07 DOTA
[01:12:35]Spirit vs Navi Supermajor小组赛 A组败者组第一轮 BO3 第二场 6.2
2018/06/03 DOTA
Python random模块常用方法
2014/11/03 Python
python递归删除指定目录及其所有内容的方法
2017/01/13 Python
树莓派实现移动拍照
2019/06/22 Python
python paramiko远程服务器终端操作过程解析
2019/12/14 Python
python实现扫雷小游戏
2020/04/24 Python
Python如何创建装饰器时保留函数元信息
2020/08/07 Python
使用Python Tkinter实现剪刀石头布小游戏功能
2020/10/23 Python
美国知名的女性服饰品牌:LOFT(洛芙特)
2016/08/05 全球购物
美国婴儿和儿童家具网上商店:ABaby.com
2018/07/02 全球购物
俄罗斯厨房产品购物网站:COOK HOUSE
2021/03/15 全球购物
企业文化建设实施方案
2014/03/22 职场文书
超市商业计划书
2014/05/04 职场文书
安全隐患整改报告
2014/11/06 职场文书
承诺函格式模板
2015/01/21 职场文书
三方协议书
2015/01/27 职场文书
MySQL 聚合函数排序
2021/07/16 MySQL
源码分析Redis中 set 和 sorted set 的使用方法
2022/03/22 Redis
Win10加载疑难解答时出错发生意外错误的解决方法
2022/07/07 数码科技