TensorFlow入门使用 tf.train.Saver()保存模型


Posted in Python onApril 24, 2018

关于模型保存的一点心得

saver = tf.train.Saver(max_to_keep=3)

在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
go和python调用其它程序并得到程序输出
Feb 10 Python
深入理解 Python 中的多线程 新手必看
Nov 20 Python
利用python求相邻数的方法示例
Aug 18 Python
Python使用matplotlib绘图无法显示中文问题的解决方法
Mar 14 Python
使用Python处理BAM的方法
Sep 28 Python
Python解决线性代数问题之矩阵的初等变换方法
Dec 12 Python
python 在右键菜单中加入复制目标文件的有效存放路径(单斜杠或者双反斜杠)
Apr 08 Python
基于python连接oracle导并出数据文件
Apr 28 Python
python中Ansible模块的Playbook的具体使用
May 28 Python
python PIL模块的基本使用
Sep 29 Python
用python爬虫批量下载pdf的实现
Dec 01 Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 #Python
Windows上使用Python增加或删除权限的方法
Apr 24 #Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 #Python
解决python删除文件的权限错误问题
Apr 24 #Python
python3+PyQt5实现自定义流体混合窗口部件
Apr 24 #Python
python3+PyQt5实现拖放功能
Apr 24 #Python
python3+PyQt5使用数据库表视图
Apr 24 #Python
You might like
PHP一些有意思的小区别
2006/12/06 PHP
PHP判断一个gif图片是否为动态图片的方法
2014/11/19 PHP
PHP网站开发中常用的8个小技巧
2015/02/13 PHP
phpStorm2020 注册码
2020/09/17 PHP
html读出文本文件内容
2007/01/22 Javascript
一些常用且实用的原生JavaScript函数
2010/09/08 Javascript
javascript高级程序设计第二版第十二章事件要点总结(常用的跨浏览器检测方法)
2012/08/22 Javascript
javascript对下拉列表框(select)的操作实例讲解
2013/11/29 Javascript
基于豆瓣API+Angular开发的web App
2015/01/02 Javascript
js实现按钮颜色渐变动画效果
2015/08/20 Javascript
js判断主流浏览器类型和版本号的简单实现代码
2016/05/26 Javascript
jQuery实现查找最近父节点的方法
2016/06/23 Javascript
通过网页查看JS源码中汉字显示乱码的解决方法
2016/10/26 Javascript
Jquery鼠标放上去显示全名的实现方法
2017/02/06 Javascript
javascript实现圣旨卷轴展开效果(代码分享)
2017/03/23 Javascript
Bootstrap DateTime Picker日历控件简单应用
2017/03/25 Javascript
JS优化与惰性载入函数实例分析
2017/04/06 Javascript
解决使用Vue.js显示数据的时,页面闪现原始代码的问题
2018/02/11 Javascript
Vue中v-show添加表达式的问题(判断是否显示)
2018/03/26 Javascript
vue模块拖拽实现示例代码
2019/03/09 Javascript
vue 路由meta 设置导航隐藏与显示功能的示例代码
2020/09/04 Javascript
[38:39]KG vs Mineski 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/16 DOTA
Python实现输出程序执行进度百分比的方法
2017/09/16 Python
python运用pygame库实现双人弹球小游戏
2019/11/25 Python
基于python plotly交互式图表大全
2019/12/07 Python
python 解决tqdm模块不能单行显示的问题
2020/02/19 Python
Python爬虫教程知识点总结
2020/10/19 Python
台湾最大银发乐活百货:乐龄网
2018/05/21 全球购物
Ego Shoes官网:英国时髦鞋类品牌
2020/10/19 全球购物
十佳护士先进事迹
2014/05/08 职场文书
奉献爱心演讲稿
2014/09/04 职场文书
重阳节演讲稿:尊敬帮助老人 弘扬传统美德
2014/09/25 职场文书
城镇居民医疗保险工作总结
2015/08/10 职场文书
(开源)微信小程序+mqtt,esp8266温湿度读取
2021/04/02 Javascript
OpenCV-Python模板匹配人眼的实例
2021/06/08 Python
日本动漫十大公认神作:第五现已全网禁播,《死亡笔记》在榜
2022/03/18 日漫