TensorFlow入门使用 tf.train.Saver()保存模型


Posted in Python onApril 24, 2018

关于模型保存的一点心得

saver = tf.train.Saver(max_to_keep=3)

在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python程序中进行文件读取和写入操作的教程
Apr 28 Python
批处理与python代码混合编程的方法
May 19 Python
python将文本分每两行一组并保存到文件
Mar 19 Python
python实现csv格式文件转为asc格式文件的方法
Mar 23 Python
python dataframe常见操作方法:实现取行、列、切片、统计特征值
Jun 09 Python
Django框架的中的setting.py文件说明详解
Oct 15 Python
Python学习笔记之函数的参数和返回值的使用
Nov 20 Python
Python操作多维数组输出和矩阵运算示例
Nov 28 Python
Python测试Kafka集群(pykafka)实例
Dec 23 Python
Keras官方中文文档:性能评估Metrices详解
Jun 15 Python
python利用线程实现多任务
Sep 18 Python
python爬虫破解字体加密案例详解
Mar 02 Python
Python使用 Beanstalkd 做异步任务处理的方法
Apr 24 #Python
Windows上使用Python增加或删除权限的方法
Apr 24 #Python
python编写暴力破解zip文档程序的实例讲解
Apr 24 #Python
解决python删除文件的权限错误问题
Apr 24 #Python
python3+PyQt5实现自定义流体混合窗口部件
Apr 24 #Python
python3+PyQt5实现拖放功能
Apr 24 #Python
python3+PyQt5使用数据库表视图
Apr 24 #Python
You might like
PHP编码规范的深入探讨
2013/06/06 PHP
php简单实现无限分类树形列表的方法
2015/03/27 PHP
PHP版QQ互联OAuth示例代码分享
2015/07/05 PHP
PHP中explode函数和split函数的区别小结
2016/08/24 PHP
javascript &&和||运算法的另类使用技巧
2009/11/28 Javascript
表单切换,用回车键替换Tab健(不支持IE)
2011/07/20 Javascript
一行代码实现纯数据json对象的深度克隆实现思路
2013/01/09 Javascript
jquery实现背景墙聚光灯效果示例分享
2014/03/02 Javascript
Node.js安装教程和NPM包管理器使用详解
2014/08/16 Javascript
ECMAScript 5严格模式(Strict Mode)介绍
2015/03/02 Javascript
JavaScipt中栈的实现方法
2016/02/17 Javascript
JS图片等比例缩放方法完整示例
2016/08/03 Javascript
Windows环境下npm install 报错: operation not permitted, rename的解决方法
2016/09/26 Javascript
BootStrap按钮标签及基本样式
2016/11/23 Javascript
微信小程序之获取当前位置经纬度以及地图显示详解
2017/05/09 Javascript
基于打包工具Webpack进行项目开发实例
2018/05/29 Javascript
jQuery实现的鼠标拖动画矩形框示例【可兼容IE8】
2019/05/17 jQuery
前端Electron新手入门教程详解
2019/06/21 Javascript
React Native 混合开发多入口加载方式详解
2019/09/23 Javascript
Vue组件模板及组件互相引用代码实例
2020/03/11 Javascript
JS实现斐波那契数列的五种方式(小结)
2020/09/09 Javascript
在Python中操作字典之clear()方法的使用
2015/05/21 Python
Django日志模块logging的配置详解
2017/02/14 Python
Python数据类型之Set集合实例详解
2019/05/07 Python
Python提取PDF内容的方法(文本、图像、线条等)
2019/09/25 Python
Bench加拿大官方网站:英国城市服装品牌
2017/11/03 全球购物
应届生求职推荐信
2013/10/28 职场文书
函授毕业自我鉴定
2014/02/04 职场文书
军训自我鉴定范文
2014/02/13 职场文书
元旦联欢会感言
2014/03/04 职场文书
电影地道战观后感
2015/06/04 职场文书
毕业生自我鉴定范文
2019/05/13 职场文书
HTML基础详解(下)
2021/10/16 HTML / CSS
如何创建一个创建MySQL数据库中的datetime类型
2022/03/21 MySQL
python神经网络ResNet50模型
2022/05/06 Python
Android开发EditText禁止输入监听及InputFilter字符过滤
2022/06/10 Java/Android