pytorch加载自定义网络权重的实现


Posted in Python onJanuary 07, 2020

在将自定义的网络权重加载到网络中时,报错:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

我们一步一步分析。

模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')

(1)查看获取模型权重的源码:

pytorch源码:net.state_dict()

def state_dict(self, destination=None, prefix='', keep_vars=False):
  r"""Returns a dictionary containing a whole state of the module.

  Both parameters and persistent buffers (e.g. running averages) are
  included. Keys are corresponding parameter and buffer names.

  Returns:
    dict:
      a dictionary containing a whole state of the module

  Example::

    >>> module.state_dict().keys()
    ['bias', 'weight']

  """

将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!

(2)查看保存模型权重的源码:

pytorch源码:torch.save()

def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
  """Saves an object to a disk file.

  See also: :ref:`recommend-saving-models`

  Args:
    obj: saved object
    f: a file-like object (has to implement write and flush) or a string
      containing a file name
    pickle_module: module used for pickling metadata and objects
    pickle_protocol: can be specified to override the default protocol

  .. warning::
    If you are using Python 2, torch.save does NOT support StringIO.StringIO
    as a valid file-like object. This is because the write method should return
    the number of bytes written; StringIO.write() does not do this.

    Please use something like io.BytesIO instead.

函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。

解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()

#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))

解决方法很简单,主要记录解决思路。

以上这篇pytorch加载自定义网络权重的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实现归并排序算法
Jun 05 Python
python列表去重的二种方法
Feb 14 Python
Python实现包含min函数的栈
Apr 29 Python
Python模拟三级菜单效果
Sep 11 Python
详解Python使用tensorflow入门指南
Feb 09 Python
python numpy格式化打印的实例
May 14 Python
python 读取鼠标点击坐标的实例
Dec 29 Python
python抓取搜狗微信公众号文章
Apr 01 Python
Python连接mysql数据库及简单增删改查操作示例代码
Aug 03 Python
python+excel接口自动化获取token并作为请求参数进行传参操作
Nov 10 Python
python 利用opencv实现图像网络传输
Nov 12 Python
用pip给python安装matplotlib库的详细教程
Feb 24 Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
You might like
php处理单文件、多文件上传代码分享
2016/08/24 PHP
php+redis实现商城秒杀功能
2020/11/19 PHP
解决jquery异步按一定的时间间隔刷新问题
2012/12/10 Javascript
js获取URL的参数的方法(getQueryString)示例
2013/09/29 Javascript
js+html5实现canvas绘制网页时钟的方法
2016/05/21 Javascript
只要1K 纯JS脚本送你一朵3D红色玫瑰
2016/08/09 Javascript
js实现ATM机存取款功能
2020/10/27 Javascript
JS实现图片轮播效果实例详解【可自动和手动】
2019/04/04 Javascript
django中使用vue.js的要点总结
2019/07/07 Javascript
详解element-ui中表单验证的三种方式
2019/09/18 Javascript
vue+element-ui JYAdmin后台管理系统模板解析
2020/07/28 Javascript
JavaScript实现多球运动效果
2020/09/07 Javascript
[01:25:33]完美世界DOTA2联赛PWL S3 INK ICE vs Magma 第二场 12.20
2020/12/23 DOTA
python zip文件 压缩
2008/12/24 Python
Python+PIL实现支付宝AR红包
2018/02/09 Python
Python使用matplotlib模块绘制图像并设置标题与坐标轴等信息示例
2018/05/04 Python
基于数据归一化以及Python实现方式
2018/07/11 Python
Python在图片中插入大量文字并且自动换行
2019/01/02 Python
python pickle存储、读取大数据量列表、字典数据的方法
2019/07/07 Python
学习和使用python的13个理由
2019/07/30 Python
Python 单例设计模式用法实例分析
2019/09/23 Python
线程安全及Python中的GIL原理分析
2019/10/29 Python
python如何通过pyqt5实现进度条
2020/01/20 Python
Python用access判断文件是否被占用的实例方法
2020/12/17 Python
be2台湾单身男女交友:全球网路婚姻介绍的领导品牌
2019/10/11 全球购物
什么是组件架构
2016/05/15 面试题
Java面试题:请说出如下代码的输出结果
2013/04/22 面试题
法律专业推荐信范文
2013/11/29 职场文书
企业安全生产责任书范本
2014/07/28 职场文书
医院反腐倡廉演讲稿
2014/09/16 职场文书
面试通知短信
2015/04/20 职场文书
失恋33天观后感
2015/06/11 职场文书
缅怀先烈主题班会
2015/08/14 职场文书
中国现代文学之经典散文三篇
2019/09/18 职场文书
Nginx解决403 forbidden的完整步骤
2021/04/01 Servers
js面向对象编程OOP及函数式编程FP区别
2022/07/07 Javascript