TensorFlow入门使用 tf.train.Saver()保存模型


Posted in Python onApril 24, 2018

关于模型保存的一点心得

saver = tf.train.Saver(max_to_keep=3)

在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中多线程thread与threading的实现方法
Aug 18 Python
Python Web框架Flask下网站开发入门实例
Feb 08 Python
探究数组排序提升Python程序的循环的运行效率的原因
Apr 01 Python
简单理解Python中基于生成器的状态机
Apr 13 Python
在win和Linux系统中python命令行运行的不同
Jul 03 Python
关于Python 3中print函数的换行详解
Aug 08 Python
利用python实现.dcm格式图像转为.jpg格式
Jan 13 Python
Django Session和Cookie分别实现记住用户登录状态操作
Jul 02 Python
一篇文章教你用python画动态爱心表白
Nov 22 Python
Pandas中DataFrame交换列顺序的方法实现
Dec 14 Python
pytorch中Schedule与warmup_steps的用法说明
May 24 Python
Python查找算法的实现 (线性、二分,分块、插值查找算法)
Apr 24 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 #Python
Windows上使用Python增加或删除权限的方法
Apr 24 #Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 #Python
解决python删除文件的权限错误问题
Apr 24 #Python
python3+PyQt5实现自定义流体混合窗口部件
Apr 24 #Python
python3+PyQt5实现拖放功能
Apr 24 #Python
python3+PyQt5使用数据库表视图
Apr 24 #Python
You might like
第十三节 对象串行化 [13]
2006/10/09 PHP
不用GD库生成当前时间的PNG格式图象的程序
2006/10/09 PHP
PHP4实际应用经验篇(8)
2006/10/09 PHP
我的群发邮件程序
2006/10/09 PHP
如何利用PHP执行.SQL文件
2013/07/05 PHP
php使用session二维数组实例
2014/11/06 PHP
PHP在线调试执行的实现方法(附demo源码)
2016/04/28 PHP
php实现替换手机号中间数字为*号及隐藏IP最后几位的方法
2016/11/16 PHP
vmware linux系统安装最新的php7图解
2019/04/14 PHP
js 文件引入实现代码
2010/04/23 Javascript
JavaScript几种形式的树结构菜单
2010/05/10 Javascript
动态创建样式表在各浏览器中的差异测试代码
2011/09/13 Javascript
基于jQuery实现仿51job城市选择功能实例代码
2016/03/02 Javascript
webpack独立打包和缓存处理详解
2017/04/03 Javascript
ES6新特性六:promise对象实例详解
2017/04/21 Javascript
简单谈谈关于 npm 5.0 的新坑
2017/06/08 Javascript
jQuery实现的隔行变色功能【案例】
2019/02/18 jQuery
浅谈Vue CLI 3结合Lerna进行UI框架设计
2019/04/14 Javascript
vue-cli 项目打包完成后运行文件路径报错问题
2019/07/19 Javascript
pandas 数据实现行间计算的方法
2018/06/08 Python
pytorch中nn.Conv1d的用法详解
2019/12/31 Python
Python 捕获代码中所有异常的方法
2020/08/03 Python
python爬虫爬取网页数据并解析数据
2020/09/18 Python
预订全球最佳旅行体验:Viator
2018/03/30 全球购物
捷克鲜花配送:Florea.cz
2018/10/29 全球购物
大一期末自我鉴定
2013/12/13 职场文书
学员自我鉴定
2014/03/19 职场文书
个人承诺书
2014/03/26 职场文书
高等学院职业生涯规划书范文
2014/09/16 职场文书
先进人物事迹材料
2014/12/29 职场文书
女性健康知识讲座通知
2015/04/23 职场文书
2015年医药代表工作总结
2015/04/25 职场文书
委托收款证明
2015/06/23 职场文书
MySQL主从复制断开的常用修复方法
2021/04/07 MySQL
浏览器常用基本操作之python3+selenium4自动化测试(基础篇3)
2021/05/21 Python
MySQL 条件查询的常用操作
2022/04/28 MySQL