Pytorch之保存读取模型实例


Posted in Python onDecember 30, 2019

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python给文本创立向量空间模型的教程
Apr 23 Python
Python中super()函数简介及用法分享
Jul 11 Python
彻底理解Python list切片原理
Oct 27 Python
python的pandas工具包,保存.csv文件时不要表头的实例
Jun 14 Python
Python从单元素字典中获取key和value的实例
Dec 31 Python
django项目简单调取百度翻译接口的方法
Aug 06 Python
详解Python3迁移接口变化采坑记
Oct 11 Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 Python
解决python -m pip install --upgrade pip 升级不成功问题
Mar 05 Python
Python Django路径配置实现过程解析
Nov 05 Python
python脚本定时发送邮件
Dec 22 Python
Pandas自定义选项option设置
Jul 25 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
Python类反射机制使用实例解析
Dec 30 #Python
Python读取YAML文件过程详解
Dec 30 #Python
You might like
坏狼的PHP学习教程之第2天
2008/06/15 PHP
两款万能的php分页类
2015/11/12 PHP
php 升级到 5.3+ 后出现的一些错误,如 ereg(); ereg_replace(); 函数报错
2015/12/07 PHP
PHP MPDF中文乱码的解决方式
2015/12/08 PHP
php车辆违章查询数据示例
2016/10/14 PHP
Javascript编程之继承实例汇总
2015/11/28 Javascript
基于JavaScript实现瀑布流布局(二)
2016/01/26 Javascript
深入理解bootstrap框架之第二章整体架构
2016/10/09 Javascript
jQuery插件HighCharts绘制2D半圆环图效果示例【附demo源码下载】
2017/03/09 Javascript
微信小程序 实现动态显示和隐藏某个控件
2017/04/27 Javascript
简单的网页广告特效实例
2017/08/19 Javascript
canvas+gif.js打造自己的数字雨头像的示例代码
2017/10/26 Javascript
详解使用jest对vue项目进行单元测试
2018/09/07 Javascript
Vue瀑布流插件的使用示例
2018/09/19 Javascript
在 Vue-CLI 中引入 simple-mock实现简易的 API Mock 接口数据模拟
2018/11/28 Javascript
微信小程序-API接口安全详解
2019/07/16 Javascript
[01:57]DOTA2上海特锦赛小组赛解说单车采访花絮
2016/02/27 DOTA
python计数排序和基数排序算法实例
2014/04/25 Python
Python中处理时间的几种方法小结
2015/04/09 Python
Python获取linux主机ip的简单实现方法
2016/04/18 Python
pygame实现俄罗斯方块游戏
2018/06/26 Python
Django中信号signals的简单使用方法
2019/07/04 Python
解决Python正则表达式匹配反斜杠''\''问题
2019/07/17 Python
css3实例教程 一款纯css3实现的发光屏幕旋转特效
2014/12/07 HTML / CSS
利用CSS3伪元素实现逐渐发光的方格边框
2017/05/07 HTML / CSS
Lime Crime官网:美国一家主打梦幻精灵系的彩妆品牌
2019/03/22 全球购物
能否解释一下XSS cookie盗窃是什么意思
2012/06/02 面试题
会计专业自荐信
2013/12/02 职场文书
旷课检讨书2000字
2014/01/14 职场文书
2014年清明节网上祭英烈寄语
2014/04/09 职场文书
纪律教育月活动总结
2014/08/26 职场文书
财政局党的群众路线教育实践活动整改方案
2014/09/21 职场文书
2014高三学生考试作弊检讨书
2014/12/14 职场文书
2016年党员公开承诺书范文
2016/03/24 职场文书
你对自己的信用报告有过了解吗?
2019/07/09 职场文书
python基于turtle绘制几何图形
2021/06/15 Python