Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的tab文件操作类分享
Nov 20 Python
利用Python画ROC曲线和AUC值计算
Sep 19 Python
python中实现数组和列表读取一列的方法
Apr 03 Python
python pexpect ssh 远程登录服务器的方法
Feb 14 Python
Python reshape的用法及多个二维数组合并为三维数组的实例
Feb 07 Python
matlab灰度图像调整及imadjust函数的用法详解
Feb 27 Python
使用python创建Excel工作簿及工作表过程图解
May 27 Python
virtualenv介绍及简明教程
Jun 23 Python
django前端页面下拉选择框默认值设置方式
Aug 09 Python
基于Python爬取搜狐证券股票过程解析
Nov 18 Python
Django与AJAX实现网页动态数据显示的示例代码
Feb 24 Python
Python进阶学习之带你探寻Python类的鼻祖-元类
May 08 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
模拟OICQ的实现思路和核心程序(一)
2006/10/09 PHP
PHP 模拟登陆MSN并获得用户信息
2009/05/16 PHP
PHP utf-8编码问题,utf8编码,数据库乱码,页面显示输出乱码
2013/04/08 PHP
关于PHP语言构造器介绍
2013/07/08 PHP
通过php添加xml文档内容的方法
2015/01/23 PHP
arguments对象
2006/11/20 Javascript
Javascript中的数学函数集合
2007/05/08 Javascript
JavaScript与C# Windows应用程序交互方法
2007/06/29 Javascript
利用jQuery插件扩展识别浏览器内核与外壳的类型和版本的实现代码
2011/10/22 Javascript
JS动态增删表格行的方法
2016/03/03 Javascript
jquery对dom节点的操作【推荐】
2016/04/15 Javascript
浅析JS动态创建元素【两种方法】
2016/04/20 Javascript
JS本地刷新返回上一页代码
2016/07/25 Javascript
js导出excel文件的简洁方法(推荐)
2016/11/02 Javascript
vue 和vue-touch 实现移动端左右导航效果(仿京东移动站导航)
2017/04/22 Javascript
说说AngularJS中的$parse和$eval的用法
2017/09/14 Javascript
vue中前进刷新、后退缓存用户浏览数据和浏览位置的实例讲解
2018/09/21 Javascript
在微信小程序中使用mqtt服务的方法
2019/12/13 Javascript
JS一次前端面试经历记录
2020/03/19 Javascript
对vue生命周期的深入理解
2020/12/03 Vue.js
Python中用Decorator来简化元编程的教程
2015/04/13 Python
python 使用poster模块进行http方式的文件传输到服务器的方法
2019/01/15 Python
在python image 中安装中文字体的实现方法
2019/08/22 Python
django框架单表操作之增删改实例分析
2019/12/16 Python
pytorch 实现查看网络中的参数
2020/01/06 Python
详解Pandas 处理缺失值指令大全
2020/07/30 Python
python 进程池pool使用详解
2020/10/15 Python
HTML5之SVG 2D入门7—SVG元素的重用与引用
2013/01/30 HTML / CSS
西尔斯百货官网:Sears
2016/09/06 全球购物
信息与计算科学专业推荐信
2014/02/23 职场文书
《猴子种果树》教学反思
2014/04/26 职场文书
承诺保证书格式
2015/02/28 职场文书
开业庆典嘉宾致辞
2015/08/01 职场文书
护士岗前培训心得体会
2016/01/08 职场文书
nginx网站服务如何配置防盗链(推荐)
2021/03/31 Servers
mysql的单列多值存储实例详解
2022/04/05 MySQL