tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的各种排序算法代码
Mar 04 Python
Python协程的用法和例子详解
Sep 09 Python
python列表生成式与列表生成器的使用
Feb 23 Python
Python文本处理之按行处理大文件的方法
Apr 09 Python
Python 进程之间共享数据(全局变量)的方法
Jul 16 Python
python将时分秒转换成秒的实例
Dec 07 Python
pytorch 自定义卷积核进行卷积操作方式
Dec 30 Python
opencv 实现特定颜色线条提取与定位操作
Jun 02 Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
Jun 10 Python
如何在scrapy中捕获并处理各种异常
Sep 28 Python
基于python实现监听Rabbitmq系统日志代码示例
Nov 28 Python
python使用Windows的wmic命令监控文件运行状况,如有异常发送邮件报警
Jan 30 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
php实现将字符串按照指定距离进行分割的方法
2015/03/14 PHP
基于PHP实现的多元线性回归模拟曲线算法
2018/01/30 PHP
Ubuntu中支持PHP5与PHP7双版本的简单实现
2018/08/19 PHP
一个很简单的办法实现TD的加亮效果.
2006/06/29 Javascript
jquery上传插件fineuploader上传文件使用方法(jquery图片上传插件)
2013/12/05 Javascript
jQuery中:last-child选择器用法实例
2014/12/31 Javascript
jquery插件validation实现验证身份证号等
2015/06/04 Javascript
javascript去掉代码里面的注释
2015/07/24 Javascript
js实现模拟银行卡账号输入显示效果
2015/11/18 Javascript
AngularJS在IE8的不支持的解决方法
2016/05/13 Javascript
BootStrap智能表单实战系列(四)表单布局介绍
2016/06/13 Javascript
浅析JavaScript函数的调用模式
2016/08/10 Javascript
使用JQ完成表格隔行换色的简单实例
2017/08/25 Javascript
jQuery选择器之子元素选择器详解
2017/09/18 jQuery
angular2模块和共享模块详解
2018/04/08 Javascript
Python使用smtplib模块发送电子邮件的流程详解
2016/06/27 Python
python 全文检索引擎详解
2017/04/25 Python
Python实现的将文件每一列写入列表功能示例【测试可用】
2018/03/19 Python
在python中实现对list求和及求积
2018/11/14 Python
python自定义函数实现最大值的输出方法
2019/07/09 Python
如何在python中实现随机选择
2019/11/02 Python
Python装饰器结合递归原理解析
2020/07/02 Python
Python 的 f-string 可以连接字符串与数字的原因解析
2021/02/20 Python
浅谈Python xlwings 读取Excel文件的正确姿势
2021/02/26 Python
创业计划书——互联网商机
2014/01/12 职场文书
大课间活动制度
2014/01/18 职场文书
护士毕业自我鉴定
2014/02/07 职场文书
爱心活动计划书
2014/04/26 职场文书
家长学校工作方案
2014/05/07 职场文书
保护环境倡议书500字
2014/05/19 职场文书
大学第二课堂活动总结
2014/07/08 职场文书
2014年教师节演讲稿
2014/09/03 职场文书
职工食堂管理制度
2015/08/06 职场文书
mysql的MVCC多版本并发控制的实现
2021/04/14 MySQL
Python机器学习之逻辑回归
2021/05/11 Python
SQLServer权限之只开启创建表权限
2022/04/12 SQL Server