Pytorch之保存读取模型实例


Posted in Python onDecember 30, 2019

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字典get()方法用法分析
Apr 17 Python
Python爬虫框架Scrapy实战之批量抓取招聘信息
Aug 07 Python
Python实现针对含中文字符串的截取功能示例
Sep 22 Python
Python实现学生成绩管理系统
Apr 05 Python
使用50行Python代码从零开始实现一个AI平衡小游戏
Nov 21 Python
Python二叉树的镜像转换实现方法示例
Mar 06 Python
浅谈Python中eval的强大与危害
Mar 13 Python
python调用pyaudio使用麦克风录制wav声音文件的教程
Jun 26 Python
Django接收照片储存文件的实例代码
Mar 07 Python
Pycharm pyuic5实现将ui文件转为py文件,让UI界面成功显示
Apr 08 Python
spyder 在控制台(console)执行python文件,debug python程序方式
Apr 20 Python
Python实现socket库网络通信套接字
Jun 04 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
Python类反射机制使用实例解析
Dec 30 #Python
Python读取YAML文件过程详解
Dec 30 #Python
You might like
使用php get_headers 判断URL是否有效的解决办法
2013/04/27 PHP
解决ajax+php中文乱码的方法详解
2013/06/09 PHP
使用laravel根据用户类型来显示或隐藏字段
2019/10/17 PHP
jquery中:input和input的区别分析
2011/07/13 Javascript
Microsfot .NET Framework4.0框架 安装失败的解决方法
2013/08/14 Javascript
JavaScript将一个数组插入到另一个数组的方法
2015/03/19 Javascript
JavaScript小技巧整理
2015/12/30 Javascript
js css+html实现简单的日历
2016/07/14 Javascript
jquery实现图片放大点击切换
2017/06/06 jQuery
理解nodejs的stream和pipe机制的原理和实现
2017/08/12 NodeJs
VUE长按事件需求详解
2017/10/18 Javascript
es6在react中的应用代码解析
2017/11/08 Javascript
微信小程序与后台PHP交互的方法实例分析
2018/12/10 Javascript
动态实现element ui的el-table某列数据不同样式的示例
2021/01/22 Javascript
[02:14]完美“圣”典2016风云人物:xiao8专访
2016/12/01 DOTA
[55:03]完美世界DOTA2联赛PWL S2 LBZS vs FTD.C 第二场 11.20
2020/11/20 DOTA
Python urlopen()函数 示例分享
2014/06/12 Python
python实现批量下载新浪博客的方法
2015/06/15 Python
详解Python中的Cookie模块使用
2015/07/06 Python
Python中的数学运算操作符使用进阶
2016/06/20 Python
使用python存储网页上的图片实例
2018/05/22 Python
pyspark操作MongoDB的方法步骤
2019/01/04 Python
Python3.5基础之函数的定义与使用实例详解【参数、作用域、递归、重载等】
2019/04/26 Python
使用Matplotlib 绘制精美的数学图形例子
2019/12/13 Python
Python django框架开发发布会签到系统(web开发)
2020/02/12 Python
VSCode基础使用与VSCode调试python程序入门的图文教程
2020/03/30 Python
Python Pandas 对列/行进行选择,增加,删除操作
2020/05/17 Python
Python爬虫之Selenium多窗口切换的实现
2020/12/04 Python
CSS3实现淘宝留白的方法
2020/06/05 HTML / CSS
百联网上商城:i百联
2017/01/28 全球购物
您在慕尼黑的跑步商店:Lauf-bar
2019/10/11 全球购物
铁路工务反思材料
2014/02/07 职场文书
购房协议书范本(无房产证)
2014/10/07 职场文书
2014年工程工作总结
2014/11/25 职场文书
2016七夕情人节感言
2015/12/09 职场文书
Python如何使用循环结构和分支结构
2022/04/13 Python