tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中bisect的用法
Sep 23 Python
python批量生成本地ip地址的方法
Mar 23 Python
Python2.x和3.x下maketrans与translate函数使用上的不同
Apr 13 Python
python 简单备份文件脚本v1.0的实例
Nov 06 Python
python绘制圆柱体的方法
Jul 02 Python
Django 限制访问频率的思路详解
Dec 24 Python
Macbook安装Python最新版本、GUI开发环境、图像处理、视频处理环境详解
Feb 17 Python
Python单例模式的四种创建方式实例解析
Mar 04 Python
pygame实现飞机大战
Mar 11 Python
运行Python编写的程序方法实例
Oct 21 Python
Python实现对word文档添加密码去除密码的示例代码
Dec 29 Python
Python matplotlib绘制雷达图
Apr 13 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
Discuz 模板引擎的封装类代码
2008/07/18 PHP
洪恩在线成语词典小偷程序php版
2012/04/20 PHP
PHP addslashes()函数讲解
2019/02/03 PHP
js 实现无缝滚动 兼容IE和FF
2009/07/15 Javascript
初试jQuery EasyUI 使用介绍
2010/04/01 Javascript
单击复制文字兼容各浏览器的完美解决方案
2013/07/04 Javascript
jQuery中$.get、$.post、$.getJSON和$.ajax的用法详解
2014/11/19 Javascript
JavaScript返回当前会话cookie全部键值对照的方法
2015/04/03 Javascript
javascript实现检验的各种规则
2015/07/31 Javascript
scroll事件实现监控滚动条并分页显示(zepto.js)
2016/12/18 Javascript
js实现把图片的绝对路径转为base64字符串、blob对象再上传
2016/12/29 Javascript
微信小程序 label 组件详解及简单实例
2017/01/10 Javascript
jQuery简单实现MD5加密的方法
2017/03/03 Javascript
AngularJS自定义过滤器用法经典实例总结
2018/05/17 Javascript
JS获取今天是本月第几周、本月共几周、本月有多少天、是今年的第几周、是今年的第几天的示例代码
2018/12/05 Javascript
vue路由--网站导航功能详解
2019/03/29 Javascript
微信小程序位置授权处理方法
2019/06/13 Javascript
js实现小球在页面规定的区域运动
2020/06/16 Javascript
详解JavaScript 作用域
2020/07/14 Javascript
[02:34]肉山说——泡妞篇
2014/09/16 DOTA
Python 时间操作例子和时间格式化参数小结
2014/04/24 Python
Python中类型检查的详细介绍
2017/02/13 Python
Python获取当前路径实现代码
2017/05/08 Python
Python采集代理ip并判断是否可用和定时更新的方法
2018/05/07 Python
python使用zip将list转为json的方法
2018/12/31 Python
深入浅析Python中的迭代器
2019/06/04 Python
Python定时发送天气预报邮件代码实例
2019/09/09 Python
python3爬虫中引用Queue的实例讲解
2020/11/24 Python
Grid 宫格常用布局的实现
2020/01/10 HTML / CSS
来自圣地亚哥的实惠太阳镜:Knockaround
2018/08/27 全球购物
同步和异步有何异同,在什么情况下分别使用他们?
2012/12/28 面试题
就业意向书范文
2014/04/01 职场文书
会计专业应届生自荐信
2014/06/28 职场文书
2014年乡镇工作总结
2014/11/21 职场文书
交通事故被告代理词
2015/05/23 职场文书
mysql全面解析json/数组
2022/07/07 MySQL