pytorch加载自定义网络权重的实现


Posted in Python onJanuary 07, 2020

在将自定义的网络权重加载到网络中时,报错:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

我们一步一步分析。

模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')

(1)查看获取模型权重的源码:

pytorch源码:net.state_dict()

def state_dict(self, destination=None, prefix='', keep_vars=False):
  r"""Returns a dictionary containing a whole state of the module.

  Both parameters and persistent buffers (e.g. running averages) are
  included. Keys are corresponding parameter and buffer names.

  Returns:
    dict:
      a dictionary containing a whole state of the module

  Example::

    >>> module.state_dict().keys()
    ['bias', 'weight']

  """

将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!

(2)查看保存模型权重的源码:

pytorch源码:torch.save()

def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
  """Saves an object to a disk file.

  See also: :ref:`recommend-saving-models`

  Args:
    obj: saved object
    f: a file-like object (has to implement write and flush) or a string
      containing a file name
    pickle_module: module used for pickling metadata and objects
    pickle_protocol: can be specified to override the default protocol

  .. warning::
    If you are using Python 2, torch.save does NOT support StringIO.StringIO
    as a valid file-like object. This is because the write method should return
    the number of bytes written; StringIO.write() does not do this.

    Please use something like io.BytesIO instead.

函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。

解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()

#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))

解决方法很简单,主要记录解决思路。

以上这篇pytorch加载自定义网络权重的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 获取本机ip地址的两个方法
Feb 25 Python
Python中itertools模块用法详解
Sep 25 Python
python模块之StringIO使用示例
Apr 08 Python
python中kmeans聚类实现代码
Feb 23 Python
Python 实现某个功能每隔一段时间被执行一次的功能方法
Oct 14 Python
详解python 模拟豆瓣登录(豆瓣6.0)
Apr 18 Python
使用Fabric自动化部署Django项目的实现
Sep 27 Python
使用Python实现分别输出每个数组
Dec 06 Python
Python telnet登陆功能实现代码
Apr 16 Python
基于SpringBoot构造器注入循环依赖及解决方式
Apr 26 Python
在echarts中图例legend和坐标系grid实现左右布局实例
May 16 Python
Python预测2020高考分数和录取情况
Jul 08 Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
You might like
PHP 文件类型判断代码
2009/03/13 PHP
跟我学Laravel之配置Laravel
2014/10/15 PHP
PHP下使用mysqli的函数连接mysql出现warning: mysqli::real_connect(): (hy000/1040): ...
2016/02/14 PHP
PHP简单预防sql注入的方法
2016/09/27 PHP
js中的escape及unescape函数的php实现代码
2007/09/04 Javascript
使用javascript实现json数据以csv格式下载
2015/01/09 Javascript
对于jQuery性能的一些优化建议
2015/08/13 Javascript
jQuery获取attr()与prop()属性值的方法及区别介绍
2016/07/06 Javascript
Js查找字符串中出现次数最多的字符及个数实例解析
2016/09/05 Javascript
jQuery实现最简单实用的分秒倒计时
2017/02/05 Javascript
Vue.js实现移动端短信验证码功能
2017/03/29 Javascript
mint-ui在vue中的使用示例
2018/04/05 Javascript
Python实现SMTP发送邮件详细教程
2021/03/02 Python
Python实现Mysql数据库连接池实例详解
2017/04/11 Python
python matplotlib中文显示参数设置解析
2017/12/15 Python
python学习之hook钩子的原理和使用
2018/10/25 Python
python文件拆分与重组实例
2018/12/10 Python
浅谈pandas筛选出表中满足另一个表所有条件的数据方法
2019/02/08 Python
零基础使用Python读写处理Excel表格的方法
2019/05/02 Python
python移位运算的实现
2019/07/15 Python
django的ORM操作 删除和编辑实现详解
2019/07/24 Python
python按修改时间顺序排列文件的实例代码
2019/07/25 Python
Python Pandas 如何shuffle(打乱)数据
2019/07/30 Python
详解Django admin高级用法
2019/11/06 Python
Python新手学习标准库模块命名
2020/05/29 Python
CSS3线性渐变简单实现以及该属性在浏览器中的不同
2012/12/12 HTML / CSS
HTML5 自动聚焦(autofocus)属性使用介绍
2013/08/07 HTML / CSS
澳大利亚第一旅行车和房车配件店:Caravan RV Camping
2020/12/26 全球购物
戴尔马来西亚官网:Dell Malaysia
2020/05/02 全球购物
营业经理岗位职责
2013/11/10 职场文书
夏季奶茶店创业计划书
2014/01/16 职场文书
教师节促销活动方案
2014/02/14 职场文书
授权委托书格式范文
2014/08/02 职场文书
党在我心中演讲稿
2014/09/02 职场文书
文员岗位职责
2015/02/04 职场文书
MySQL中EXPLAIN语句及用法
2022/05/20 MySQL