tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 从远程服务器下载东西的代码
Feb 10 Python
python海龟绘图实例教程
Jul 24 Python
Python常用正则表达式符号浅析
Aug 13 Python
python使用Pycharm创建一个Django项目
Mar 05 Python
Python Requests库基本用法示例
Aug 20 Python
Django框架视图函数设计示例
Jul 29 Python
Python中注释(多行注释和单行注释)的用法实例
Aug 28 Python
手把手教你进行Python虚拟环境配置教程
Feb 03 Python
pycharm设置python文件模板信息过程图解
Mar 10 Python
Python爬取股票信息,并可视化数据的示例
Sep 26 Python
编写python代码实现简单抽奖器
Oct 20 Python
详细介绍python类及类的用法
May 31 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
php park、unpark、ord 函数使用方法(二进制流接口应用实例)
2010/10/19 PHP
微信access_token的获取开发示例
2015/04/16 PHP
浅谈PHP Cookie处理函数
2016/06/10 PHP
PHP页面输出时js设置input框的选中值
2016/09/30 PHP
URI、URL和URN之间的区别与联系
2006/12/20 Javascript
通过 Dom 方法提高 innerHTML 性能
2008/03/26 Javascript
Javascript 学习笔记 错误处理
2009/07/30 Javascript
javascript 面向对象全新理练之数据的封装
2009/12/03 Javascript
jQuery AJAX实现调用页面后台方法和web服务定义的方法分享
2012/03/01 Javascript
解决Jquery load()加载GB2312页面时出现乱码的两种方案
2013/09/10 Javascript
jquerymobile局部渲染的各种刷新方法小结
2014/03/05 Javascript
JS中attr和prop属性的区别以及优先选择示例介绍
2014/06/30 Javascript
JavaScript动态插入CSS的方法
2015/12/10 Javascript
深入浅析Vue不同场景下组件间的数据交流
2017/08/15 Javascript
Node.js Buffer用法解读
2018/05/18 Javascript
基于JavaScript实现瀑布流布局
2018/08/15 Javascript
微信小程序结合Storage实现搜索历史效果
2019/05/18 Javascript
七行JSON代码把你的网站变成移动应用过程详解
2019/07/09 Javascript
jQuery实现轮播图源码
2019/10/23 jQuery
解决微信小程序scroll-view组件无横向滚动的问题
2020/02/04 Javascript
Anaconda2 5.2.0安装使用图文教程
2018/09/19 Python
python中类的属性和方法介绍
2018/11/27 Python
pytorch实现onehot编码转为普通label标签
2020/01/02 Python
python 实现线程之间的通信示例
2020/02/14 Python
python tkinter 设置窗口大小不可缩放实例
2020/03/04 Python
python标准库OS模块详解
2020/03/10 Python
python软件测试Jmeter性能测试JDBC Request(结合数据库)的使用详解
2021/01/26 Python
美国家具网站:Cymax
2016/09/17 全球购物
共产党员岗位承诺书
2014/05/29 职场文书
中学生运动会口号
2014/06/07 职场文书
人力资源管理专业求职信
2014/07/23 职场文书
网球场地租赁协议范本
2014/10/07 职场文书
目标责任书格式范文
2015/05/11 职场文书
2015年暑期实践报告范文
2015/07/13 职场文书
《去年的树》教学反思
2016/02/18 职场文书
golang操作redis的客户端包有多个比如redigo、go-redis
2022/04/14 Golang