Pytorch之保存读取模型实例


Posted in Python onDecember 30, 2019

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python进行稳定可靠的文件操作详解
Dec 31 Python
python判断字符串是否纯数字的方法
Nov 19 Python
简单讲解Python编程中namedtuple类的用法
Jun 21 Python
python虚拟环境的安装配置图文教程
Oct 20 Python
示例详解Python3 or Python2 两者之间的差异
Aug 23 Python
Python+OpenCV感兴趣区域ROI提取方法
Jan 10 Python
python广度优先搜索得到两点间最短路径
Jan 17 Python
详解Python中的format格式化函数的使用方法
Nov 20 Python
python实现大战外星人小游戏实例代码
Dec 26 Python
python实现调用摄像头并拍照发邮箱
Apr 27 Python
Python selenium模拟网页点击爬虫交管12123违章数据
May 26 Python
分享提高 Python 代码的可读性的技巧
Mar 03 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
Python类反射机制使用实例解析
Dec 30 #Python
Python读取YAML文件过程详解
Dec 30 #Python
You might like
php 获取mysql数据库信息代码
2009/03/12 PHP
Zend Framework开发入门经典教程
2016/03/23 PHP
PHP中str_split()函数的用法讲解
2019/04/11 PHP
JavaScript Timer实现代码
2010/02/17 Javascript
JavaScript版DateAdd和DateDiff函数代码
2012/03/01 Javascript
Flex通过JS获取客户端IP和计算机名的实例代码
2013/11/21 Javascript
利用jquery动画特效和css打造的侧边弹出垂直导航
2014/04/04 Javascript
浅谈javascript的分号的使用
2015/05/12 Javascript
jQuery Validate插件实现表单强大的验证功能
2015/12/18 Javascript
JavaScript 函数的执行过程
2016/05/09 Javascript
对Angular.js Controller如何进行单元测试
2016/10/25 Javascript
Chrome不支持showModalDialog模态对话框和无法返回returnValue问题的解决方法
2016/10/30 Javascript
Vue中使用的EventBus有生命周期
2018/07/12 Javascript
Vue模拟数据,实现路由进入商品详情页面的示例
2018/08/31 Javascript
p5.js绘制旋转的正方形
2019/10/23 Javascript
Webpack3+React16代码分割的实现
2021/03/03 Javascript
python实现kNN算法
2017/12/20 Python
Django权限机制实现代码详解
2018/02/05 Python
python中返回矩阵的行列方法
2018/04/04 Python
pandas通过loc生成新的列方法
2018/11/28 Python
Python3 SSH远程连接服务器的方法示例
2018/12/29 Python
python实现梯度下降和逻辑回归
2020/03/24 Python
python 实现逻辑回归
2020/12/30 Python
aec加密 php_php aes加密解密类(兼容php5、php7)
2021/03/14 PHP
基于HTML5的WebGL实现json和echarts图表展现在同一个界面
2017/10/26 HTML / CSS
Mio Skincare法国官网:身体紧致及孕期身体护理
2018/04/04 全球购物
超市促销实习自我鉴定
2013/09/23 职场文书
应用电子专业学生的自我评价
2013/10/16 职场文书
大四学生思想汇报
2014/01/13 职场文书
公司门卫岗位职责
2014/03/15 职场文书
找工作求职信
2014/07/07 职场文书
门店店长岗位职责
2015/04/14 职场文书
2015年事业单位工作总结
2015/04/27 职场文书
2016年入党心得体会范文
2016/01/23 职场文书
Java生成读取条形码和二维码的简单示例
2021/07/09 Java/Android
利用uni-app生成微信小程序的踩坑记录
2022/04/05 Javascript