TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现下载文件的三种方法
Feb 09 Python
python爬虫headers设置后无效的解决方法
Oct 21 Python
Python面向对象编程基础解析(二)
Oct 26 Python
Python登录并获取CSDN博客所有文章列表代码实例
Dec 28 Python
python自动重试第三方包retrying模块的方法
Apr 24 Python
在python中bool函数的取值方法
Nov 01 Python
python 实现一次性在文件中写入多行的方法
Jan 28 Python
python实现定时压缩指定文件夹发送邮件
Dec 22 Python
Python发展史及网络爬虫
Jun 19 Python
python3连接kafka模块pykafka生产者简单封装代码
Dec 23 Python
基于Python获取照片的GPS位置信息
Jan 20 Python
keras训练浅层卷积网络并保存和加载模型实例
Jul 02 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
PhpMyAdmin中无法导入sql文件的解决办法
2010/01/08 PHP
PHP遍历某个目录下的所有文件和子文件夹的实现代码
2013/06/28 PHP
php将字符串转化成date存入数据库的两种方式
2014/04/28 PHP
PHP实现微信退款的方法示例
2019/03/26 PHP
测试你的JS的掌握程度的代码
2009/12/09 Javascript
javascript instanceof 内部机制探析
2010/10/15 Javascript
基于Jquery的仿照flash放大图片效果代码
2011/03/16 Javascript
从面试题学习Javascript 面向对象(创建对象)
2012/03/30 Javascript
JQuery select控件的相关操作实现代码
2012/09/14 Javascript
js中的前绑定和后绑定详解
2013/08/01 Javascript
jQuery判断checkbox(复选框)是否被选中以及全选、反选实现代码
2014/02/21 Javascript
javascript下拉框选项单击事件的例子分享
2015/03/04 Javascript
七夕情人节丘比特射箭小游戏
2015/08/20 Javascript
ANGULARJS中使用JQUERY分页控件
2015/09/16 Javascript
JS更改select内option属性的方法
2015/10/14 Javascript
使用jQuery操作HTML的table表格的实例解析
2016/03/13 Javascript
EasyUI 中combotree 默认不能选择父节点的实现方法
2016/11/07 Javascript
简单实现js悬浮导航效果
2017/02/05 Javascript
利用原生JS与jQuery实现数字线性变化的动画
2017/02/24 Javascript
vue 解决form表单提交但不跳转页面的问题
2019/10/30 Javascript
Python脚本实时处理log文件的方法
2016/11/21 Python
利用python实现简易版的贪吃蛇游戏(面向python小白)
2018/12/30 Python
django 读取图片到页面实例
2020/03/27 Python
Win 10下Anaconda虚拟环境的教程
2020/05/18 Python
基于opencv的selenium滑动验证码的实现
2020/07/24 Python
阿迪达斯荷兰官方网站:adidas荷兰
2018/03/16 全球购物
Lancome兰蔻官方旗舰店:来自法国的世界知名美妆品牌
2018/06/14 全球购物
香港最大的洋酒零售连锁店:屈臣氏酒窖(Watson’s Wine)
2018/12/10 全球购物
医学生个人求职信范文
2013/09/24 职场文书
医学生实习自荐信
2013/10/01 职场文书
中英文自我评价常用句型
2013/12/19 职场文书
五年级学生评语
2014/04/22 职场文书
第二批党的群众路线教育实践活动个人整改方案
2014/10/31 职场文书
小学中队活动总结
2015/05/11 职场文书
Oracle中update和select 关联操作
2022/01/18 Oracle
Go语言测试库testify使用学习
2022/07/23 Golang