TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python sort、sorted高级排序技巧
Nov 21 Python
python去除扩展名的实例讲解
Apr 23 Python
Python使用Dijkstra算法实现求解图中最短路径距离问题详解
May 16 Python
Python判断有效的数独算法示例
Feb 23 Python
对python中 math模块下 atan 和 atan2的区别详解
Jan 17 Python
python 实现线程之间的通信示例
Feb 14 Python
推荐8款常用的Python GUI图形界面开发框架
Feb 23 Python
django实现模板中的字符串文字和自动转义
Mar 31 Python
Spark处理数据排序问题如何避免OOM
May 21 Python
使用tkinter实现三子棋游戏
Feb 25 Python
Python包管理工具pip的15 个使用小技巧
May 17 Python
python对文档中元素删除,替换操作
Apr 02 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
[原创]php常用字符串输出方法分析(echo,print,printf及sprintf)
2016/07/09 PHP
PHP进制转换实例分析(2,8,16,36,64进制至10进制相互转换)
2017/02/04 PHP
php面向对象的用户登录身份验证
2017/06/08 PHP
PHP PDOStatement::errorCode讲解
2019/01/31 PHP
JavaScript中的Web worker多线程API研究
2014/12/06 Javascript
jQuery实现企业网站横幅焦点图切换功能实例
2015/04/30 Javascript
jQuery 翻页组件yunm.pager.js实现div局部刷新的思路
2016/08/11 Javascript
Angularjs 动态改变title标题(兼容ios)
2016/12/29 Javascript
BootStrap Fileinput插件和Bootstrap table表格插件相结合实现文件上传、预览、提交的导入Excel数据操作步骤
2017/08/07 Javascript
解决vue 更改计算属性后select选中值不更改的问题
2018/03/02 Javascript
vue集成百度UEditor富文本编辑器使用教程
2018/09/21 Javascript
Vux+Axios拦截器增加loading的问题及实现方法
2018/11/08 Javascript
layui 解决富文本框form表单提交为空的问题
2019/10/26 Javascript
关于element-ui表单中限制输入纯数字的解决方式
2020/09/08 Javascript
使用Vue实现一个树组件的示例
2020/11/06 Javascript
解决vue初始化项目一直停在downloading template的问题
2020/11/09 Javascript
如何使用七牛Python SDK写一个同步脚本及使用教程
2015/08/23 Python
python 网络爬虫初级实现代码
2016/02/27 Python
python多维数组切片方法
2018/04/13 Python
基于Django与ajax之间的json传输方法
2018/05/29 Python
python2与python3中关于对NaN类型数据的判断和转换方法
2018/10/30 Python
Python中那些 Pythonic的写法详解
2019/07/02 Python
python 使用pygame工具包实现贪吃蛇游戏(多彩版)
2019/10/30 Python
Tensorflow的常用矩阵生成方式
2020/01/04 Python
利用CSS3的transition属性实现滑动效果
2015/08/05 HTML / CSS
高性能钓鱼服装:Huk Gear
2019/02/20 全球购物
马来西亚排名第一的宠物用品店:Pets Wonderland
2020/04/16 全球购物
研究生毕业鉴定
2014/01/29 职场文书
个人委托书
2014/07/31 职场文书
三好学生先进事迹材料
2014/08/28 职场文书
单位员工收入证明样本
2014/10/09 职场文书
惊天动地观后感
2015/06/10 职场文书
采购部2015年度工作总结
2015/07/24 职场文书
七年级作文(600字3篇)
2019/09/24 职场文书
九大龙王魂骨,山龙王留下躯干骨,榜首死的最憋屈(被捏碎)
2022/03/18 国漫
MySQL的prepare使用以及遇到的bug
2022/05/11 MySQL