TensorFlow 滑动平均的示例代码


Posted in Python onJune 19, 2018

滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。

1、滑动平均求解对象初始化

ema = tf.train.ExponentialMovingAverage(decay,num_updates)

参数decay

`shadow_variable = decay * shadow_variable + (1 - decay) * variable`

参数num_updates

`min(decay, (1 + num_updates) / (10 + num_updates))`

2、添加/更新变量

添加目标变量,为之维护影子变量

注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量

ema.apply([var0, var1])

3、获取影子变量值

这一步不需要定义图中,从影子变量集合中提取目标值

sess.run(ema.average([var0, var1]))

4、保存&载入影子变量

我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

保存影子变量

建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“

import tensorflow as tf 
if __name__ == "__main__": 

  v = tf.Variable(0.,name="v") 

  #设置滑动平均模型的系数 

  ema = tf.train.ExponentialMovingAverage(0.99) 

  #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量 

  op = ema.apply([v]) 

  #获取变量v的名字 

  print(v.name) 

  #v:0 

  #创建一个保存模型的对象 

  save = tf.train.Saver() 

  sess = tf.Session() 

  #初始化所有变量 

  init = tf.initialize_all_variables() 

  sess.run(init) 

  #给变量v重新赋值 

  sess.run(tf.assign(v,10)) 

  #应用平均滑动设置 

  sess.run(op) 

  #保存模型文件 

  save.save(sess,"./model.ckpt") 

  #输出变量v之前的值和使用滑动平均模型之后的值 

  print(sess.run([v,ema.average(v)])) 

  #[10.0, 0.099999905]

载入影子变量并映射到变量

利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此操作『TensorFlow』模型载入方法汇总

v = tf.Variable(1.,name="v") 

#定义模型对象 

saver = tf.train.Saver({"v/ExponentialMovingAverage":v}) 

sess = tf.Session() 

saver.restore(sess,"./model.ckpt") 

print(sess.run(v)) 

#0.0999999

这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。

使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

variables_to_restore函数的使用

v = tf.Variable(1.,name="v") 

#滑动模型的参数的大小并不会影响v的值 

ema = tf.train.ExponentialMovingAverage(0.99) 

print(ema.variables_to_restore()) 

#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} 

sess = tf.Session() 

saver = tf.train.Saver(ema.variables_to_restore()) 

saver.restore(sess,"./model.ckpt") 

print(sess.run(v)) 

#0.0999999

variables_to_restore会识别网络中的变量,并自动生成影子变量名。

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

5、官方文档例子

官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例

| Example usage when creating a training model:
 | 
 | ```python
 | # Create variables.
 | var0 = tf.Variable(...)
 | var1 = tf.Variable(...)
 | # ... use the variables to build a training model...
 | ...
 | # Create an op that applies the optimizer. This is what we usually
 | # would use as a training op.
 | opt_op = opt.minimize(my_loss, [var0, var1])
 | 
 | # Create an ExponentialMovingAverage object
 | ema = tf.train.ExponentialMovingAverage(decay=0.9999)
 | 
 | with tf.control_dependencies([opt_op]):
 |   # Create the shadow variables, and add ops to maintain moving averages
 |   # of var0 and var1. This also creates an op that will update the moving
 |   # averages after each training step. This is what we will use in place
 |   # of the usual training op.
 |   training_op = ema.apply([var0, var1])
 | 
 | ...train the model by running training_op...
 | ```

6、batch_normal的例子

