Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
教你如何将 Sublime 3 打造成 Python/Django IDE开发利器
Jul 04 Python
Python中MYSQLdb出现乱码的解决方法
Oct 11 Python
python使用装饰器和线程限制函数执行时间的方法
Apr 18 Python
Python实现多进程共享数据的方法分析
Dec 04 Python
Python编程把二叉树打印成多行代码
Jan 04 Python
Pipenv一键搭建python虚拟环境的方法
May 22 Python
Sanic框架异常处理与中间件操作实例分析
Jul 16 Python
Python 中包/模块的 `import` 操作代码
Apr 22 Python
Django模型中字段属性choice使用说明
Mar 30 Python
初学者学习Python好还是Java好
May 26 Python
如何在Python3中使用telnetlib模块连接网络设备
Sep 21 Python
利用Python网络爬虫爬取各大音乐评论的代码
Apr 13 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
example2.php
2006/10/09 PHP
桌面中心(二)数据库写入
2006/10/09 PHP
php生成静态文件的多种方法分享
2012/07/17 PHP
php简单防盗链实现方法
2015/07/29 PHP
PHP+Ajax实现无刷新分页实例详解(附demo源码下载)
2016/04/07 PHP
ThinkPHP实现更新数据实例详解(demo)
2016/06/29 PHP
解决php写入数据库乱码的问题
2019/09/17 PHP
在视频前插入广告
2006/11/20 Javascript
解决FireFox下[使用event很麻烦]的问题
2006/11/26 Javascript
jQuery.extend 函数详解
2012/02/03 Javascript
根据邮箱的域名跳转到相应的登录页面的代码
2012/02/27 Javascript
indexOf 和 lastIndexOf 使用示例介绍
2014/09/02 Javascript
嵌入式iframe子页面与父页面js通信的方法
2015/01/20 Javascript
vue.js入门教程之基础语法小结
2016/09/01 Javascript
jQuery插件FusionCharts绘制的3D饼状图效果实例【附demo源码下载】
2017/03/03 Javascript
详解vue数据渲染出现闪烁问题
2017/06/29 Javascript
解决Vue使用mint-ui loadmore实现上拉加载与下拉刷新出现一个页面使用多个上拉加载后冲突问题
2017/11/07 Javascript
微信小程序下拉框组件使用方法详解
2018/12/28 Javascript
详解vue项目打包步骤
2019/03/29 Javascript
[02:34]肉山说——泡妞篇
2014/09/16 DOTA
Python中使用装饰器来优化尾递归的示例
2016/06/18 Python
Tornado高并发处理方法实例代码
2018/01/15 Python
selenium+python自动化测试环境搭建步骤
2019/06/03 Python
浅谈Python描述数据结构之KMP篇
2020/09/06 Python
详解CSS3选择器的使用方法汇总
2015/11/24 HTML / CSS
css3实现3D文本悬停改变效果的示例代码
2019/01/16 HTML / CSS
使用 HTML5 Canvas 制作水波纹效果点击图片就会触发
2014/09/15 HTML / CSS
Daniel Wellington官方海外旗舰店:丹尼尔惠灵顿DW手表
2018/02/22 全球购物
芭比波朗加拿大官方网站:Bobbi Brown Cosmetics CA
2020/11/05 全球购物
创联软件面试题笔试题
2012/10/07 面试题
留学自荐信的技巧
2013/10/17 职场文书
校庆接待方案
2014/03/18 职场文书
2014国庆节标语口号
2014/09/19 职场文书
2015年学校综合治理工作总结
2015/07/20 职场文书
2020年个人安全保证书参考模板
2020/01/08 职场文书
用python批量解压带密码的压缩包
2021/05/31 Python