详解tensorflow实现迁移学习实例


Posted in Python onFebruary 10, 2018

本文主要是总结利用tensorflow实现迁移学习的基本步骤。

所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。

持久化

首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。

保存图

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  saver.save(sess, "model.ckpt")

加载图

saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
  saver.restore(sess, "model.ckpt")

迁移学习

第一步: 读取加载已经训练好的模型

在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
  with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。

def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
  bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})

  # 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
  bottlenect_values = np.squeeze(bottlenect_values)
  return bottlenect_values

该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。

第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)

以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取apk文件URL地址实例
Nov 01 Python
python中sets模块的用法实例
Sep 30 Python
Python中使用haystack实现django全文检索搜索引擎功能
Aug 26 Python
Python 修改列表中的元素方法
Jun 26 Python
Python创建一个空的dataframe,并循环赋值的方法
Nov 08 Python
浅谈Pandas:Series和DataFrame间的算术元素
Dec 22 Python
使用python读取.text文件特定行的数据方法
Jan 28 Python
使用TensorFlow直接获取处理MNIST数据方式
Feb 10 Python
Python selenium使用autoIT上传附件过程详解
May 26 Python
PyCharm MySQL可视化Database配置过程图解
Jun 09 Python
Python监听剪切板实现方法代码实例
Nov 11 Python
Python图像读写方法对比
Nov 16 Python
Python学习之Django的管理界面代码示例
Feb 10 #Python
Tensorflow 自带可视化Tensorboard使用方法(附项目代码)
Feb 10 #Python
tensorflow训练中出现nan问题的解决
Feb 10 #Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
You might like
php删除与复制文件夹及其文件夹下所有文件的实现代码
2013/01/23 PHP
CodeIgniter安全相关设置汇总
2014/07/03 PHP
Fleaphp常见函数功能与用法示例
2016/11/15 PHP
Gambit vs ForZe BO3 第二场 2.13
2021/03/10 DOTA
悄悄用脚本检查你访问过哪些网站的代码
2010/12/04 Javascript
改善用户体验的五款jQuery插件分享
2011/05/22 Javascript
JQuery选中checkbox方法代码实例(全选、反选、全不选)
2015/04/27 Javascript
用js动态添加html元素,以及属性的简单实例
2016/07/19 Javascript
微信小程序 Record API详解及实例代码
2016/09/30 Javascript
利用jQuery异步上传文件的插件用法详解
2017/07/19 jQuery
nodejs动态创建二维码的方法
2017/08/12 NodeJs
js实现一个简单的MVVM框架示例
2018/01/15 Javascript
vue实现Excel文件的上传与下载功能的两种方式
2019/06/28 Javascript
微信小程序跨页面数据传递事件响应实现过程解析
2019/12/19 Javascript
Python和Ruby中each循环引用变量问题(一个隐秘BUG?)
2014/06/04 Python
Python对象体系深入分析
2014/10/28 Python
Python and、or以及and-or语法总结
2015/04/14 Python
利用python写个下载teahour音频的小脚本
2017/05/08 Python
python 字典中取值的两种方法小结
2018/08/02 Python
Python产生一个数值范围内的不重复的随机数的实现方法
2019/08/21 Python
python实现两个字典合并,两个list合并
2019/12/02 Python
python GUI库图形界面开发之PyQt5多线程中信号与槽的详细使用方法与实例
2020/03/08 Python
python FTP编程基础入门
2021/02/27 Python
马来西亚银饰品牌:JEOEL
2017/12/15 全球购物
好莱坞百老汇御用王牌美妆:Koh Gen Do 江原道
2018/04/03 全球购物
香港士多网上超级市场:Ztore
2021/01/09 全球购物
ShellScript面试题一则-ShellScript编程
2014/06/24 面试题
六一儿童节活动策划方案
2014/01/27 职场文书
宿舍保安职务说明书
2014/02/25 职场文书
电力培训心得体会
2014/09/02 职场文书
带香烟到学校抽的检讨书
2014/09/25 职场文书
公司委托书格式范文
2014/10/09 职场文书
《夹竹桃》教学反思
2016/02/23 职场文书
python人工智能human learn绘图可创建机器学习模型
2021/11/23 Python
MySQL的prepare使用以及遇到的bug
2022/05/11 MySQL
windows server 2012安装FTP并配置被动模式指定开放端口
2022/06/10 Servers