详解tensorflow实现迁移学习实例


Posted in Python onFebruary 10, 2018

本文主要是总结利用tensorflow实现迁移学习的基本步骤。

所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。

持久化

首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。

保存图

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  saver.save(sess, "model.ckpt")

加载图

saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
  saver.restore(sess, "model.ckpt")

迁移学习

第一步: 读取加载已经训练好的模型

在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
  with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。

def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
  bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})

  # 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
  bottlenect_values = np.squeeze(bottlenect_values)
  return bottlenect_values

该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。

第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)

以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理文本文件并生成指定格式的文件
Jul 31 Python
python之wxPython应用实例
Sep 28 Python
Python中遇到的小问题及解决方法汇总
Jan 11 Python
详谈python在windows中的文件路径问题
Apr 28 Python
numpy 计算两个数组重复程度的方法
Nov 07 Python
使用Python向C语言的链接库传递数组、结构体、指针类型的数据
Jan 29 Python
python代码编写计算器小程序
Mar 30 Python
python单元测试框架pytest的使用示例
Oct 07 Python
python smtplib发送多个email联系人的实现
Oct 09 Python
python实现局部图像放大
Nov 17 Python
利用Python实现模拟登录知乎
May 25 Python
使用pd.merge表连接出现多余行的问题解决
Jun 16 Python
Python学习之Django的管理界面代码示例
Feb 10 #Python
Tensorflow 自带可视化Tensorboard使用方法(附项目代码)
Feb 10 #Python
tensorflow训练中出现nan问题的解决
Feb 10 #Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
You might like
PHP删除非空目录的函数代码小结
2013/02/28 PHP
php导出word文档与excel电子表格的简单示例代码
2014/03/08 PHP
基于Laravel实现的用户动态模块开发
2017/09/21 PHP
DOM 事件流详解
2015/01/20 Javascript
cookie的secure属性详解
2015/04/08 Javascript
以JavaScript来实现WordPress中的二级导航菜单的方法
2015/12/14 Javascript
jquery模拟实现鼠标指针停止运动事件
2016/01/12 Javascript
Javascript实现图片加载从模糊到清晰显示的方法
2016/06/21 Javascript
JS高仿抛物线加入购物车特效实现代码
2017/02/20 Javascript
DVA框架统一处理所有页面的loading状态
2017/08/25 Javascript
微信小程序实现轮播图效果
2017/09/07 Javascript
javascript将list转换成树状结构的实例
2017/09/08 Javascript
浅析Vue自定义组件的v-model
2017/11/26 Javascript
如何快速解决JS或Jquery ajax异步跨域的问题
2018/01/08 jQuery
详解Chai.js断言库API中文文档
2018/01/31 Javascript
vue生命周期与钩子函数简单示例
2019/03/13 Javascript
js实现二级联动简单实例
2020/01/11 Javascript
JavaScript中reduce()的5个基本用法示例
2020/07/19 Javascript
javascript实现页面的实时时钟显示示例
2020/08/06 Javascript
Vue实现腾讯云点播视频上传功能的实现代码
2020/08/17 Javascript
Python 列表(List)操作方法详解
2014/03/11 Python
Python中%r和%s的详解及区别
2017/03/16 Python
Apache如何部署django项目
2017/05/21 Python
Python3中的bytes和str类型详解
2019/05/02 Python
python join方法使用详解
2019/07/30 Python
TensorFlow内存管理bfc算法实例
2020/02/03 Python
Python安装whl文件过程图解
2020/02/18 Python
Grid 宫格常用布局的实现
2020/01/10 HTML / CSS
canvas之自定义头像功能实现代码示例
2017/09/29 HTML / CSS
Chemist Warehouse官方海外旗舰店:澳洲第一连锁大药房
2017/08/25 全球购物
澳大利亚领先的在线机械五金、园艺和存储专家:Edisons
2018/03/24 全球购物
prAna官网:瑜伽、旅行和冒险服装
2019/03/10 全球购物
建筑学推荐信
2013/11/03 职场文书
个人剖析材料及整改措施
2014/10/07 职场文书
2015年计划生育协会工作总结
2015/05/13 职场文书
办公室主任岗位竞聘书
2015/09/15 职场文书