Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现超简单端口转发的方法
Mar 13 Python
解决Scrapy安装错误:Microsoft Visual C++ 14.0 is required...
Oct 01 Python
python实现简单中文词频统计示例
Nov 08 Python
python实现ID3决策树算法
Dec 20 Python
python 获取当天凌晨零点的时间戳方法
May 22 Python
Python3实现将本地JSON大数据文件写入MySQL数据库的方法
Jun 13 Python
pyqt5实现按钮添加背景图片以及背景图片的切换方法
Jun 13 Python
anaconda3安装及jupyter环境配置全教程
Aug 24 Python
15个应该掌握的Jupyter Notebook使用技巧(小结)
Sep 23 Python
BeautifulSoup中find和find_all的使用详解
Dec 07 Python
python基础之类方法和静态方法
Oct 24 Python
python缺失值填充方法示例代码
Dec 24 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
什么是PHP7中的孤儿进程与僵尸进程
2019/04/14 PHP
php和asp语法上的区别总结
2019/05/12 PHP
基于php解决json_encode中文UNICODE转码问题
2020/11/10 PHP
FireFox中textNode分片的问题
2007/04/10 Javascript
jquery click([data],fn)使用方法实例介绍
2013/07/08 Javascript
判断js对象是否拥有某一个属性的js代码
2013/08/16 Javascript
百度移动版的url编码解码示例
2014/04/29 Javascript
js+css实现导航效果实例
2015/02/10 Javascript
JavaScript使用ActiveXObject访问Access和SQL Server数据库
2015/04/02 Javascript
JavaScript弹出新窗口并控制窗口移动到指定位置的方法
2015/04/06 Javascript
javascript实现unicode与ASCII相互转换的方法
2015/12/10 Javascript
详解Bootstrap的iCheck插件checkbox和radio
2016/08/24 Javascript
javascript删除html标签函数cIsHTML
2017/01/09 Javascript
Django中使用jquery的ajax进行数据交互的实例代码
2017/10/15 jQuery
原生JS实现的多个彩色小球跟随鼠标移动动画效果示例
2018/02/01 Javascript
React项目动态设置title标题的方法示例
2018/09/26 Javascript
详解a标签添加onclick事件的几种方式
2019/03/29 Javascript
JavaScript中的垃圾回收与内存泄漏示例详解
2019/05/02 Javascript
webpack-mvc 传统多页面组件化开发详解
2019/05/07 Javascript
Vuex的各个模块封装的实现
2020/06/05 Javascript
[55:44]完美世界DOTA2联赛决赛 FTD vs Phoenix 第二场 11.08
2020/11/11 DOTA
Tornado 多进程实现分析详解
2018/01/12 Python
Python3实现zip分卷压缩过程解析
2019/10/09 Python
提升python处理速度原理及方法实例
2019/12/25 Python
Scrapy框架介绍之Puppeteer渲染的使用
2020/06/19 Python
matplotlib之pyplot模块坐标轴标签设置使用(xlabel()、ylabel())
2021/02/22 Python
Pureology官网:为染色头发打造最好的产品
2019/09/13 全球购物
酒店节能减排方案
2014/05/26 职场文书
转让协议书范本
2014/09/13 职场文书
2015公务员试用期工作总结
2014/12/12 职场文书
中学生自我评价2015
2015/03/03 职场文书
2015年“七七卢沟桥事变”纪念活动总结
2015/03/24 职场文书
辩论会主持词
2015/07/03 职场文书
优秀党员主要事迹材料
2015/11/04 职场文书
python模块与C和C++动态库相互调用实现过程示例
2021/11/02 Python
MySQL数据库索引的最左匹配原则
2021/11/20 MySQL