Pytorch之保存读取模型实例


Posted in Python onDecember 30, 2019

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python yield使用方法示例
Dec 04 Python
python spyder中读取txt为图片的方法
Apr 27 Python
python使用openpyxl库修改excel表格数据方法
May 03 Python
浅谈关于Python3中venv虚拟环境
Aug 01 Python
python整小时 整天时间戳获取算法示例
Feb 20 Python
详解Python解决抓取内容乱码问题(decode和encode解码)
Mar 29 Python
Python PIL读取的图像发生自动旋转的实现方法
Jul 05 Python
Python参数类型以及常见的坑详解
Jul 08 Python
python函数声明和调用定义及原理详解
Dec 02 Python
Python try except else使用详解
Jan 12 Python
Python实现网络聊天室的示例代码(支持多人聊天与私聊)
Jan 27 Python
Pandas-DataFrame知识点汇总
Mar 16 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
Python类反射机制使用实例解析
Dec 30 #Python
Python读取YAML文件过程详解
Dec 30 #Python
You might like
IIS+PHP+MySQL+Zend配置 (视频教程)
2006/12/13 PHP
php使用number_format函数截取小数的方法分析
2016/05/27 PHP
适合PHP初学者阅读的4本经典书籍
2016/09/23 PHP
PHP魔术方法以及关于独立实例与相连实例的全面讲解
2016/10/18 PHP
JavaScript 对象链式操作测试代码
2010/04/25 Javascript
JS代码判断IE6,IE7,IE8,IE9的函数代码
2013/08/02 Javascript
让checkbox不选中即将选中的checkbox不选中
2014/07/11 Javascript
JS实现不使用图片仿Windows右键菜单效果代码
2015/10/22 Javascript
后端接收不到AngularJs中$http.post发送的数据原因分析及解决办法
2016/07/05 Javascript
js select实现省市区联动选择
2020/04/17 Javascript
使用JQuery选择HTML遍历函数的方法
2016/09/17 Javascript
AngularJS递归指令实现Tree View效果示例
2016/11/07 Javascript
Jquery Easyui选项卡组件Tab使用详解(10)
2016/12/18 Javascript
非常优秀的JS图片轮播插件Swiper的用法
2017/01/03 Javascript
jQuery Ajax前后端使用JSON进行交互示例
2017/03/17 Javascript
ES6中class类用法实例浅析
2017/04/06 Javascript
JavaScript设置名字输入不合法的实现方法
2017/05/23 Javascript
Angular4 中常用的指令入门总结
2017/06/12 Javascript
Vue2.x通用条件搜索组件的封装及应用详解
2019/05/28 Javascript
vue spa应用中的路由缓存问题与解决方案
2019/05/31 Javascript
vscode 插件开发 + vue的操作方法
2020/06/05 Javascript
[02:29]完美世界高校联赛上海赛区回顾
2015/12/15 DOTA
在Python下进行UDP网络编程的教程
2015/04/29 Python
Python2实现的图片文本识别功能详解
2018/07/11 Python
python的一些加密方法及python 加密模块
2019/07/11 Python
Python安装selenium包详细过程
2019/07/23 Python
python字符串格式化方式解析
2019/10/19 Python
Python数据结构dict常用操作代码实例
2020/03/12 Python
python爬虫基础之urllib的使用
2020/12/31 Python
selenium+python自动化78-autoit参数化与批量上传功能的实现
2021/03/04 Python
CSS实现限制字数功能当对象内文本溢出时显示省略标记
2014/08/20 HTML / CSS
斯洛伐克最大的婴儿食品和用品网上商店:Feedo.sk
2020/12/21 全球购物
TCP/IP中的TCP和IP分别承担什么责任
2012/04/21 面试题
退伍老兵事迹材料
2014/01/31 职场文书
文员试用期转正自我鉴定
2014/09/14 职场文书
Python实现批量将文件复制到新的目录中再修改名称
2022/04/12 Python