Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
进一步了解Python中的XML 工具
Apr 13 Python
python中字典(Dictionary)用法实例详解
May 30 Python
python3.4用函数操作mysql5.7数据库
Jun 23 Python
Python中turtle作图示例
Nov 15 Python
python模仿网页版微信发送消息功能
Feb 24 Python
Python面向对象类编写细节分析【类,方法,继承,超类,接口等】
Jan 05 Python
Python3实现统计单词表中每个字母出现频率的方法示例
Jan 28 Python
详解python中list的使用
Mar 15 Python
python连接PostgreSQL数据库的过程详解
Sep 18 Python
使用 Supervisor 监控 Python3 进程方式
Dec 05 Python
Pytorch数据拼接与拆分操作实现图解
Apr 30 Python
Python如何实现爬取B站视频
May 20 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
一周学会PHP(视频)Http下载
2006/12/12 PHP
PHP获取当前文件所在目录 getcwd()函数
2009/05/13 PHP
PHP levenshtein()函数用法讲解
2019/03/08 PHP
Yii框架自定义数据库操作组件示例
2019/11/11 PHP
PHP7 其他修改
2021/03/09 PHP
屏蔽网页右键复制和ctrl+c复制的js代码
2013/01/04 Javascript
jquery 选取方法都有哪些
2014/05/18 Javascript
JS应用正则表达式转换大小写示例
2014/09/18 Javascript
jquery使用正则表达式验证email地址的方法
2015/01/22 Javascript
Nodejs初级阶段之express
2015/11/23 NodeJs
Google 地图控件集详解及实例代码
2016/08/06 Javascript
nodejs连接mongodb数据库实现增删改查
2016/12/01 NodeJs
vue2.0获取自定义属性的值
2017/03/28 Javascript
详解angularjs利用ui-route异步加载组件
2017/05/21 Javascript
AngularJS表单验证功能
2017/10/19 Javascript
vue2.0.js的多级联动选择器实现方法
2018/02/09 Javascript
vue.js2.0点击获取自己的属性和jquery方法
2018/02/23 jQuery
Vue单页应用引用单独的样式文件的两种方式
2018/03/30 Javascript
微信小程序中遇到的iOS兼容性问题小结
2018/11/14 Javascript
swiper自定义分页器的样式
2020/09/14 Javascript
[04:54]DOTA2-DPC中国联赛1月31日Recap集锦
2021/03/11 DOTA
wxpython学习笔记(推荐查看)
2014/06/09 Python
Python中urllib2模块的8个使用细节分享
2015/01/01 Python
Python 中迭代器与生成器实例详解
2017/03/29 Python
浅谈Python中的全局锁(GIL)问题
2019/01/11 Python
如何在mac环境中用python处理protobuf
2019/12/25 Python
Python结合Window计划任务监测邮件的示例代码
2020/08/05 Python
HTML5 File接口在web页面上使用文件下载
2017/02/27 HTML / CSS
Mountain Warehouse澳大利亚官网:欧洲家庭户外品牌倡导者
2016/11/20 全球购物
澳大利亚领先的武术用品和健身器材供应商:SMAI
2019/03/24 全球购物
俄罗斯品牌服装和鞋子在线商店:BRIONITY
2020/03/26 全球购物
上海方立数码笔试题
2013/10/18 面试题
加入学生会演讲稿
2014/04/24 职场文书
语文教师求职信范文
2015/03/20 职场文书
2015年城管个人工作总结
2015/05/15 职场文书
Redis 报错 error:NOAUTH Authentication required
2022/05/15 Redis