tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在python的WEB框架Flask中使用多个配置文件的解决方法
Apr 18 Python
Python常见数据结构详解
Jul 24 Python
Python中返回字典键的值的values()方法使用
May 22 Python
详解tensorflow实现迁移学习实例
Feb 10 Python
python检测主机的连通性并记录到文件的实例
Jun 21 Python
一篇文章搞懂Python的类与对象名称空间
Dec 10 Python
python 图像平移和旋转的实例
Jan 10 Python
Python爬虫实现验证码登录代码实例
May 10 Python
基于python-opencv3的图像显示和保存操作
Jun 27 Python
Selenium向iframe富文本框输入内容过程图解
Apr 10 Python
Python timeit模块原理及使用方法
Oct 10 Python
Numpy中np.random.rand()和np.random.randn() 用法和区别详解
Oct 23 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
解析php中var_dump,var_export,print_r三个函数的区别
2013/06/21 PHP
解析php防止form重复提交的方法
2013/07/01 PHP
Laravel 5.5基于内置的Auth模块实现前后台登陆详解
2017/12/21 PHP
Laravel实现短信注册的示例代码
2018/05/29 PHP
thinkphp5框架API token身份验证功能示例
2019/05/21 PHP
用js实现下载远程文件并保存在本地的脚本
2008/05/06 Javascript
ajax页面无刷新 IE下遭遇Ajax缓存导致数据不更新的问题
2012/12/11 Javascript
让jQuery Mobile不显示讨厌loading界面的方法
2014/02/19 Javascript
如何将php数组或者对象传递给javascript
2014/03/20 Javascript
jQuery往textarea中光标所在位置插入文本的方法
2015/06/26 Javascript
整理Javascript基础语法学习笔记
2015/11/29 Javascript
jQuery取得iframe中元素的常用方法详解
2016/01/14 Javascript
JS实现的四级密码强度检测功能示例
2017/05/11 Javascript
angular ng-click防止重复提交实例
2017/06/16 Javascript
使用Node.js实现简易MVC框架的方法
2017/08/07 Javascript
微信小程序实现下载进度条的方法
2017/12/08 Javascript
JS实现的简单折叠展开动画效果示例
2018/04/28 Javascript
微信小程序实现批量倒计时功能
2020/11/01 Javascript
nodejs实现UDP组播示例方法
2019/11/04 NodeJs
vue+echarts实现中国地图流动效果(步骤详解)
2021/01/27 Vue.js
[01:28]2014DOTA2国际邀请赛中国区预选赛四大豪门直升机抵达会场
2014/05/24 DOTA
[00:08]DOTA2勇士令状等级奖励“天外飞星”
2019/05/24 DOTA
聊聊Python中的pypy
2018/01/12 Python
python从入门到精通 windows安装python图文教程
2019/05/18 Python
Python二元赋值实用技巧解析
2019/10/25 Python
python实现打砖块游戏
2020/02/25 Python
sqlalchemy实现时间列自动更新教程
2020/09/02 Python
解决python3.x安装numpy成功但import出错的问题
2020/11/17 Python
Html5内唤醒百度、高德APP的实现示例
2019/05/20 HTML / CSS
越南综合购物网站:Lazada越南
2019/06/10 全球购物
最新远光软件笔试题面试题内容
2013/11/08 面试题
介绍一下你对SOA的认识
2016/04/24 面试题
甲方资料员岗位职责
2013/12/13 职场文书
乡镇干部先进性教育活动个人整改措施
2014/09/16 职场文书
抗洪救灾感谢信
2015/01/22 职场文书
2016小学新学期寄语
2015/12/04 职场文书