Pytorch之保存读取模型实例


Posted in Python onDecember 30, 2019

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用点操作符访问字典(dict)数据的方法
Mar 16 Python
Python代码实现KNN算法
Dec 20 Python
python机器学习实战之树回归详解
Dec 20 Python
对python周期性定时器的示例详解
Feb 19 Python
详解用python写网络爬虫-爬取新浪微博评论
May 10 Python
pyqt5 禁止窗口最大化和禁止窗口拉伸的方法
Jun 18 Python
Python 分享10个PyCharm技巧
Jul 13 Python
在notepad++中实现直接运行python代码
Dec 18 Python
解决flask接口返回的内容中文乱码的问题
Apr 03 Python
什么是python的列表推导式
May 26 Python
python中wheel的用法整理
Jun 15 Python
Python 使用SFTP和FTP实现对服务器的文件下载功能
Dec 17 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
Python类反射机制使用实例解析
Dec 30 #Python
Python读取YAML文件过程详解
Dec 30 #Python
You might like
PHP+MySQL 手工注入语句大全 推荐
2009/10/30 PHP
PHP5.4中json_encode中文转码的变化小结
2013/01/30 PHP
php一行代码获取文件后缀名实例分析
2014/11/12 PHP
PHP打开和关闭文件操作函数总结
2014/11/18 PHP
php中Redis的应用--消息传递
2017/03/28 PHP
thinkphp自定义权限管理之名称判断方法
2017/04/01 PHP
thinkPHP5实现的查询数据库并返回json数据实例
2017/10/23 PHP
PHP实现的字符串匹配算法示例【sunday算法】
2017/12/19 PHP
JQuery 技巧和窍门整理(8个)
2010/04/22 Javascript
教你如何在 Javascript 文件里使用 .Net MVC Razor 语法
2014/07/23 Javascript
Json实现异步请求提交评论无需跳转其他页面
2014/10/11 Javascript
javascript实现依次输入input自动定焦
2014/12/23 Javascript
javascript实现瀑布流自适应遇到的问题及解决方案
2015/01/28 Javascript
常见JS验证脚本汇总
2015/12/01 Javascript
js实现将json数组显示前台table中
2017/01/10 Javascript
select下拉框插件jquery.editable-select详解
2017/01/22 Javascript
weui框架实现上传、预览和删除图片功能代码
2017/08/24 Javascript
原生JS实现多个小球碰撞反弹效果示例
2018/01/31 Javascript
vue项目强制清除页面缓存的例子
2019/11/06 Javascript
详细介绍Python中的偏函数
2015/04/27 Python
Python脚本获取操作系统版本信息
2016/12/17 Python
Python学习入门之区块链详解
2017/07/25 Python
python中requests库session对象的妙用详解
2017/10/30 Python
Python cookbook(数据结构与算法)实现优先级队列的方法示例
2018/02/18 Python
Python3 翻转二叉树的实现
2019/09/30 Python
Python利用逻辑回归分类实现模板
2020/02/15 Python
Python如何实现小程序 无限求和平均
2020/02/18 Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
2020/06/23 Python
春秋航空官方网站:Spring Airlines
2017/09/27 全球购物
世界顶级俱乐部的官方球衣和套装:Subside Sports
2018/04/22 全球购物
意大利巧克力店:Chocolate Shop
2019/07/24 全球购物
Dr. Martens马汀博士法国官网:马丁靴鼻祖
2020/01/15 全球购物
授权委托书样本
2014/04/03 职场文书
考试作弊检讨
2015/01/27 职场文书
使用ORM新增数据在Mysql中的操作步骤
2021/07/26 MySQL
postgreSQL数据库基础知识介绍
2022/04/12 PostgreSQL