详解tensorflow实现迁移学习实例


Posted in Python onFebruary 10, 2018

本文主要是总结利用tensorflow实现迁移学习的基本步骤。

所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。

持久化

首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。

保存图

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  saver.save(sess, "model.ckpt")

加载图

saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
  saver.restore(sess, "model.ckpt")

迁移学习

第一步: 读取加载已经训练好的模型

在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
  with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。

def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
  bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})

  # 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
  bottlenect_values = np.squeeze(bottlenect_values)
  return bottlenect_values

该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。

第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)

以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程中的文件读写及相关的文件对象方法讲解
Jan 19 Python
Python使用asyncio包处理并发详解
Sep 09 Python
Python数据分析之双色球统计两个红和蓝球哪组合比例高的方法
Feb 03 Python
Python学习_几种存取xls/xlsx文件的方法总结
May 03 Python
python写日志文件操作类与应用示例
Jul 01 Python
Django获取该数据的上一条和下一条方法
Aug 12 Python
解决python中的幂函数、指数函数问题
Nov 25 Python
pytorch下大型数据集(大型图片)的导入方式
Jan 08 Python
python字符串常用方法及文件简单读写的操作方法
Mar 04 Python
Python常用base64 md5 aes des crc32加密解密方法汇总
Nov 06 Python
Pycharm创建python文件自动添加日期作者等信息(步骤详解)
Feb 03 Python
Python绘制分类图的方法
Apr 20 Python
Python学习之Django的管理界面代码示例
Feb 10 #Python
Tensorflow 自带可视化Tensorboard使用方法(附项目代码)
Feb 10 #Python
tensorflow训练中出现nan问题的解决
Feb 10 #Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
You might like
如何正确理解PHP的错误信息
2006/10/09 PHP
正则表达式语法
2006/10/09 Javascript
PHP 如何向 MySQL 发送数据
2006/10/09 PHP
phpMyAdmin 安装及问题总结
2009/05/28 PHP
php使用ob_start()实现图片存入变量的方法
2014/11/14 PHP
分享一则PHP定义函数代码
2015/02/26 PHP
php正则去除网页中所有的html,js,css,注释的实现方法
2016/11/03 PHP
PHP批量删除jQuery操作
2017/07/23 PHP
用js计算页面执行时间的函数
2006/12/07 Javascript
ajaxControlToolkit AutoCompleteExtender的用法
2008/10/30 Javascript
几种延迟加载JS代码的方法加快网页的访问速度
2013/10/12 Javascript
node.js中的path.basename方法使用说明
2014/12/09 Javascript
JavaScript探测CSS动画是否已经完成的方法
2016/08/30 Javascript
通过Ajax使用FormData对象无刷新上传文件方法
2016/12/08 Javascript
jQuery EasyUI 组件加上“清除”功能实例详解
2017/04/11 jQuery
JavaScript+CSS相册特效实例代码
2017/09/07 Javascript
微信小程序实现全国机场索引列表
2018/01/31 Javascript
详解node.js 下载图片的 2 种方式
2018/03/02 Javascript
对mac下nodejs 更新到最新版本的最新方法(推荐)
2018/05/17 NodeJs
详解Vue+axios+Node+express实现文件上传(用户头像上传)
2018/08/10 Javascript
Python的Flask框架中web表单的教程
2015/04/20 Python
Java多线程编程中ThreadLocal类的用法及深入
2016/06/21 Python
python连接PostgreSQL数据库的过程详解
2019/09/18 Python
HTML中使用SVG与SVG预定义形状元素介绍
2013/06/28 HTML / CSS
amazeui页面分析之登录页面的示例代码
2020/08/25 HTML / CSS
德国领先的大尺码和超大尺码男装在线零售商:Bigtex
2019/06/22 全球购物
信息管理专业学生自荐信格式
2013/09/22 职场文书
物流管理专业职业生涯规划书
2014/01/06 职场文书
会计学专业学生的求职信范文
2014/01/27 职场文书
运动会入场词60字
2014/02/15 职场文书
会计个人实习计划书
2014/08/15 职场文书
“四风”查摆问题自我剖析材料
2014/09/27 职场文书
2014年幼儿园个人工作总结
2014/11/10 职场文书
内乡县衙导游词
2015/02/05 职场文书
使用 CSS 轻松实现一些高频出现的奇形怪状按钮
2021/12/06 HTML / CSS
MySQL数据库简介与基本操作
2022/05/30 MySQL