tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python运行的17个时新手常见错误小结
Aug 07 Python
c++生成dll使用python调用dll的方法
Jan 20 Python
Python中用Spark模块的使用教程
Apr 13 Python
Python部署web开发程序的几种方法
May 05 Python
Django如何实现内容缓存示例详解
Sep 24 Python
pandas groupby 分组取每组的前几行记录方法
Apr 20 Python
Python实现正弦信号的时域波形和频谱图示例【基于matplotlib】
May 04 Python
使用Python实现租车计费系统的两种方法
Sep 29 Python
对Python定时任务的启动和停止方法详解
Feb 19 Python
PyQt5显示GIF图片的方法
Jun 17 Python
Python基础学习之基本数据结构详解【数字、字符串、列表、元组、集合、字典】
Jun 18 Python
python、PyTorch图像读取与numpy转换实例
Jan 13 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
虫族 Zerg 热键控制
2020/03/14 星际争霸
ThinkPHP3.2框架自定义配置和加载用法示例
2018/06/14 PHP
PHP设计模式之PHP迭代器模式讲解
2019/03/22 PHP
Laravel使用swoole实现websocket主动消息推送的方法介绍
2019/10/20 PHP
使用JavaScript检测Firefox浏览器是否启用了Firebug的代码
2010/12/28 Javascript
javascript实现checkBox的全选,反选与赋值
2015/03/12 Javascript
jQuery实现购物车数字加减效果
2015/03/14 Javascript
JavaScript返回上一页的三种方法及区别介绍
2015/07/04 Javascript
网页从弹窗页面单选框传值至父页面代码分享
2015/09/29 Javascript
javascript基础知识之html5轮播图实例讲解(44)
2017/02/17 Javascript
微信小程序 开发之全局配置
2017/05/05 Javascript
vue页面切换到滚动页面显示顶部的实例
2018/03/13 Javascript
NodeJs生成sitemap站点地图的方法示例
2019/06/11 NodeJs
javascript中this的用法实践分析
2019/07/29 Javascript
浅谈layui里的上传控件问题
2019/09/26 Javascript
vue.js实现简单的计算器功能
2020/02/22 Javascript
如何在JavaScript中使用localStorage详情
2021/02/04 Javascript
Vue包大小优化的实现(从1.72M到94K)
2021/02/18 Vue.js
[02:38]DOTA2超级联赛专访Loda 认为IG世界最强
2013/05/27 DOTA
Python输出汉字字库及将文字转换为图片的方法
2016/06/04 Python
python实现下载整个ftp目录的方法
2017/01/17 Python
pyqt5中QThread在使用时出现重复emit的实例
2019/06/21 Python
Python转换时间的图文方法
2019/07/01 Python
Python多重继承之菱形继承的实例详解
2020/02/12 Python
python小白学习包管理器pip安装
2020/06/09 Python
python中doctest库实例用法
2020/12/31 Python
简单几步用纯CSS3实现3D翻转效果
2019/01/17 HTML / CSS
印度尼西亚电子产品购物网站:Kliknklik
2018/06/05 全球购物
德国网上超市:myTime.de
2019/08/26 全球购物
七年级政治教学反思
2014/02/03 职场文书
《我的信念》教学反思
2014/02/15 职场文书
应届毕业生应聘自荐信范文
2014/02/26 职场文书
党员个人整改方案及措施
2014/10/25 职场文书
2014年班组长工作总结
2014/11/20 职场文书
国庆节新闻稿
2015/07/17 职场文书
草系十大最强宝可梦,纸片人上榜,榜首大家最熟悉
2022/03/18 日漫