和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):

  with tf.variable_scope(scope):

    # beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)

    # gamma = tf.get_variable(name='gamma', shape=[n_out],

    #             initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)

    batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')

    ema = tf.train.ExponentialMovingAverage(decay=decay)

 

    def mean_var_with_update():

      ema_apply_op = ema.apply([batch_mean,batch_var])

      with tf.control_dependencies([ema_apply_op]):

        return tf.identity(batch_mean),tf.identity(batch_var)

        # identity之后会把Variable转换为Tensor并入图中,

        # 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制

 

    mean,var = tf.cond(phase_train,

              mean_var_with_update,

              lambda: (ema.average(batch_mean),ema.average(batch_var)))

   normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)

  return normed

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python求列表交集的方法汇总
Nov 10 Python
python检测是文件还是目录的方法
Jul 03 Python
python制作简单五子棋游戏
Jun 18 Python
postman模拟访问具有Session的post请求方法
Jul 15 Python
python如何实现数据的线性拟合
Jul 19 Python
python使用pip安装模块出现ReadTimeoutError: HTTPSConnectionPool的解决方法
Oct 04 Python
python pyenv多版本管理工具的使用
Dec 23 Python
Python基于BeautifulSoup爬取京东商品信息
Jun 01 Python
Keras 数据增强ImageDataGenerator多输入多输出实例
Jul 03 Python
快速解释如何使用pandas的inplace参数的使用
Jul 23 Python
python四个坐标点对图片区域最小外接矩形进行裁剪
Jun 04 Python
Python爬虫实战之爬取京东商品数据并实实现数据可视化
Jun 07 Python
python3个性签名设计实现代码
Jun 19 #Python
TensorFlow 模型载入方法汇总(小结)
Jun 19 #Python
python3爬虫之设计签名小程序
Jun 19 #Python
Python GUI Tkinter简单实现个性签名设计
Jun 19 #Python
TensorFlow数据输入的方法示例
Jun 19 #Python
深入分析python中整型不会溢出问题
Jun 18 #Python
Python登录注册验证功能实现
Jun 18 #Python
You might like
php文件夹的创建与删除方法
2015/01/24 PHP
PHP实现合并discuz用户
2015/08/05 PHP
PHP 计算两个时间段之间交集的天数示例
2019/10/24 PHP
PHP实现基本留言板功能原理与步骤详解
2020/03/26 PHP
Jquery ajax不能解析json对象,报Invalid JSON错误的原因和解决方法
2010/03/27 Javascript
JavaScript中的style.display属性操作
2013/03/27 Javascript
页面按钮禁用与解除禁用的方法
2014/02/19 Javascript
小结Node.js中非阻塞IO和事件循环
2014/09/18 Javascript
JavaScript sup方法入门实例(把字符串显示为上标)
2014/10/20 Javascript
js实现鼠标划过给div加透明度的方法
2015/05/25 Javascript
Angular使用过滤器uppercase/lowercase实现字母大小写转换功能示例
2018/03/27 Javascript
微信小程序文章详情页跳转案例详解
2019/07/09 Javascript
JS实现容器模块左右拖动效果
2020/01/14 Javascript
[02:36]DOTA2英雄基础教程 帕格纳
2014/01/20 DOTA
pycharm 使用心得(五)断点调试
2014/06/06 Python
Python 常用 PEP8 编码规范详解
2017/01/22 Python
Python实现的朴素贝叶斯分类器示例
2018/01/06 Python
win10 64bit下python NLTK安装教程
2018/09/19 Python
python高级特性和高阶函数及使用详解
2018/10/17 Python
pyqt5 tablewidget 利用线程动态刷新数据的方法
2019/06/17 Python
利用Python的sympy包求解一元三次方程示例
2019/11/22 Python
关于numpy数组轴的使用详解
2019/12/05 Python
Python如何基于smtplib发不同格式的邮件
2019/12/30 Python
详解Anaconda安装tensorflow报错问题解决方法
2020/11/01 Python
塔吉特百货公司官网:Target
2017/04/27 全球购物
《要下雨了》教学反思
2014/02/17 职场文书
本科毕业自我鉴定
2014/03/20 职场文书
房产继承公证书
2014/04/09 职场文书
五年级学生评语大全
2014/12/26 职场文书
证婚人致辞精选
2015/07/28 职场文书
2016年清明节期间群众祭祀活动工作总结
2016/04/01 职场文书
2016年第29个世界无烟日宣传活动总结
2016/04/06 职场文书
Go语言切片前或中间插入项与内置copy()函数详解
2021/04/27 Golang
基于tensorflow权重文件的解读
2021/05/26 Python
Python 的演示平台支持 WSGI 接口的应用
2022/04/20 Python