tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python单元测试unittest实例详解
May 11 Python
基于python socketserver框架全面解析
Sep 21 Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 Python
python学生管理系统代码实现
Apr 05 Python
Python闭包执行时值的传递方式实例分析
Jun 04 Python
利用Python将数值型特征进行离散化操作的方法
Nov 06 Python
详解用python自制微信机器人,定时发送天气预报
Mar 25 Python
Python Matplotlib 基于networkx画关系网络图
Jul 10 Python
django admin组件使用方法详解
Jul 19 Python
使用Python快乐学数学Github万星神器Manim简介
Aug 07 Python
Python HTMLTestRunner可视化报告实现过程解析
Apr 10 Python
Keras使用ImageNet上预训练的模型方式
May 23 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
用PHP伪造referer突破网盘禁止外连的代码
2008/06/15 PHP
PHP中使用数组实现堆栈数据结构的代码
2012/02/05 PHP
基于Zend的Config机制的应用分析
2013/05/02 PHP
MongoDB在PHP中的常用操作小结
2014/02/20 PHP
PHP中的排序函数sort、asort、rsort、krsort、ksort区别分析
2014/08/18 PHP
PHP flush 函数使用注意事项
2016/08/26 PHP
PHP封装的XML简单操作类完整实例
2017/11/13 PHP
PHP实现的62进制转10进制,10进制转62进制函数示例
2019/06/06 PHP
JavaScript中Object和Function的关系小结
2009/09/26 Javascript
javascript中的107个基础知识收集整理 推荐
2010/03/29 Javascript
jquery indexOf使用方法
2013/08/19 Javascript
对比分析json及XML
2014/11/28 Javascript
微信公众号-获取用户信息(网页授权获取)实现步骤
2016/10/21 Javascript
jQuery实现的简单排序功能示例【冒泡排序】
2017/01/13 Javascript
JS中Array数组学习总结
2017/01/18 Javascript
JavaScript监听手机物理返回键的两种解决方法
2017/08/14 Javascript
Vue中v-show添加表达式的问题(判断是否显示)
2018/03/26 Javascript
Vue props 单向数据流的实现
2018/11/06 Javascript
javascript单张多张图无缝滚动实例代码
2020/05/10 Javascript
Python开发WebService系列教程之REST,web.py,eurasia,Django
2014/06/30 Python
Python3实现带附件的定时发送邮件功能
2020/12/22 Python
Python高级特性——详解多维数组切片(Slice)
2019/11/26 Python
django models里数据表插入数据id自增操作
2020/07/15 Python
使用css3 属性如何丰富图片样式(圆角 阴影 渐变)
2012/11/22 HTML / CSS
雅诗兰黛加拿大官网:Estee Lauder加拿大
2019/07/31 全球购物
思想专业自荐信范文
2013/12/25 职场文书
工程造价专业求职信
2014/07/17 职场文书
咖啡厅商业计划书
2014/09/15 职场文书
2014年社区民政工作总结
2014/12/02 职场文书
安装工程师岗位职责
2015/02/13 职场文书
2015年党员创先争优公开承诺书
2015/04/27 职场文书
唐山大地震的观后感
2015/06/05 职场文书
领导干部学习心得体会
2016/01/23 职场文书
初二数学教学反思
2016/02/17 职场文书
Python实现学生管理系统并生成exe可执行文件详解流程
2022/01/22 Python
python自动化测试之Selenium详解
2022/03/13 Python