TensorFlow 滑动平均的示例代码


Posted in Python onJune 19, 2018

滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。

1、滑动平均求解对象初始化

ema = tf.train.ExponentialMovingAverage(decay,num_updates)

参数decay

`shadow_variable = decay * shadow_variable + (1 - decay) * variable`

参数num_updates

`min(decay, (1 + num_updates) / (10 + num_updates))`

2、添加/更新变量

添加目标变量,为之维护影子变量

注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量

ema.apply([var0, var1])

3、获取影子变量值

这一步不需要定义图中,从影子变量集合中提取目标值

sess.run(ema.average([var0, var1]))

4、保存&载入影子变量

我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

保存影子变量

建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“

import tensorflow as tf 
if __name__ == "__main__": 

  v = tf.Variable(0.,name="v") 

  #设置滑动平均模型的系数 

  ema = tf.train.ExponentialMovingAverage(0.99) 

  #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量 

  op = ema.apply([v]) 

  #获取变量v的名字 

  print(v.name) 

  #v:0 

  #创建一个保存模型的对象 

  save = tf.train.Saver() 

  sess = tf.Session() 

  #初始化所有变量 

  init = tf.initialize_all_variables() 

  sess.run(init) 

  #给变量v重新赋值 

  sess.run(tf.assign(v,10)) 

  #应用平均滑动设置 

  sess.run(op) 

  #保存模型文件 

  save.save(sess,"./model.ckpt") 

  #输出变量v之前的值和使用滑动平均模型之后的值 

  print(sess.run([v,ema.average(v)])) 

  #[10.0, 0.099999905]

载入影子变量并映射到变量

利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此操作『TensorFlow』模型载入方法汇总

v = tf.Variable(1.,name="v") 

#定义模型对象 

saver = tf.train.Saver({"v/ExponentialMovingAverage":v}) 

sess = tf.Session() 

saver.restore(sess,"./model.ckpt") 

print(sess.run(v)) 

#0.0999999

这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。

使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

variables_to_restore函数的使用

v = tf.Variable(1.,name="v") 

#滑动模型的参数的大小并不会影响v的值 

ema = tf.train.ExponentialMovingAverage(0.99) 

print(ema.variables_to_restore()) 

#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} 

sess = tf.Session() 

saver = tf.train.Saver(ema.variables_to_restore()) 

saver.restore(sess,"./model.ckpt") 

print(sess.run(v)) 

#0.0999999

variables_to_restore会识别网络中的变量,并自动生成影子变量名。

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

5、官方文档例子

官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例

| Example usage when creating a training model:
 | 
 | ```python
 | # Create variables.
 | var0 = tf.Variable(...)
 | var1 = tf.Variable(...)
 | # ... use the variables to build a training model...
 | ...
 | # Create an op that applies the optimizer. This is what we usually
 | # would use as a training op.
 | opt_op = opt.minimize(my_loss, [var0, var1])
 | 
 | # Create an ExponentialMovingAverage object
 | ema = tf.train.ExponentialMovingAverage(decay=0.9999)
 | 
 | with tf.control_dependencies([opt_op]):
 |   # Create the shadow variables, and add ops to maintain moving averages
 |   # of var0 and var1. This also creates an op that will update the moving
 |   # averages after each training step. This is what we will use in place
 |   # of the usual training op.
 |   training_op = ema.apply([var0, var1])
 | 
 | ...train the model by running training_op...
 | ```

6、batch_normal的例子

