详解tensorflow实现迁移学习实例


Posted in Python onFebruary 10, 2018

本文主要是总结利用tensorflow实现迁移学习的基本步骤。

所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。

持久化

首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。

保存图

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  saver.save(sess, "model.ckpt")

加载图

saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
  saver.restore(sess, "model.ckpt")

迁移学习

第一步: 读取加载已经训练好的模型

在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
  with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。

def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
  bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})

  # 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
  bottlenect_values = np.squeeze(bottlenect_values)
  return bottlenect_values

该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。

第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)

以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用any判断一个对象是否为空的方法
Nov 19 Python
Python3中多线程编程的队列运作示例
Apr 16 Python
Scrapy-redis爬虫分布式爬取的分析和实现
Feb 07 Python
Python中文分词工具之结巴分词用法实例总结【经典案例】
Apr 15 Python
Python 实现两个服务器之间文件的上传方法
Feb 13 Python
PyTorch和Keras计算模型参数的例子
Jan 02 Python
pytorch实现mnist分类的示例讲解
Jan 10 Python
如何基于Python代码实现高精度免费OCR工具
Jun 18 Python
Python描述数据结构学习之哈夫曼树篇
Sep 07 Python
Python将list元素转存为CSV文件的实现
Nov 16 Python
全面介绍python中很常用的单元测试框架unitest
Dec 14 Python
Python爬虫实战之爬取京东商品数据并实实现数据可视化
Jun 07 Python
Python学习之Django的管理界面代码示例
Feb 10 #Python
Tensorflow 自带可视化Tensorboard使用方法(附项目代码)
Feb 10 #Python
tensorflow训练中出现nan问题的解决
Feb 10 #Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
You might like
php 判断字符串编码是utf-8 或gb2312实例
2016/11/01 PHP
JavaScript 三种不同位置代码的写法
2009/10/25 Javascript
boxy基于jquery的弹出层对话框插件扩展应用 弹出层选择器
2010/11/21 Javascript
Javascript 面试题随笔
2011/03/31 Javascript
jQuery中需要注意的细节问题小结
2011/12/06 Javascript
js判断客户端是iOS还是Android等移动终端的方法
2013/12/11 Javascript
jquery制作搜狐快站页面效果示例分享
2014/02/21 Javascript
table行随鼠标移动变色示例
2014/05/07 Javascript
JS验证IP,子网掩码,网关和MAC的方法
2015/07/02 Javascript
JavaScript setTimeout使用闭包功能实现定时打印数值
2015/12/18 Javascript
再次谈论React.js实现原生js拖拽效果引起的一系列问题
2016/04/03 Javascript
jquery实现点击弹出可放大居中及关闭的对话框(附demo源码下载)
2016/05/10 Javascript
js实现图片淡入淡出切换简易效果
2016/08/22 Javascript
vue在路由中验证token是否存在的简单实现
2019/11/11 Javascript
Vue路由管理器Vue-router的使用方法详解
2020/02/05 Javascript
JQuery复选框全选效果如何实现
2020/05/08 jQuery
vue-quill-editor插入图片路径太长问题解决方法
2021/01/08 Vue.js
[42:36]DOTA2上海特级锦标赛B组败者赛 VG VS Spirit第二局
2016/02/26 DOTA
[37:02]OG vs INfamous 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
Python中设置变量作为默认值时容易遇到的错误
2015/04/03 Python
使用Python脚本来控制Windows Azure的简单教程
2015/04/16 Python
python 地图经纬度转换、纠偏的实例代码
2018/08/06 Python
windows系统中Python多版本与jupyter notebook使用虚拟环境的过程
2019/05/15 Python
pytorch实现mnist分类的示例讲解
2020/01/10 Python
Python 如何定义匿名或内联函数
2020/08/01 Python
骆驼官方商城:CAMEL
2016/11/22 全球购物
Dogeared官网:在美国手工制作的珠宝
2019/08/24 全球购物
财务部岗位职责
2013/11/19 职场文书
单位消防安全制度
2014/01/12 职场文书
小学毕业典礼主持词
2014/03/27 职场文书
关于青春的演讲稿800字
2014/08/22 职场文书
装饰公司活动策划方案
2014/08/23 职场文书
售房委托书
2014/08/30 职场文书
工人先锋号事迹材料(2016精选版)
2016/03/01 职场文书
学生检讨书范文
2019/06/24 职场文书
python热力图实现的完整实例
2022/06/25 Python