Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过ftplib登录到ftp服务器的方法
May 08 Python
Python操作RabbitMQ服务器实现消息队列的路由功能
Jun 29 Python
Python数据分析之双色球中蓝红球分析统计示例
Feb 03 Python
Python cookbook(数据结构与算法)从序列中移除重复项且保持元素间顺序不变的方法
Mar 13 Python
python list元素为tuple时的排序方法
Apr 18 Python
Pandas DataFrame 取一行数据会得到Series的方法
Nov 10 Python
Python3+OpenCV2实现图像的几何变换(平移、镜像、缩放、旋转、仿射)
May 13 Python
Python基础学习之函数方法实例详解
Jun 18 Python
OpenCV 边缘检测
Jul 10 Python
浅谈python多线程和多线程变量共享问题介绍
Apr 17 Python
浅析python中的del用法
Sep 02 Python
python之语音识别speech模块
Sep 09 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
PHP7 新特性详细介绍
2016/09/06 PHP
php的socket编程详解
2016/11/20 PHP
php删除一个路径下的所有文件夹和文件的方法
2018/02/07 PHP
PHP封装curl的调用接口及常用函数详解
2018/05/31 PHP
javascript 新浪背投广告实现代码
2009/07/07 Javascript
JS动态添加option和删除option(附实例代码)
2013/04/01 Javascript
javascript 上下banner替换具体实现
2013/11/14 Javascript
js/jQuery简单实现选项卡功能
2014/01/02 Javascript
使用jQuery快速解决input中placeholder值在ie中无法支持的问题
2014/01/02 Javascript
javascript中的循环语句for语句深入理解
2014/04/04 Javascript
JS获取iframe中marginHeight和marginWidth属性的方法
2015/04/01 Javascript
详解网站中图片日常使用以及优化手法
2017/01/09 Javascript
node.js实现复制文本到剪切板的功能
2017/01/23 Javascript
JavaScript 正则命名分组【推荐】
2018/06/07 Javascript
详解小程序原生使用ES7 async/await语法
2018/08/06 Javascript
javascript实现的字符串转换成数组操作示例
2019/06/13 Javascript
原生js添加一个或多个类名的方法分析
2019/07/30 Javascript
JS+CSS实现炫酷光感效果
2020/09/05 Javascript
[43:24]完美世界DOTA2联赛PWL S3 INK ICE vs DLG 第二场 12.12
2020/12/17 DOTA
python结合selenium获取XX省交通违章数据的实现思路及代码
2016/06/26 Python
Python实现的凯撒密码算法示例
2018/04/12 Python
python获取代理IP的实例分享
2018/05/07 Python
Python设计模式之建造者模式实例详解
2019/01/17 Python
用Python实现将一张图片分成9宫格的示例
2019/07/05 Python
Tensorflow实现神经网络拟合线性回归
2019/07/19 Python
python3.6连接mysql数据库及增删改查操作详解
2020/02/10 Python
基于Python采集爬取微信公众号历史数据
2020/11/27 Python
HTML5+CSS3实现机器猫
2016/10/17 HTML / CSS
localStorage的过期时间设置的方法详解
2018/11/26 HTML / CSS
英国最红的高街时尚品牌:Topshop
2016/08/05 全球购物
美国排名第一的葡萄酒俱乐部:Firstleaf Wine Club
2020/01/02 全球购物
外贸业务员岗位职责
2013/11/24 职场文书
英语自我评价范文
2014/01/24 职场文书
课堂教学改革实施方案
2014/03/17 职场文书
2014年安全生产工作总结
2014/11/13 职场文书
村官个人总结范文
2015/03/03 职场文书