Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 打印出所有的对象/模块的属性(实例代码)
Sep 11 Python
Python 包含汉字的文件读写之每行末尾加上特定字符
Dec 12 Python
Python之日期与时间处理模块(date和datetime)
Feb 16 Python
Python中如何优雅的合并两个字典(dict)方法示例
Aug 09 Python
使用apidocJs快速生成在线文档的实例讲解
Feb 07 Python
python实现windows下文件备份脚本
May 27 Python
Python3.6+Django2.0以上 xadmin站点的配置和使用教程图解
Jun 04 Python
Python中请不要再用re.compile了
Jun 30 Python
tensorflow之获取tensor的shape作为max_pool的ksize实例
Jan 04 Python
python下对hsv颜色空间进行量化操作
Jun 04 Python
用Python编写简单的gRPC服务的详细过程
Jul 04 Python
python计算列表元素与乘积详情
Aug 05 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
PHP实现邮件群发的源码
2013/06/18 PHP
PHP无限分类(树形类)
2013/09/28 PHP
php获取文件夹路径内的图片以及分页显示示例
2014/03/11 PHP
PHP调用VC编写的COM组件实例
2014/03/29 PHP
PHP strcmp()和strcasecmp()的区别实例
2016/11/05 PHP
php图形jpgraph操作实例分析
2017/02/22 PHP
PHP连接MYSQL数据库的3种常用方法
2017/02/27 PHP
使用Laravel中的查询构造器实现增删改查功能
2019/09/03 PHP
在线游戏大家来找茬II
2006/09/30 Javascript
学习JavaScript的最佳方法分享
2011/10/21 Javascript
JQuery插件Style定制化方法的分析与比较
2012/05/03 Javascript
解决js数据包含加号+通过ajax传到后台时出现连接错误
2013/08/01 Javascript
js 文本滚动效果的实例代码
2013/08/17 Javascript
javacript使用break内层跳出外层循环分析
2015/01/12 Javascript
win7下安装配置node.js+express开发环境
2015/12/06 Javascript
分分钟玩转Vue.js组件(二)
2017/03/01 Javascript
深入理解Promise.all
2018/08/08 Javascript
9102年webpack4搭建vue项目的方法步骤
2019/02/20 Javascript
通过js随机函数Math.random实现乱序
2020/05/19 Javascript
微信小游戏中three.js离屏画布的示例代码
2020/10/12 Javascript
简单的Apache+FastCGI+Django配置指南
2015/07/22 Python
python中的格式化输出用法总结
2016/07/28 Python
django基础之数据库操作方法(详解)
2017/05/24 Python
详解Python3 中hasattr()、getattr()、setattr()、delattr()函数及示例代码数
2018/04/18 Python
Django给admin添加Action的步骤详解
2019/05/01 Python
selenium与xpath之获取指定位置的元素的实现
2021/01/26 Python
Java里面有没有全局变量?为什么?
2015/02/06 面试题
非常详细的C#面试题集
2016/07/13 面试题
户外活动策划方案
2014/03/12 职场文书
博士毕业生自我鉴定范文
2014/04/13 职场文书
政治学求职信
2014/06/03 职场文书
大班亲子运动会方案
2014/06/10 职场文书
防灾减灾日活动总结
2014/08/26 职场文书
采购内勤岗位职责
2015/04/13 职场文书
小学运动会入场口号
2015/12/24 职场文书
Apache Hudi 加速传统的批处理模式
2022/04/24 Servers