tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现冒泡,插入,选择排序简单实例
Aug 18 Python
用Python实现换行符转换的脚本的教程
Apr 16 Python
python使用PyGame播放Midi和Mp3文件的方法
Apr 24 Python
在Django的URLconf中使用命名组的方法
Jul 18 Python
Python使用pylab库实现画线功能的方法详解
Jun 08 Python
Python 使用PIL numpy 实现拼接图片的示例
May 08 Python
idea创建springMVC框架和配置小文件的教程图解
Sep 18 Python
Pycharm设置utf-8自动显示方法
Jan 17 Python
python执行scp命令拷贝文件及文件夹到远程主机的目录方法
Jul 08 Python
python爬虫之爬取百度音乐的实现方法
Aug 24 Python
用python中的matplotlib绘制方程图像代码
Nov 21 Python
pytorch标签转onehot形式实例
Jan 02 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
坏狼的PHP学习教程之第1天
2008/06/15 PHP
ThinkPHP控制器里javascript代码不能执行的解决方法
2014/11/22 PHP
PHP使用DOM对XML解析处理操作示例
2019/07/04 PHP
Laravel 手动开关 Eloquent 修改器的操作方法
2019/12/30 PHP
利用PHP内置SERVER开启web服务(本地开发使用)
2020/01/22 PHP
php+mysql+ajax 局部刷新点赞/取消点赞功能(每个账号只点赞一次)
2020/07/24 PHP
模仿百度三维地图的js数据分享
2011/05/12 Javascript
页面图片浮动左右滑动效果的简单实现案例
2014/02/10 Javascript
jquery控制显示服务器生成的图片流
2015/08/04 Javascript
js贪吃蛇网页版游戏特效代码分享(挑战十关)
2015/08/24 Javascript
Angular2入门--架构总览
2017/03/29 Javascript
ES6新特性二:Iterator(遍历器)和for-of循环详解
2017/04/20 Javascript
通过构造函数实例化对象的方法
2017/06/28 Javascript
分享Bootstrap简单表格、表单、登录页面
2017/08/04 Javascript
angularjs实现分页和搜索功能
2018/01/03 Javascript
微信小程序中上传图片并进行压缩的实现代码
2018/08/28 Javascript
vue-cli中vue本地实现跨域调试接口
2019/01/16 Javascript
jQuery实现为table表格动态添加或删除tr功能示例
2019/02/19 jQuery
合并Excel工作薄中成绩表的VBA代码,非常适合教育一线的朋友
2009/04/09 Python
在CentOS上配置Nginx+Gunicorn+Python+Flask环境的教程
2016/06/07 Python
Python实现合并excel表格的方法分析
2019/04/13 Python
python3 批量获取对应端口服务的实例
2019/07/25 Python
tensorflow 获取所有variable或tensor的name示例
2020/01/04 Python
浅谈Keras中shuffle和validation_split的顺序
2020/06/19 Python
python如何进入交互模式
2020/07/06 Python
城市观光通行证:The Sightseeing Pass
2018/04/28 全球购物
Wiggle新西兰:自行车、跑步、游泳
2020/05/06 全球购物
最新会计专业求职信范文
2014/01/28 职场文书
企业管理毕业生求职信
2014/03/11 职场文书
升职演讲稿范文
2014/05/23 职场文书
幼儿园家长安全责任书
2014/07/22 职场文书
党校个人总结
2015/03/04 职场文书
2015年营销工作总结范文
2015/04/23 职场文书
2015中秋祝酒词
2015/08/12 职场文书
《普罗米修斯》教学反思
2016/02/22 职场文书
浅谈node.js中间件有哪些类型
2021/04/29 Javascript