Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于Python的身份证号码自动生成程序
Aug 15 Python
简单介绍Python中的JSON模块
Apr 08 Python
python清除指定目录内所有文件中script的方法
Jun 30 Python
在python中使用正则表达式查找可嵌套字符串组
Oct 24 Python
python+matplotlib实现礼盒柱状图实例代码
Jan 16 Python
python 读取txt,json和hdf5文件的实例
Jun 05 Python
pandas表连接 索引上的合并方法
Jun 08 Python
Python面向对象之类和对象实例详解
Dec 10 Python
由Python编写的MySQL管理工具代码实例
Apr 09 Python
python向字符串中添加元素的实例方法
Jun 28 Python
Python configparser模块操作代码实例
Jun 08 Python
Python如何合并多个字典或映射
Jul 24 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
用session做客户验证时的注意事项
2006/10/09 PHP
用PHP实现 上一篇、下一篇的代码
2012/09/29 PHP
PHP管理依赖(dependency)关系工具 Composer的自动加载(autoload)
2014/08/18 PHP
PHP实现 APP端微信支付功能
2018/06/22 PHP
Extjs列表详细信息窗口新建后自动加载解决方法
2010/04/02 Javascript
游览器中javascript的执行过程(图文)
2012/05/20 Javascript
Javascript中找到子元素在父元素内相对位置的代码
2012/07/21 Javascript
jQuery之自动完成组件的深入解析
2013/06/19 Javascript
extJS中常用的4种Ajax异步提交方式
2014/03/07 Javascript
深入理解javascript构造函数和原型对象
2014/09/23 Javascript
jQuery中ajax和post处理json的不同示例对比
2014/11/02 Javascript
jQuery中prevUntil()方法用法实例
2015/01/08 Javascript
jQuery实现简易的天天爱消除小游戏
2015/10/16 Javascript
Bootstrap入门书籍之(三)栅格系统
2016/02/17 Javascript
JS学习之表格的排序简单实例
2016/05/16 Javascript
JS正则截取两个字符串之间及字符串前后内容的方法
2017/01/06 Javascript
Angular17之Angular自定义指令详解
2018/01/21 Javascript
vue组件实现可搜索下拉框扩展
2020/10/23 Javascript
学习jQuery中的noConflict()用法
2018/09/28 jQuery
js判断浏览器的环境(pc端,移动端,还是微信浏览器)
2020/12/24 Javascript
python BeautifulSoup使用方法详解
2013/11/21 Python
python实现图片变亮或者变暗的方法
2015/06/01 Python
对Python闭包与延迟绑定的方法详解
2019/01/07 Python
python远程调用rpc模块xmlrpclib的方法
2019/01/11 Python
浅析Python 读取图像文件的性能对比
2019/03/07 Python
pytorch查看torch.Tensor和model是否在CUDA上的实例
2020/01/03 Python
详解django中Template语言
2020/02/22 Python
pycharm中import呈现灰色原因的解决方法
2020/03/04 Python
Python对象的属性访问过程详解
2020/03/05 Python
PyQt5 文本输入框自动补全QLineEdit的实现示例
2020/05/13 Python
python interpolate插值实例
2020/07/06 Python
10个示例带你掌握python中的元组
2020/11/23 Python
神话般的珠宝:Ross-Simons
2020/07/13 全球购物
俄罗斯达美乐比萨外送服务:Domino’s Pizza
2020/12/18 全球购物
毕业生求职简历的自我评价
2013/10/23 职场文书
管理部副部长岗位职责范文
2014/03/09 职场文书