和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):

  with tf.variable_scope(scope):

    # beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)

    # gamma = tf.get_variable(name='gamma', shape=[n_out],

    #             initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)

    batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')

    ema = tf.train.ExponentialMovingAverage(decay=decay)

 

    def mean_var_with_update():

      ema_apply_op = ema.apply([batch_mean,batch_var])

      with tf.control_dependencies([ema_apply_op]):

        return tf.identity(batch_mean),tf.identity(batch_var)

        # identity之后会把Variable转换为Tensor并入图中,

        # 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制

 

    mean,var = tf.cond(phase_train,

              mean_var_with_update,

              lambda: (ema.average(batch_mean),ema.average(batch_var)))

   normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)

  return normed

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python splitlines使用技巧
Sep 06 Python
Perl中著名的Schwartzian转换问题解决实现
Jun 02 Python
python通过加号运算符操作列表的方法
Jul 28 Python
Python全局变量用法实例分析
Jul 19 Python
python使用循环打印所有三位数水仙花数的实例
Nov 13 Python
对python 自定义协议的方法详解
Feb 13 Python
Python爬虫:url中带字典列表参数的编码转换方法
Aug 21 Python
pytorch中nn.Conv1d的用法详解
Dec 31 Python
python 的topk算法实例
Apr 02 Python
浅谈Python中的继承
Jun 19 Python
Python 操作 MySQL数据库
Sep 18 Python
Python必备技巧之函数的使用详解
Apr 04 Python
python3个性签名设计实现代码
Jun 19 #Python
TensorFlow 模型载入方法汇总(小结)
Jun 19 #Python
python3爬虫之设计签名小程序
Jun 19 #Python
Python GUI Tkinter简单实现个性签名设计
Jun 19 #Python
TensorFlow数据输入的方法示例
Jun 19 #Python
深入分析python中整型不会溢出问题
Jun 18 #Python
Python登录注册验证功能实现
Jun 18 #Python
You might like
星际中的相关伤害
2020/03/04 星际争霸
php中多维数组按指定value排序的实现代码
2014/08/19 PHP
PHP数组排序之sort、asort与ksort用法实例
2014/09/08 PHP
详解PHP中的外观模式facade pattern
2018/02/05 PHP
javaScript Array(数组)相关方法简述
2009/07/25 Javascript
(jQuery,mootools,dojo)使用适合自己的编程别名命名
2010/09/14 Javascript
Javascript实现简单的富文本编辑器附演示
2014/06/16 Javascript
JavaScript实现穷举排列(permutation)算法谜题解答
2014/12/29 Javascript
JavaScript中字符串(string)转json的2种方法
2015/06/25 Javascript
fullpage.js全屏滚动插件使用实例
2016/09/06 Javascript
NodeJS和BootStrap分页效果的实现代码
2016/11/07 NodeJs
ES6新特性二:Iterator(遍历器)和for-of循环详解
2017/04/20 Javascript
基于AngularJS实现表单验证功能
2017/07/28 Javascript
详解vue开发中调用微信jssdk的问题
2019/04/16 Javascript
Javascript生成器(Generator)的介绍与使用
2021/01/31 Javascript
[01:19:34]2014 DOTA2国际邀请赛中国区预选赛 New Element VS Dream time
2014/05/22 DOTA
[03:10]超级美酒第四天 fy拉比克秀 大合集
2018/06/05 DOTA
[49:11]完美世界DOTA2联赛PWL S3 INK ICE vs DLG 第二场 12.20
2020/12/23 DOTA
计算机二级python学习教程(2) python语言基本语法元素
2019/05/16 Python
pygame实现烟雨蒙蒙下彩虹雨
2019/11/11 Python
python实现连续变量最优分箱详解--CART算法
2019/11/22 Python
Python数据存储之 h5py详解
2019/12/26 Python
Python连接Oracle之环境配置、实例代码及报错解决方法详解
2020/02/11 Python
css3的focus-within选择器的使用
2020/05/11 HTML / CSS
举例详解HTML5中使用JSON格式提交表单
2015/06/16 HTML / CSS
周仰杰(JIMMY CHOO)英国官方网站:闻名世界的鞋子品牌
2018/10/28 全球购物
用Python匹配HTML tag的时候,<.*>和<.*?>有什么区别
2012/11/04 面试题
计算机工程学院个人求职信
2013/10/05 职场文书
中专毕业生自荐信
2013/11/16 职场文书
2014全国两会学习心得体会2000字
2014/03/10 职场文书
第一军规观后感
2015/06/12 职场文书
房产遗嘱范本
2015/08/06 职场文书
《圆的周长》教学反思
2016/02/17 职场文书
Windows11里微软已经将驱动程序安装位置A盘删除
2021/11/21 数码科技
Nginx 反向代理解决跨域问题多种情况分析
2022/01/18 Servers
Oracle使用别名的好处
2022/04/19 Oracle