TensorFlow的权值更新方法


Posted in Python onJune 14, 2018

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
仅用50行Python代码实现一个简单的代理服务器
Apr 08 Python
pygame学习笔记(3):运动速率、时间、事件、文字
Apr 15 Python
简单分析Python中用fork()函数生成的子进程
May 04 Python
Python中有趣在__call__函数
Jun 21 Python
Python for Informatics 第11章 正则表达式(一)
Apr 21 Python
Python正则简单实例分析
Mar 21 Python
Python实现的文本对比报告生成工具示例
May 22 Python
解决DataFrame排序sort的问题
Jun 07 Python
Python Django 添加首页尾页上一页下一页代码实例
Aug 21 Python
python虚拟环境模块venv使用及示例
Mar 04 Python
python 机器学习的标准化、归一化、正则化、离散化和白化
Apr 16 Python
matplotlib之pyplot模块实现添加子图subplot的使用
Apr 25 Python
python字符串常用方法
Jun 14 #Python
tensorflow 输出权重到csv或txt的实例
Jun 14 #Python
修复 Django migration 时遇到的问题解决
Jun 14 #Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 #Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 #Python
利用python对Excel中的特定数据提取并写入新表的方法
Jun 14 #Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 #Python
You might like
example2.php
2006/10/09 PHP
php登陆页的密码处理方式分享
2013/10/14 PHP
微信开发之网页授权获取用户信息(二)
2016/01/08 PHP
JavaScript 私有成员分析
2009/01/13 Javascript
jquery学习笔记二 实现可编辑的表格
2010/04/09 Javascript
为你的网站增加亮点的9款jQuery插件推荐
2011/05/03 Javascript
jQuery实现鼠标滚轮动态改变样式或效果
2015/01/05 Javascript
javascript框架设计之浏览器的嗅探和特征侦测
2015/06/23 Javascript
基于Node.js实现nodemailer邮件发送
2016/01/26 Javascript
详解JavaScript中基于原型prototype的继承特性
2016/05/05 Javascript
Angular页面间切换及传值的4种方法
2016/11/04 Javascript
js图片延迟加载(Lazyload)三种实现方式
2017/03/01 Javascript
Angular中自定义Debounce Click指令防止重复点击
2017/07/26 Javascript
js处理包含中文的字符串实例
2017/10/11 Javascript
vue中使用iview自定义验证关键词输入框问题及解决方法
2018/03/26 Javascript
关于微信公众号开发无法支付的问题解决
2018/12/28 Javascript
Vue 前端实现登陆拦截及axios 拦截器的使用
2019/07/17 Javascript
加速vue组件渲染之性能优化
2020/04/09 Javascript
Threejs实现滴滴官网首页地球动画功能
2020/07/13 Javascript
Python实现的径向基(RBF)神经网络示例
2018/02/06 Python
python调用c++返回带成员指针的类指针实例
2019/12/12 Python
python安装读取grib库总结(推荐)
2020/06/24 Python
Python加载数据的5种不同方式(收藏)
2020/11/13 Python
python 如何把docker-compose.yaml导入到数据库相关条目里
2021/01/15 Python
英国标志性奢侈品牌:Burberry
2016/07/28 全球购物
Senreve官网:美国旧金山的奢侈手袋品牌
2019/03/21 全球购物
哥德堡通行证:Gothenburg Pass
2019/12/09 全球购物
线程问题:wait()方法是定义在哪个类里面
2015/07/07 面试题
暑期教师培训方案
2014/06/07 职场文书
药店促销活动总结
2014/07/10 职场文书
设立有限责任公司出资协议书
2014/11/01 职场文书
2019职场单身人才调研报告:互联网行业单身比例最高
2019/08/07 职场文书
MySQL分库分表详情
2021/09/25 MySQL
python基础之函数的定义和调用
2021/10/24 Python
《辉夜大小姐想让我告白》第三季正式预告
2022/03/20 日漫
Mysql数据库事务的脏读幻读及不可重复读详解
2022/05/30 MySQL