tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python利用Guetzli批量压缩图片
Mar 23 Python
python实现多线程网页下载器
Apr 15 Python
django缓存配置的几种方法详解
Jul 16 Python
解决Python3.5+OpenCV3.2读取图像的问题
Dec 05 Python
Python二叉搜索树与双向链表转换算法示例
Mar 02 Python
Python实现从SQL型数据库读写dataframe型数据的方法【基于pandas】
Mar 18 Python
django 使用全局搜索功能的实例详解
Jul 18 Python
Python学习笔记之While循环用法分析
Aug 14 Python
wxpython多线程防假死与线程间传递消息实例详解
Dec 13 Python
python GUI库图形界面开发之PyQt5信号与槽基础使用方法与实例
Mar 06 Python
python为QT程序添加图标的方法详解
Mar 09 Python
如何教少儿学习Python编程
Jul 10 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
分割GBK中文遭遇乱码的解决方法
2013/08/09 PHP
6种php上传图片重命名的方法实例
2013/11/04 PHP
smarty简单入门实例
2014/11/28 PHP
ecshop添加菜单及权限分配问题
2017/11/21 PHP
PHP面向对象程序设计之接口的继承定义与用法详解
2018/12/20 PHP
Apache+PHP+MySQL搭建PHP开发环境图文教程
2020/08/06 PHP
地震发生中逃生十大法则
2008/05/12 Javascript
jquery中选择块并改变属性值的方法
2013/07/31 Javascript
jQuery插件pagination实现分页特效
2015/04/12 Javascript
简介JavaScript中的getUTCFullYear()方法的使用
2015/06/10 Javascript
使用vue.js制作分页组件
2016/06/27 Javascript
JavaScript控制输入框中只能输入中文、数字和英文的方法【基于正则实现】
2017/03/03 Javascript
Vuex利用state保存新闻数据实例
2017/06/28 Javascript
浅谈Webpack多页应用HMR卡住问题
2019/04/24 Javascript
[44:15]国士无双DOTA2 6.82版本详解(上)
2014/09/28 DOTA
Python xlrd读取excel日期类型的2种方法
2015/04/28 Python
Python可变参数用法实例分析
2017/04/02 Python
Python实现输入二叉树的先序和中序遍历,再输出后序遍历操作示例
2018/07/27 Python
python中 * 的用法详解
2019/07/10 Python
Python 处理文件的几种方式
2019/08/23 Python
pyftplib中文乱码问题解决方案
2020/01/11 Python
Python如何使用vars返回对象的属性列表
2020/10/17 Python
python 如何对logging日志封装
2020/12/02 Python
使用CSS3制作倾斜导航条和毛玻璃效果
2017/09/12 HTML / CSS
中国最大的团购网站:聚划算
2016/09/21 全球购物
Clarks英国官方网站:全球领军鞋履品牌
2016/11/26 全球购物
美国专营婴幼儿用品的购物网站:buybuy BABY
2017/01/01 全球购物
考博自荐信
2013/10/25 职场文书
毕业自我评价范文
2013/11/17 职场文书
初婚未育证明
2014/01/15 职场文书
年会主持词结束语
2014/03/27 职场文书
保护黄河倡议书
2014/05/16 职场文书
中队活动总结
2014/08/27 职场文书
2014年体育教学工作总结
2014/12/09 职场文书
2014年法务工作总结
2014/12/11 职场文书
2015年度员工自我评价范文
2015/03/11 职场文书