TensorFlow入门使用 tf.train.Saver()保存模型


Posted in Python onApril 24, 2018

关于模型保存的一点心得

saver = tf.train.Saver(max_to_keep=3)

在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python里使用正则表达式的ASCII模式
Nov 02 Python
Python读取txt某几列绘图的方法
Oct 14 Python
在cmder下安装ipython以及环境的搭建
Oct 19 Python
python实现批量注册网站用户的示例
Feb 22 Python
浅析Python 中几种字符串格式化方法及其比较
Jul 02 Python
Python 字符串类型列表转换成真正列表类型过程解析
Aug 26 Python
Python迷宫生成和迷宫破解算法实例
Dec 24 Python
pytorch自定义二值化网络层方式
Jan 07 Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 Python
使用python将微信image下.dat文件解密为.png的方法
Nov 30 Python
python实现计算器简易版
Dec 17 Python
详细介绍python操作RabbitMq
Apr 12 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 #Python
Windows上使用Python增加或删除权限的方法
Apr 24 #Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 #Python
解决python删除文件的权限错误问题
Apr 24 #Python
python3+PyQt5实现自定义流体混合窗口部件
Apr 24 #Python
python3+PyQt5实现拖放功能
Apr 24 #Python
python3+PyQt5使用数据库表视图
Apr 24 #Python
You might like
采用PHP函数memory_get_usage获取PHP内存清耗量的方法
2011/12/06 PHP
php中的一些数组排序方法分享
2012/07/20 PHP
thinkPHP实现表单自动验证
2014/12/24 PHP
理解PHP中的Session及对Session有效期的控制
2016/01/08 PHP
PHP与Web页面的交互示例详解二
2020/08/04 PHP
jquery实现的超出屏幕时把固定层变为定位层的代码
2010/02/23 Javascript
Jquery结合HTML5实现文件上传
2015/06/25 Javascript
javascript单页面手势滑屏切换原理详解
2016/03/21 Javascript
jquery跟随屏幕滚动效果的实现代码
2016/04/13 Javascript
微信小程序 数组中的push与concat的区别
2017/01/05 Javascript
详解如何使用vue-cli脚手架搭建Vue.js项目
2017/05/19 Javascript
谈谈VUE种methods watch和compute的区别和联系
2017/08/01 Javascript
vue2.0 + element UI 中 el-table 数据导出Excel的方法
2018/03/02 Javascript
vue实现同一个页面可以有多个router-view的方法
2018/09/20 Javascript
JS根据json数组多个字段排序及json数组常用操作
2019/06/06 Javascript
详解微信小程序开发(项目从零开始)
2019/06/06 Javascript
JavaScript实现PC端四格密码输入框功能
2020/02/19 Javascript
Vue实现简单计算器案例
2020/02/25 Javascript
vue 解决data中定义图片相对路径页面不显示的问题
2020/08/13 Javascript
python返回昨天日期的方法
2015/05/13 Python
python numpy元素的区间查找方法
2018/11/14 Python
python 实现视频 图像帧提取
2019/12/10 Python
PyTorch的SoftMax交叉熵损失和梯度用法
2020/01/15 Python
python使用smtplib模块发送邮件
2020/12/17 Python
解决Python import .pyd 可能遇到路径的问题
2021/03/04 Python
canvas里面如何基于随机点绘制一个多边形的方法
2018/06/13 HTML / CSS
汉森冲浪板:Hansen Surfboards
2018/05/19 全球购物
法拉利英国精品店:Ferraris Boutique UK
2019/07/20 全球购物
白俄罗斯女装和针织品网上商店:Presli.by
2019/10/13 全球购物
简单租房协议书
2014/04/09 职场文书
音乐教师个人工作总结
2015/02/06 职场文书
中学教师个人总结
2015/02/10 职场文书
2015应届毕业生求职信范文
2015/03/20 职场文书
求职简历自荐信怎么写
2015/03/26 职场文书
Python实现8种常用抽样方法
2021/06/27 Python
Centos7中MySQL数据库使用mysqldump进行每日自动备份的编写
2021/08/02 MySQL