tensorflow 恢复指定层与不同层指定不同学习率的方法


Posted in Python onJuly 26, 2018

如下所示:

#tensorflow 中从ckpt文件中恢复指定的层或将指定的层不进行恢复:
#tensorflow 中不同的layer指定不同的学习率
 
with tf.Graph().as_default():
		#存放的是需要恢复的层参数
	 variables_to_restore = []
	 #存放的是需要训练的层参数名,这里是没恢复的需要进行重新训练,实际上恢复了的参数也可以训练
  variables_to_train = []
  for var in slim.get_model_variables():
   excluded = False
   for exclusion in fine_tune_layers:
   #比如fine tune layer中包含logits,bottleneck
    if var.op.name.startswith(exclusion):
     excluded = True
     break
   if not excluded:
    variables_to_restore.append(var)
    #print('var to restore :',var)
   else:
    variables_to_train.append(var)
    #print('var to train: ',var)
 
 
  #这里省略掉一些步骤,进入训练步骤:
  #将variables_to_train,需要训练的参数给optimizer 的compute_gradients函数
  grads = opt.compute_gradients(total_loss, variables_to_train)
  #这个函数将只计算variables_to_train中的梯度
  #然后将梯度进行应用:
  apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
  #也可以直接调用opt.minimize(total_loss,variables_to_train)
  #minimize只是将compute_gradients与apply_gradients封装成了一个函数,实际上还是调用的这两个函数
  #如果在梯度里面不同的参数需要不同的学习率,那么可以:
 
  capped_grads_and_vars = []#[(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
  #update_gradient_vars是需要更新的参数,使用的是全局学习率
  #对于不是update_gradient_vars的参数,将其梯度更新乘以0.0001,使用基本上不动
 	for grad in grads:
 		for update_vars in update_gradient_vars:
 			if grad[1]==update_vars:
 				capped_grads_and_vars.append((grad[0],grad[1]))
 			else:
 				capped_grads_and_vars.append((0.0001*grad[0],grad[1]))
 
 	apply_gradient_op = opt.apply_gradients(capped_grads_and_vars, global_step=global_step)
 
 	#在恢复模型时:
 
  with sess.as_default():
 
   if pretrained_model:
    print('Restoring pretrained model: %s' % pretrained_model)
    init_fn = slim.assign_from_checkpoint_fn(
    pretrained_model,
    variables_to_restore)
    init_fn(sess)
   #这样就将指定的层参数没有恢复

以上这篇tensorflow 恢复指定层与不同层指定不同学习率的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python和C语言混合编程实例
Jun 04 Python
Python+django实现文件上传
Jan 17 Python
Python实现简单多线程任务队列
Feb 27 Python
python print 按逗号或空格分隔的方法
May 02 Python
python高级特性和高阶函数及使用详解
Oct 17 Python
Python turtle绘画象棋棋盘
Aug 21 Python
关于python中plt.hist参数的使用详解
Nov 28 Python
python dict如何定义
Sep 02 Python
python和node.js生成当前时间戳的示例
Sep 29 Python
详解Python中__new__方法的作用
Mar 31 Python
Python语法学习之进程的创建与常用方法详解
Apr 08 Python
使用python绘制横竖条形图
Apr 21 Python
kaggle+mnist实现手写字体识别
Jul 26 #Python
解决tensorflow模型参数保存和加载的问题
Jul 26 #Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
You might like
php数组查找函数总结
2014/11/18 PHP
yii2实现分页,带搜索的分页功能示例
2017/01/07 PHP
php下载远程大文件(获取远程文件大小)的实例
2017/06/17 PHP
推荐dojo学习笔记
2007/03/24 Javascript
JS刷新当前页面的几种方法总结
2013/12/24 Javascript
node.js中的fs.lchown方法使用说明
2014/12/16 Javascript
javascript制作的cookie封装及使用指南
2015/01/02 Javascript
在JavaScript的jQuery库中操作AJAX的方法讲解
2015/08/15 Javascript
js随机生成字母数字组合的字符串 随机动画数字
2015/09/02 Javascript
基于jQuery实现仿搜狐辩论投票动画代码(附源码下载)
2016/02/18 Javascript
深入理解jquery中的事件与动画
2016/05/24 Javascript
详解angular2封装material2对话框组件
2017/03/03 Javascript
小程序获取周围IBeacon设备的方法
2018/10/31 Javascript
vue中监听返回键问题
2019/08/28 Javascript
layui form.render('select', 'test2') 更新渲染的方法
2019/09/27 Javascript
vue限制输入框只能输入8位整数和2位小数的代码
2019/11/06 Javascript
云服务器部署Node.js项目的方法步骤(小白系列)
2020/03/23 Javascript
vue fetch中的.then()的正确使用方法
2020/04/17 Javascript
react基本安装与测试示例
2020/04/27 Javascript
Python模块学习 datetime介绍
2012/08/27 Python
Python判断变量是否已经定义的方法
2014/08/18 Python
Python中的进程分支fork和exec详解
2015/04/11 Python
python基于xmlrpc实现二进制文件传输的方法
2015/06/02 Python
TensorFlow如何实现反向传播
2018/02/06 Python
python检测空间储存剩余大小和指定文件夹内存占用的实例
2018/06/11 Python
使用Django开发简单接口实现文章增删改查
2019/05/09 Python
Python基本语法之运算符功能与用法详解
2019/10/22 Python
使用Tensorflow将自己的数据分割成batch训练实例
2020/01/20 Python
科室工作的个人自我评价
2013/10/30 职场文书
2014年会演讲稿范文
2014/01/06 职场文书
委托书英文
2015/01/28 职场文书
工程部主管岗位职责
2015/02/12 职场文书
军训新闻稿范文
2015/07/17 职场文书
2019学校运动会开幕词
2019/05/13 职场文书
简单介绍Python的第三方库yaml
2021/06/18 Python
JAVA SpringMVC实现自定义拦截器
2022/03/16 Python