Pytorch之保存读取模型实例


Posted in Python onDecember 30, 2019

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之坑爹的字符编码
Sep 28 Python
Python 基于Twisted框架的文件夹网络传输源码
Aug 28 Python
关于Python中异常(Exception)的汇总
Jan 18 Python
Python使用QRCode模块生成二维码实例详解
Jun 14 Python
TensorFlow损失函数专题详解
Apr 26 Python
在Python中定义一个常量的方法
Nov 10 Python
Django框架搭建的简易图书信息网站案例
May 25 Python
python函数与方法的区别总结
Jun 23 Python
详解利用python+opencv识别图片中的圆形(霍夫变换)
Jul 01 Python
python将类似json的数据存储到MySQL中的实例
Jul 12 Python
带你学习Python如何实现回归树模型
Jul 16 Python
Python的三个重要函数详解
Jan 18 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
Python类反射机制使用实例解析
Dec 30 #Python
Python读取YAML文件过程详解
Dec 30 #Python
You might like
php实现加减法验证码代码
2014/02/14 PHP
ThinkPHP3.1新特性之对分组支持的改进与完善概述
2014/06/19 PHP
PHP图像处理之imagecreate、imagedestroy函数介绍
2014/11/19 PHP
PHP实现双链表删除与插入节点的方法示例
2017/11/11 PHP
thinkPHP+mysql+ajax实现的仿百度一下即时搜索效果详解
2019/07/15 PHP
PHP+Redis事务解决高并发下商品超卖问题(推荐)
2020/08/03 PHP
javascript 字符 Escape,encodeURI,encodeURIComponent
2009/07/09 Javascript
使用jQuery.Validate进行客户端验证(初级篇) 不使用微软验证控件的理由
2010/06/28 Javascript
JavaScript实现复制功能各浏览器支持情况实测
2013/07/18 Javascript
jq实现酷炫的鼠标经过图片翻滚效果
2014/03/12 Javascript
jQuery中:button选择器用法实例
2015/01/04 Javascript
js中数组结合字符串实现查找(屏蔽广告判断url等)
2016/03/30 Javascript
JS简单实现仿百度控制台输出信息效果
2016/09/04 Javascript
Vue.js实战之通过监听滚动事件实现动态锚点
2017/04/04 Javascript
基于rollup的组件库打包体积优化小结
2018/06/18 Javascript
深入理解使用Vue实现Context-Menu的思考与总结
2019/03/09 Javascript
微信小程序实现人脸识别登陆的示例代码
2019/04/02 Javascript
JavaScript/TypeScript 实现并发请求控制的示例代码
2021/01/18 Javascript
Python实现的异步代理爬虫及代理池
2017/03/17 Python
一个基于flask的web应用诞生(1)
2017/04/11 Python
Python中使用多进程来实现并行处理的方法小结
2017/08/09 Python
Python3导入CSV文件的实例(跟Python2有些许的不同)
2018/06/22 Python
Flask框架WTForm表单用法示例
2018/07/20 Python
Python全局变量与局部变量区别及用法分析
2018/09/03 Python
python 对多个csv文件分别进行处理的方法
2019/01/07 Python
Python二维码生成识别实例详解
2019/07/16 Python
使用Python爬取弹出窗口信息的实例
2020/03/14 Python
如何利用python发送邮件
2020/09/26 Python
纽约现代艺术博物馆商店:MoMA STORE(室内家具和杂货商品)
2016/08/02 全球购物
肯尼亚网上商城:Kilimall
2016/08/20 全球购物
L’Artisan Parfumeur官网:法国香水品牌
2020/08/11 全球购物
WebSphere面试题:在WebSphere里面如何部署一个应用
2015/08/02 面试题
初一英语教学反思
2014/01/11 职场文书
《月亮湾》教学反思
2014/04/14 职场文书
航空学院求职信
2014/06/11 职场文书
零基础学java之带参数以及返回值的方法
2022/04/10 Java/Android