Tensorflow实现部分参数梯度更新操作


Posted in Python onJanuary 23, 2020

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
linux下python抓屏实现方法
May 22 Python
pandas将DataFrame的列变成行索引的方法
Apr 10 Python
Flask模拟实现CSRF攻击的方法
Jul 24 Python
python整小时 整天时间戳获取算法示例
Feb 20 Python
Python容器使用的5个技巧和2个误区总结
Sep 26 Python
python中@property和property函数常见使用方法示例
Oct 21 Python
Django 简单实现分页与搜索功能的示例代码
Nov 07 Python
python pygame实现挡板弹球游戏
Nov 25 Python
python sitk.show()与imageJ结合使用常见的问题
Apr 20 Python
python 实现数据库中数据添加、查询与更新的示例代码
Dec 07 Python
python常量折叠基础知识点讲解
Feb 28 Python
使用Python webdriver图书馆抢座自动预约的正确方法
Mar 04 Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
django 文件上传功能的相关实例代码(简单易懂)
Jan 22 #Python
You might like
基于jQuery的仿flash的广告轮播代码
2010/11/04 Javascript
JS 自定义带默认值的函数
2011/07/21 Javascript
jQuery实现加入购物车飞入动画效果
2015/03/14 Javascript
充分发挥Node.js程序性能的一些方法介绍
2015/06/23 Javascript
复杂的javascript窗口分帧解析
2016/02/19 Javascript
JavaScript html5 canvas绘制时钟效果
2016/03/01 Javascript
Javascript中的几种继承方式对比分析
2016/03/22 Javascript
javascript设计模式之策略模式学习笔记
2017/02/15 Javascript
JavaScript实现两个select下拉框选项左移右移
2017/03/09 Javascript
Nodejs实现多房间简易聊天室功能
2017/06/20 NodeJs
AngularJS通过ng-Img-Crop实现头像截取的示例
2017/08/17 Javascript
基于Vue实现微信小程序的图文编辑器
2018/07/25 Javascript
JavaScript继承的特性与实践应用深入详解
2018/12/30 Javascript
VUE搭建手机商城心得和遇到的坑
2019/02/21 Javascript
vue视图不更新情况详解
2019/05/16 Javascript
策略模式实现 Vue 动态表单验证的方法
2019/09/16 Javascript
js实现自定义滚动条的示例
2020/10/27 Javascript
Python数组条件过滤filter函数使用示例
2014/07/22 Python
Python中map,reduce,filter和sorted函数的使用方法
2015/08/17 Python
Python递归函数定义与用法示例
2017/06/02 Python
Python基于pyCUDA实现GPU加速并行计算功能入门教程
2018/06/19 Python
基于python实现KNN分类算法
2020/04/23 Python
python调用动态链接库的基本过程详解
2019/06/19 Python
python数据归一化及三种方法详解
2019/08/06 Python
用Python画一个LinkinPark的logo代码实例
2019/09/10 Python
python传到前端的数据,双引号被转义的问题
2020/04/03 Python
Jupyter notebook如何实现指定浏览器打开
2020/05/13 Python
python删除某个目录文件夹的方法
2020/05/26 Python
利用Storage Event实现页面间通信的示例代码
2018/07/26 HTML / CSS
西班牙太阳镜品牌:Hawkers
2018/03/11 全球购物
美国马匹用品和马钉购物网站:State Line Tack
2018/08/05 全球购物
英国网上自行车商店:Tredz Bikes
2019/10/29 全球购物
宗教学大学生职业生涯规划范文
2014/02/08 职场文书
初三学生评语大全
2014/04/24 职场文书
2015年入党积极分子评语
2015/03/26 职场文书
Oracle锁表解决方法的详细记录
2022/06/05 Oracle