详解tensorflow实现迁移学习实例


Posted in Python onFebruary 10, 2018

本文主要是总结利用tensorflow实现迁移学习的基本步骤。

所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。

持久化

首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。

保存图

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  saver.save(sess, "model.ckpt")

加载图

saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
  saver.restore(sess, "model.ckpt")

迁移学习

第一步: 读取加载已经训练好的模型

在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
  with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。

def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
  bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})

  # 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
  bottlenect_values = np.squeeze(bottlenect_values)
  return bottlenect_values

该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。

第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)

以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python压缩和解压缩zip文件
Feb 14 Python
深入探究Python中变量的拷贝和作用域问题
May 05 Python
在python win系统下 打开TXT文件的实例
Apr 29 Python
Python中实现单例模式的n种方式和原理
Nov 14 Python
Python爬虫动态ip代理防止被封的方法
Jul 07 Python
Python scipy的二维图像卷积运算与图像模糊处理操作示例
Sep 06 Python
win10下安装Anaconda的教程(python环境+jupyter_notebook)
Oct 23 Python
Python简易计算器制作方法代码详解
Oct 31 Python
调用其他python脚本文件里面的类和方法过程解析
Nov 15 Python
python3利用Axes3D库画3D模型图
Mar 25 Python
python怎么提高计算速度
Jun 11 Python
python 装饰器重要在哪
Feb 14 Python
Python学习之Django的管理界面代码示例
Feb 10 #Python
Tensorflow 自带可视化Tensorboard使用方法(附项目代码)
Feb 10 #Python
tensorflow训练中出现nan问题的解决
Feb 10 #Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
You might like
用PHP和MySQL保存和输出图片
2006/10/09 PHP
基于PHP常用函数的用法详解
2013/05/10 PHP
php防止SQL注入详解及防范
2013/11/12 PHP
php解析json数据实例
2014/08/19 PHP
php实现登录tplink WR882N获取IP和重启的方法
2016/07/20 PHP
php插入含有特殊符号数据的处理方法
2016/11/24 PHP
因str_replace导致的注入问题总结
2019/08/08 PHP
jQuery Tips 为AJAX回调函数传递额外参数的方法
2010/12/28 Javascript
jquery中的查找parents与closest方法之间的区别
2013/12/02 Javascript
jquery实现页面虚拟键盘特效
2015/08/08 Javascript
JavaScript实现弹出模态窗体并接受传值的方法
2016/02/12 Javascript
javascript中json基础知识详解
2017/01/19 Javascript
利用JS实现文字的聚合动画效果
2017/01/22 Javascript
在 Angular 中使用Chart.js 和 ng2-charts的示例代码
2017/08/17 Javascript
详解JS中的this、apply、call、bind(经典面试题)
2017/09/19 Javascript
JS运动特效之任意值添加运动的方法分析
2018/01/24 Javascript
vue-cli3项目配置eslint代码规范的完整步骤
2020/09/10 Javascript
python 数据加密代码
2008/12/24 Python
Python中使用tarfile压缩、解压tar归档文件示例
2015/04/05 Python
介绍Python中的__future__模块
2015/04/27 Python
Python中对元组和列表按条件进行排序的方法示例
2015/11/10 Python
Python开发如何在ubuntu 15.10 上配置vim
2016/01/25 Python
Python计算字符宽度的方法
2016/06/14 Python
python处理按钮消息的实例详解
2017/07/11 Python
运动检测ViBe算法python实现代码
2018/01/09 Python
Python 画出来六维图
2019/07/26 Python
Win10下python 2.7与python 3.7双环境安装教程图解
2019/10/12 Python
解锁canvas导出图片跨域的N种姿势小结
2019/01/24 HTML / CSS
Senreve官网:美国旧金山的奢侈手袋品牌
2019/03/21 全球购物
拾金不昧表扬信范文
2014/01/11 职场文书
石油大学毕业生自荐信
2014/01/28 职场文书
七匹狼男装广告词
2014/03/21 职场文书
支部书记四风对照材料
2014/08/28 职场文书
《女娲补天》教学反思
2016/02/20 职场文书
Python如何将list中的string转换为int
2022/07/15 Ruby
CSS 实现磨砂玻璃(毛玻璃)效果样式
2023/05/21 HTML / CSS