TensorFlow入门使用 tf.train.Saver()保存模型


Posted in Python onApril 24, 2018

关于模型保存的一点心得

saver = tf.train.Saver(max_to_keep=3)

在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python变量不能以数字打头详解
Jul 06 Python
Python错误提示:[Errno 24] Too many open files的分析与解决
Feb 16 Python
python列表的增删改查实例代码
Jan 30 Python
使用python装饰器计算函数运行时间的实例
Apr 21 Python
Python之pymysql的使用小结
Jul 01 Python
Django后端接收嵌套Json数据及解析详解
Jul 17 Python
Python流程控制 while循环实现解析
Sep 02 Python
Python程序控制语句用法实例分析
Jan 14 Python
Python selenium使用autoIT上传附件过程详解
May 26 Python
关于Python不换行输出和不换行输出end=““不显示的问题(亲测已解决)
Oct 27 Python
Python Web项目Cherrypy使用方法镜像
Nov 05 Python
Django Admin后台模型列表页面如何添加自定义操作按钮
Nov 11 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 #Python
Windows上使用Python增加或删除权限的方法
Apr 24 #Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 #Python
解决python删除文件的权限错误问题
Apr 24 #Python
python3+PyQt5实现自定义流体混合窗口部件
Apr 24 #Python
python3+PyQt5实现拖放功能
Apr 24 #Python
python3+PyQt5使用数据库表视图
Apr 24 #Python
You might like
TP(thinkPHP)框架多层控制器和多级控制器的使用示例
2018/06/13 PHP
PHP count_chars()函数讲解
2019/02/14 PHP
javaScript 读取和设置文档元素的样式属性
2009/04/14 Javascript
js获取指定日期周数以及星期几的小例子
2014/06/27 Javascript
Firefox下无法正常显示年份的解决方法
2014/09/04 Javascript
JS实现仿PS的调色板效果完整实例
2016/12/21 Javascript
Bootstrap学习笔记之进度条、媒体对象实例详解
2017/03/09 Javascript
JS之if语句对接事件动作逻辑(详解)
2017/06/28 Javascript
使用vux实现上拉刷新功能遇到的坑
2018/02/08 Javascript
微信小程序实现的3d轮播图效果示例【基于swiper组件】
2018/12/11 Javascript
[49:07]VGJ.T vs Optic Supermajor小组赛D组 BO3 第二场 6.3
2018/06/04 DOTA
python根据开头和结尾字符串获取中间字符串的方法
2015/03/26 Python
Python安装第三方库的3种方法
2015/06/21 Python
python web基础之加载静态文件实例
2018/03/20 Python
python编写暴力破解zip文档程序的实例讲解
2018/04/24 Python
python3+pyqt5+itchat微信定时发送消息的方法
2019/02/20 Python
基于wxPython的GUI实现输入对话框(2)
2019/02/27 Python
cProfile Python性能分析工具使用详解
2019/07/22 Python
Python3爬虫中关于中文分词的详解
2020/07/29 Python
Canvas引入跨域的图片导致toDataURL()报错的问题的解决
2018/09/19 HTML / CSS
介绍一下EJB的分类及其各自的功能及应用
2016/08/23 面试题
会展中心部门工作职责
2013/11/27 职场文书
客服专员岗位职责范本
2013/11/29 职场文书
员工薪酬福利制度
2014/01/17 职场文书
司机岗位职责说明书
2014/07/29 职场文书
我的中国心演讲稿
2014/09/04 职场文书
计算机实训报告总结
2014/11/05 职场文书
2014年政协委员工作总结
2014/12/01 职场文书
赤壁观后感(2)
2015/06/15 职场文书
大学生暑假实习总结
2015/07/13 职场文书
2016年度基层党建工作公开承诺书
2016/03/25 职场文书
感恩信:写给爸爸妈妈的一封感谢信
2019/09/12 职场文书
redis连接被拒绝的解决方案
2021/04/12 Redis
浅谈如何提高PHP代码质量之单元测试
2021/05/28 PHP
Linux系统下安装PHP7.3版本
2021/06/26 PHP
Django中session进行权限管理的使用
2021/07/09 Python