tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python发送邮件的实例代码(支持html、图片、附件)
Mar 04 Python
收集的几个Python小技巧分享
Nov 22 Python
Python数据类型详解(三)元祖:tuple
May 08 Python
python简单读取大文件的方法
Jul 01 Python
解决Python 遍历字典时删除元素报异常的问题
Sep 11 Python
浅谈pytorch和Numpy的区别以及相互转换方法
Jul 26 Python
Python设计模式之迭代器模式原理与用法实例分析
Jan 10 Python
Pandas分组与排序的实现
Jul 23 Python
django 微信网页授权认证api的步骤详解
Jul 30 Python
Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】
Dec 19 Python
Python中pyecharts安装及安装失败的解决方法
Feb 18 Python
详解python第三方库的安装、PyInstaller库、random库
Mar 03 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
开源php中文分词系统SCWS安装和使用实例
2014/04/11 PHP
Smarty高级应用之缓存操作技巧分析
2016/05/14 PHP
php生成无限栏目树
2017/03/16 PHP
PHP实现的随机红包算法示例
2017/08/14 PHP
php如何比较两个浮点数是否相等详解
2019/02/12 PHP
PHP中关于php.ini参数优化详解
2020/02/28 PHP
javascript 一个自定义长度的文本自动换行的函数
2007/08/19 Javascript
jQueryUI的Dialog的简单封装
2010/06/07 Javascript
jquery动画2.元素坐标动画效果(创建一个图片走廊)
2012/08/24 Javascript
重写javascript中window.confirm的行为
2012/10/21 Javascript
Three.js源码阅读笔记(光照部分)
2012/12/27 Javascript
Jquery chosen动态设置值实例介绍
2013/08/08 Javascript
js字符串转成JSON
2013/11/07 Javascript
javascript日期格式化示例分享
2014/03/05 Javascript
JavaScript实现将xml转换成html table表格的方法
2015/04/17 Javascript
jquery实现点击查看更多内容控制段落文字展开折叠效果
2015/08/06 Javascript
javascript 四十条常用技巧大全
2016/09/09 Javascript
JavaScript制作颜色反转小游戏
2016/09/25 Javascript
JS实现上传图片实时预览功能
2017/05/22 Javascript
vue路由拦截及页面跳转的设置方法
2018/05/24 Javascript
layer.js之回调销毁对话框的例子
2019/09/11 Javascript
JavaScript实现10秒后再次获取验证码
2020/12/02 Javascript
[01:13:01]2018DOTA2亚洲邀请赛 4.4 淘汰赛 TNC vs VG 第三场
2018/04/05 DOTA
pycharm 使用心得(三)Hello world!
2014/06/05 Python
python简单实现基于SSL的IRC bot实例
2015/06/15 Python
python编码总结(编码类型、格式、转码)
2016/07/01 Python
Python使用matplotlib填充图形指定区域代码示例
2018/01/16 Python
Python实现繁?转为简体的方法示例
2018/12/18 Python
关于windows下Tensorflow和pytorch安装教程
2020/02/04 Python
英国最大的女士服装零售商:Bonmarché
2017/08/17 全球购物
阿姆斯特丹城市卡:Amsterdam Pass
2019/12/01 全球购物
Linux内核的同步机制是什么?主要有哪几种内核锁
2013/01/03 面试题
简短证婚人证婚词
2014/01/09 职场文书
文化活动实施方案
2014/03/28 职场文书
读鲁迅先生的经典名言
2019/08/20 职场文书
Windows10安装Apache2.4的方法步骤
2022/06/25 Servers