Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django 添加静态文件的两种实现方法(必看篇)
Jul 14 Python
numpy中实现ndarray数组返回符合特定条件的索引方法
Apr 17 Python
python3利用venv配置虚拟环境及过程中的小问题小结
Aug 01 Python
对python3.4 字符串转16进制的实例详解
Jun 12 Python
python用pip install时安装失败的一系列问题及解决方法
Feb 24 Python
Windows10+anacond+GPU+pytorch安装详细过程
Mar 24 Python
python函数调用,循环,列表复制实例
May 03 Python
Python3创建Django项目的几种方法(3种)
Jun 03 Python
Django xadmin安装及使用详解
Oct 26 Python
Python趣味挑战之教你用pygame画进度条
May 31 Python
python自动计算图像数据集的RGB均值
Jun 18 Python
5行Python代码实现一键批量扣图
Jun 29 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
php实现的支持断点续传的文件下载类
2014/09/23 PHP
php计算title标题相似比的方法
2015/07/29 PHP
php安装php_rar扩展实现rar文件读取和解压的方法
2016/11/17 PHP
php进程daemon化的正确实现方法
2018/09/06 PHP
PHP实现微信提现功能
2018/09/30 PHP
PHP实现会员账号单唯一登录的方法分析
2019/03/07 PHP
javascript字符串拼接的效率问题
2010/12/25 Javascript
运用JQuery的toggle实现网页加载完成自动弹窗
2014/03/18 Javascript
Ajax局部更新导致JS事件重复触发问题的解决方法
2014/10/14 Javascript
jQuery中clearQueue()方法用法实例
2014/12/29 Javascript
javascript图片切换综合实例(循环切换、顺序切换)
2016/01/13 Javascript
jQuery模仿阿里云购买服务器选择购买时间长度的代码
2016/04/29 Javascript
基于jQuery实现的单行公告活动轮播效果
2017/08/23 jQuery
vuex与组件联合使用的方法
2018/05/10 Javascript
JavaScript中的执行环境和作用域链
2020/09/04 Javascript
[45:34]完美世界DOTA2联赛PWL S3 Rebirth vs CPG 第一场 12.18
2020/12/19 DOTA
按日期打印Python的Tornado框架中的日志的方法
2015/05/02 Python
Python 根据日志级别打印不同颜色的日志的方法示例
2019/08/08 Python
python通过txt文件批量安装依赖包的实现步骤
2019/08/13 Python
TensorFlow tf.nn.conv2d_transpose是怎样实现反卷积的
2020/04/20 Python
Python实现aes加密解密多种方法解析
2020/05/15 Python
在python中使用pyspark读写Hive数据操作
2020/06/06 Python
keras.utils.to_categorical和one hot格式解析
2020/07/02 Python
解决python运行效率不高的问题
2020/07/20 Python
HTML5 Canvas画线技巧——实现绘制一个像素宽的细线
2013/08/02 HTML / CSS
健康监测猫砂:Pretty Litter
2017/05/25 全球购物
澳大利亚首屈一指的鞋类品牌:Tony Bianco
2018/03/13 全球购物
Java Servlet的主要功能和作用是什么
2014/02/14 面试题
农场厂长岗位职责
2013/12/28 职场文书
工作总结与自我评价
2014/09/18 职场文书
小学五一劳动节活动总结
2015/02/09 职场文书
以权谋私检举信范文
2015/03/02 职场文书
2015年监理工作总结范文
2015/04/07 职场文书
2016优秀班主任个人先进事迹材料
2016/02/26 职场文书
Django实现翻页的示例代码
2021/05/24 Python
python 学习GCN图卷积神经网络
2022/05/11 Python