TensorFlow入门使用 tf.train.Saver()保存模型


Posted in Python onApril 24, 2018

关于模型保存的一点心得

saver = tf.train.Saver(max_to_keep=3)

在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python urlopen 使用小示例
Sep 06 Python
Python编程中装饰器的使用示例解析
Jun 20 Python
python做量化投资系列之比特币初始配置
Jan 23 Python
Python编程中NotImplementedError的使用方法
Apr 21 Python
用Django写天气预报查询网站
Oct 21 Python
Python3.0中普通方法、类方法和静态方法的比较
May 03 Python
用Python实现二叉树、二叉树非递归遍历及绘制的例子
Aug 09 Python
python3中numpy函数tile的用法详解
Dec 04 Python
TensorBoard 计算图的可视化实现
Feb 15 Python
Python os模块常用方法和属性总结
Feb 20 Python
Python使用Matlab命令过程解析
Jun 04 Python
python读取excel数据绘制简单曲线图的完整步骤记录
Oct 30 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 #Python
Windows上使用Python增加或删除权限的方法
Apr 24 #Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 #Python
解决python删除文件的权限错误问题
Apr 24 #Python
python3+PyQt5实现自定义流体混合窗口部件
Apr 24 #Python
python3+PyQt5实现拖放功能
Apr 24 #Python
python3+PyQt5使用数据库表视图
Apr 24 #Python
You might like
PHP生成便于打印的网页
2006/10/09 PHP
PHP 模板高级篇总结
2006/12/21 PHP
ajax实现无刷新分页(php)
2010/07/18 PHP
php的memcached客户端memcached
2011/06/14 PHP
php shell超强免杀、减少体积工具实现代码
2012/10/16 PHP
PHP JS Ip地址及域名格式检测代码
2013/09/27 PHP
PHP里面把16进制的图片数据显示在html的img标签上(实现方法)
2017/05/02 PHP
thinkPHP框架动态配置用法实例分析
2018/06/14 PHP
PHP中常用的三种设计模式详解【单例模式、工厂模式、观察者模式】
2019/06/14 PHP
JS 控制非法字符的输入代码
2009/12/04 Javascript
Jquery Ajax学习实例6 向WebService发出请求,返回DataSet(XML) 异步调用
2010/03/18 Javascript
javascript 三种方法实现获得和设置以及移除元素属性
2013/03/20 Javascript
javascript 树形导航菜单实例代码
2013/08/13 Javascript
提取字符串中年月日的函数代码
2013/11/05 Javascript
JS实现将人民币金额转换为大写的示例代码
2014/02/13 Javascript
将List对象列表转换成JSON格式的类实现方法
2016/07/04 Javascript
浅谈JQ中mouseover和mouseenter的区别
2016/09/13 Javascript
jquery easyui validatebox remote的使用详解
2016/11/09 Javascript
vue 动态绑定背景图片的方法
2018/08/10 Javascript
JS实现吸顶特效
2020/01/08 Javascript
[26:52]LGD vs EG 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/18 DOTA
python端口扫描系统实现方法
2014/11/19 Python
在Django中进行用户注册和邮箱验证的方法
2016/05/09 Python
Python爬虫动态ip代理防止被封的方法
2019/07/07 Python
python basemap 画出经纬度并标定的实例
2019/07/09 Python
python pygame实现挡板弹球游戏
2019/11/25 Python
Keras 实现加载预训练模型并冻结网络的层
2020/06/15 Python
详解HTML5中垂直上下居中的解决方案
2017/12/20 HTML / CSS
美国二手奢侈品寄售网站:TheRealReal
2016/10/29 全球购物
Funko官方商店:源自美国,畅销全球搪胶收藏玩偶
2018/09/15 全球购物
大学自主招生自荐信范文
2014/02/26 职场文书
物业管理专业求职信
2014/06/11 职场文书
党员创先争优心得体会
2014/09/11 职场文书
导游词之藏龙百瀑景区
2019/12/30 职场文书
Oracle数据库中通用的函数实例详解
2022/03/25 Oracle
SpringBoot使用ip2region获取地理位置信息的方法
2022/06/21 Java/Android