Pytorch之保存读取模型实例


Posted in Python onDecember 30, 2019

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中virtualenvwrapper安装与使用
May 20 Python
详谈Python 窗体(tkinter)表格数据(Treeview)
Oct 11 Python
解决Python2.7中IDLE启动没有反应的问题
Nov 30 Python
python 实现读取一个excel多个sheet表并合并的方法
Feb 12 Python
Python3实现的简单三级菜单功能示例
Mar 12 Python
Python实现的读取文件内容并写入其他文件操作示例
Apr 09 Python
python 3.6.7实现端口扫描器
Sep 04 Python
Python3 Tkinkter + SQLite实现登录和注册界面
Nov 19 Python
pytorch方法测试详解——归一化(BatchNorm2d)
Jan 15 Python
Python列表解析操作实例总结
Feb 26 Python
浅谈numpy中np.array()与np.asarray的区别以及.tolist
Jun 03 Python
一篇文章弄懂Python关键字、标识符和变量
Jul 15 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
Python类反射机制使用实例解析
Dec 30 #Python
Python读取YAML文件过程详解
Dec 30 #Python
You might like
PHP4和PHP5性能测试和对比 测试代码与环境
2007/08/17 PHP
PHP 输出缓存详解
2009/06/20 PHP
使用GROUP BY的时候如何统计记录条数 COUNT(*) DISTINCT
2011/04/23 PHP
解析PHP跨站刷票的实现代码
2013/06/18 PHP
php自定义session示例分享
2014/04/22 PHP
mod_php、FastCGI、PHP-FPM等PHP运行方式对比
2015/07/02 PHP
Yii框架创建cronjob定时任务的方法分析
2017/05/23 PHP
深入document.write()与HTML4.01的非成对标签的详解
2013/05/08 Javascript
js onclick事件传参讲解
2013/11/06 Javascript
不提示直接关闭网页窗口的JS示例代码
2013/12/17 Javascript
javascript圆盘抽奖程序实现原理和完整代码例子
2014/06/03 Javascript
JS上传图片前实现图片预览效果的方法
2015/03/02 Javascript
js实现简单的联动菜单效果
2015/08/19 Javascript
微信小程序 欢迎界面开发的实例详解
2016/11/30 Javascript
jQuery实现鼠标滑过图片移动特效
2016/12/08 Javascript
jQuery实现验证码功能
2017/03/17 Javascript
JSONP基础知识详解
2017/03/19 Javascript
基于Cesium绘制抛物弧线
2020/11/18 Javascript
使用python调用浏览器并打开一个网址的例子
2014/06/05 Python
浅谈Python中的闭包
2015/07/08 Python
网易有道2017内推编程题 洗牌(python)
2019/06/19 Python
Python文本文件的合并操作方法代码实例
2020/03/31 Python
pandas使用函数批量处理数据(map、apply、applymap)
2020/11/27 Python
Django数据模型中on_delete使用详解
2020/11/30 Python
python 利用matplotlib在3D空间中绘制平面的案例
2021/02/06 Python
基于HTML5超酷摄像头(HTML5 webcam)拍照功能实现代码
2012/12/13 HTML / CSS
html5 canvas手势解锁源码分享
2020/01/07 HTML / CSS
Crocs卡骆驰洞洞鞋日本官方网站:Crocs日本
2016/08/25 全球购物
Internet主要有哪些网络群组成
2015/12/24 面试题
师范生实习的个人自我鉴定
2013/10/20 职场文书
无故旷工检讨书
2014/01/26 职场文书
汽车促销活动方案
2014/03/31 职场文书
初三学生评语大全
2014/04/24 职场文书
答谢会策划方案
2014/05/12 职场文书
小学生学习雷锋倡议书
2014/05/15 职场文书
青年志愿者活动感想
2015/08/07 职场文书