Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python读取浮点数和读取文本文件示例
May 06 Python
pygame学习笔记(3):运动速率、时间、事件、文字
Apr 15 Python
python使用xpath中遇到:到底是什么?
Jan 04 Python
TensorFlow数据输入的方法示例
Jun 19 Python
Python中format()格式输出全解
Apr 12 Python
python读写csv文件方法详细总结
Jul 05 Python
Django 大文件下载实现过程解析
Aug 01 Python
python使用socket 先读取长度,在读取报文内容示例
Sep 26 Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 Python
你需要学会的8个Python列表技巧
Jun 24 Python
python 爬虫之selenium可视化爬虫的实现
Dec 04 Python
python批量创建变量并赋值操作
Jun 03 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
MySQL修改密码方法总结
2008/03/25 PHP
php使用curl发送json格式数据实例
2013/12/17 PHP
PHP实现显示照片exif信息的方法
2014/07/11 PHP
php采集内容中带有图片地址的远程图片并保存的方法
2015/01/03 PHP
PHP生成压缩文件实例
2015/02/07 PHP
PHP单例模式简单用法示例
2017/06/23 PHP
彪哥1.1(智能表格)提供下载
2006/09/07 Javascript
javascript 闭包
2011/09/15 Javascript
JavaScript实现当网页加载完成后执行指定函数的方法
2015/03/21 Javascript
Jsonp post 跨域方案
2015/07/06 Javascript
jQuery中$.ajax()和$.getJson()同步处理详解
2015/08/12 Javascript
JavaScript数组和对象的复制
2017/03/21 Javascript
vue实现留言板todolist功能
2017/08/16 Javascript
vue动画效果实现方法示例
2019/03/18 Javascript
layUI实现三级导航菜单效果
2019/07/26 Javascript
vue中的循环对象属性和属性值用法
2020/09/04 Javascript
python并发编程多进程之守护进程原理解析
2019/08/20 Python
在pycharm中实现删除bookmark
2020/02/14 Python
五款漂亮的纯CSS3动画按钮的实例教程
2014/11/21 HTML / CSS
英国领先的酒类网上商城:TheDrinkShop
2017/03/16 全球购物
EJB的激活机制
2013/10/25 面试题
《乞巧》教学反思
2014/02/27 职场文书
教师师德演讲稿
2014/05/06 职场文书
党课培训心得体会
2014/09/02 职场文书
2014年煤矿工作总结
2014/11/24 职场文书
工作保证书怎么写
2015/02/28 职场文书
红白喜事主持词
2015/07/06 职场文书
2015暑假社会调查报告
2015/07/13 职场文书
乡镇团代会开幕词
2016/03/04 职场文书
八年级作文之感悟亲情
2019/11/20 职场文书
利用python做表格数据处理
2021/04/13 Python
Python 用户输入和while循环的操作
2021/05/23 Python
Python的三个重要函数详解
2022/01/18 Python
SQL Server查询某个字段在哪些表中存在
2022/03/03 SQL Server
Win10 最新稳定版本 21H2开始推送
2022/04/19 数码科技
SpringBoot 集成短信和邮件 以阿里云短信服务为例
2022/04/22 Java/Android