对TensorFlow中的variables_to_restore函数详解


Posted in Python onJuly 30, 2018

variables_to_restore函数,是TensorFlow为滑动平均值提供。之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮。我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

1、滑动平均值模型文件的保存

import tensorflow as tf
 
if __name__ == "__main__":
 v = tf.Variable(0.,name="v")
 #设置滑动平均模型的系数
 ema = tf.train.ExponentialMovingAverage(0.99)
 #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
 op = ema.apply([v])
 #获取变量v的名字
 print(v.name)
 #v:0
 #创建一个保存模型的对象
 save = tf.train.Saver()
 sess = tf.Session()
 #初始化所有变量
 init = tf.initialize_all_variables()
 sess.run(init)
 #给变量v重新赋值
 sess.run(tf.assign(v,10))
 #应用平均滑动设置
 sess.run(op)
 #保存模型文件
 save.save(sess,"./model.ckpt")
 #输出变量v之前的值和使用滑动平均模型之后的值
 print(sess.run([v,ema.average(v)]))
 #[10.0, 0.099999905]

上面的代码,是如何来保存一个滑动平均值的模型文件,之前有介绍过滑动平均值和模型文件的保存,所以这里就不再重复了。

2、滑动平均值模型文件的读取

v = tf.Variable(1.,name="v")
 #定义模型对象
 saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
 sess = tf.Session()
 saver.restore(sess,"./model.ckpt")
 print(sess.run(v))
 #0.0999999

对于模型文件的读取,在上一篇博客中有介绍过,这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。是不是感觉使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

3、variables_to_restore函数的使用

v = tf.Variable(1.,name="v")
 #滑动模型的参数的大小并不会影响v的值
 ema = tf.train.ExponentialMovingAverage(0.99)
 print(ema.variables_to_restore())
 #{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
 sess = tf.Session()
 saver = tf.train.Saver(ema.variables_to_restore())
 saver.restore(sess,"./model.ckpt")
 print(sess.run(v))
 #0.0999999

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

以上这篇对TensorFlow中的variables_to_restore函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用Python获取操作系统信息实例
Sep 02 Python
Python 描述符(Descriptor)入门
Nov 20 Python
使用python遍历指定城市的一周气温
Mar 31 Python
人机交互程序 python实现人机对话
Nov 14 Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 Python
python交互模式下输入换行/输入多行命令的方法
Jul 02 Python
Django实现基于类的分页功能
Oct 31 Python
Python实现获取当前目录下文件名代码详解
Mar 10 Python
实现ECharts双Y轴左右刻度线一致的例子
May 16 Python
Python中格式化字符串的四种实现
May 26 Python
详解Python+OpenCV进行基础的图像操作
Feb 15 Python
Python实现文字pdf转换图片pdf效果
Apr 03 Python
Python实现模拟浏览器请求及会话保持操作示例
Jul 30 #Python
tensorflow 打印内存中的变量方法
Jul 30 #Python
Python实现的多叉树寻找最短路径算法示例
Jul 30 #Python
tensorflow: variable的值与variable.read_value()的值区别详解
Jul 30 #Python
Tensorflow 实现修改张量特定元素的值方法
Jul 30 #Python
python用BeautifulSoup库简单爬虫实例分析
Jul 30 #Python
对TensorFlow的assign赋值用法详解
Jul 30 #Python
You might like
PHP中捕获超时事件的方法实例
2015/02/12 PHP
CI框架源码解读之URI.php中_fetch_uri_string()函数用法分析
2016/05/18 PHP
PHP购物车类Cart.class.php定义与用法示例
2016/07/20 PHP
浅析PHP数据导出知识点
2018/02/17 PHP
PHP实现可精确验证身份证号码的工具类示例
2018/05/31 PHP
在JavaScript中实现类的方式探讨
2013/08/28 Javascript
js中的referrer返回上一页使用介绍
2013/09/26 Javascript
js二维数组排序的简单示例代码
2014/01/24 Javascript
基于jQuery全屏焦点图左右切换插件responsiveslides
2015/09/07 Javascript
详解JavaScript数组和字符串中去除重复值的方法
2016/03/07 Javascript
JavaScript 数组的深度复制解析
2016/11/02 Javascript
Bootstrap fileinput组件封装及使用详解
2017/03/10 Javascript
Angular2搜索和重置按钮过场动画
2017/05/24 Javascript
nodejs基础之多进程实例详解
2018/12/27 NodeJs
vue实现侧边栏导航效果
2019/10/21 Javascript
在NodeJs中使用node-schedule增加定时器任务的方法
2020/06/08 NodeJs
[42:24]完美世界DOTA2联赛循环赛 LBZS vs DM BO2第一场 11.01
2020/11/02 DOTA
Python入门_条件控制(详解)
2017/05/16 Python
python调用matlab的m自定义函数方法
2019/02/18 Python
react+django清除浏览器缓存的几种方法小结
2019/07/17 Python
Python爬取爱奇艺电影信息代码实例
2019/11/26 Python
Python3操作YAML文件格式方法解析
2020/04/10 Python
利用Canvas模仿百度贴吧客户端loading小球的方法示例
2017/08/13 HTML / CSS
html5使用canvas画一条线
2014/12/15 HTML / CSS
澳大利亚女装精品店:Alannah Hill
2020/07/29 全球购物
会计出纳岗位职责
2013/12/25 职场文书
初中体育教学反思
2014/01/14 职场文书
优秀团员个人事迹材料
2014/01/29 职场文书
社区工作感言
2014/02/21 职场文书
颐和园的导游词
2015/01/30 职场文书
惹女朋友生气检讨书
2015/05/06 职场文书
六五普法先进个人主要事迹材料
2015/11/03 职场文书
如何利用map实现Nginx允许多个域名跨域
2021/03/31 Servers
68行Python代码实现带难度升级的贪吃蛇
2022/01/18 Python
Win10防火墙白名单怎么设置?Win10添加防火墙白名单方法
2022/04/06 数码科技
修改Nginx配置返回指定content-type的方法
2022/09/23 Servers