Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
全面解读Python Web开发框架Django
Jun 30 Python
python中实现php的var_dump函数功能
Jan 21 Python
python高手之路python处理excel文件(方法汇总)
Jan 07 Python
Python如何爬取实时变化的WebSocket数据的方法
Mar 09 Python
PyQt5 在label显示的图片中绘制矩形的方法
Jun 17 Python
Form表单及django的form表单的补充
Jul 25 Python
手机使用python操作图片文件(pydroid3)过程详解
Sep 25 Python
python的sys.path模块路径添加方式
Mar 09 Python
Python matplotlib读取excel数据并用for循环画多个子图subplot操作
Jul 14 Python
python+selenium 简易地疫情信息自动打卡签到功能的实现代码
Aug 22 Python
python import 上级目录的导入
Nov 03 Python
安装python依赖包psycopg2来调用postgresql的操作
Jan 01 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
ThinkPHP3.1新特性之动态设置自动完成和自动验证示例
2014/06/19 PHP
PHP实现长文章分页实例代码(附源码)
2016/02/03 PHP
Yii列表定义与使用分页方法小结(3种方法)
2016/07/15 PHP
PHP图像识别技术原理与实现
2016/10/27 PHP
JavaScript学习历程和心得小结
2010/08/16 Javascript
Extjs中的GridPanel隐藏列会显示在menuDisabled中解决方法
2013/01/27 Javascript
jquery win 7透明弹出层效果的简单代码
2013/08/06 Javascript
JS控制一个DIV层在指定时间内消失的方法
2014/02/17 Javascript
jQuery实现鼠标悬停显示提示信息窗口的方法
2015/04/30 Javascript
Bootstrap每天必学之栅格系统(布局)
2015/11/25 Javascript
jquery插件bootstrapValidator数据验证详解
2016/11/09 Javascript
JS实现线性表的链式表示方法示例【经典数据结构】
2017/04/11 Javascript
使用nodeJs来安装less及编译less文件为css文件的方法
2017/11/20 NodeJs
vue 实现axios拦截、页面跳转和token 验证
2018/07/17 Javascript
JS简单判断是否在微信浏览器打开的方法示例
2019/01/08 Javascript
详解Vuex下Store的模块化拆分实践
2019/07/31 Javascript
element-ui table组件如何使用render属性的实现
2019/11/04 Javascript
Python实时获取cmd的输出
2015/12/13 Python
Python实现批量更换指定目录下文件扩展名的方法
2016/09/19 Python
Python标准模块--ContextManager上下文管理器的具体用法
2017/11/27 Python
浅谈用VSCode写python的正确姿势
2017/12/16 Python
Python基于更相减损术实现求解最大公约数的方法
2018/04/04 Python
Python实现快速傅里叶变换的方法(FFT)
2018/07/21 Python
Python GUI编程学习笔记之tkinter中messagebox、filedialog控件用法详解
2020/03/30 Python
通过canvas转换颜色为RGBA格式及性能问题的解决
2019/11/22 HTML / CSS
台湾生鲜宅配:大口市集
2017/10/14 全球购物
在校硕士自我鉴定
2014/01/23 职场文书
电脑售后服务承诺书
2014/03/27 职场文书
排查整治工作方案
2014/06/09 职场文书
会计专业毕业生求职信
2014/07/04 职场文书
公安机关纪律作风整顿剖析
2014/10/10 职场文书
现场施工员岗位职责
2015/04/11 职场文书
六年级数学教学反思
2016/02/16 职场文书
七年级作文之关于奶奶
2019/10/29 职场文书
用golang如何替换某个文件中的字符串
2021/04/25 Golang
如何利用Python实现n*n螺旋矩阵
2022/01/18 Python