详解tensorflow实现迁移学习实例


Posted in Python onFebruary 10, 2018

本文主要是总结利用tensorflow实现迁移学习的基本步骤。

所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。

持久化

首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。

保存图

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  saver.save(sess, "model.ckpt")

加载图

saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
  saver.restore(sess, "model.ckpt")

迁移学习

第一步: 读取加载已经训练好的模型

在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
  with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。

def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
  bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})

  # 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
  bottlenect_values = np.squeeze(bottlenect_values)
  return bottlenect_values

该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。

第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)

以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现模拟时钟代码推荐
Nov 08 Python
Python爬虫动态ip代理防止被封的方法
Jul 07 Python
详解基于python-django框架的支付宝支付案例
Sep 23 Python
Pycharm+Python+PyQt5使用详解
Sep 25 Python
Python如何计算语句执行时间
Nov 22 Python
pytorch forward两个参数实例
Jan 17 Python
python3 正则表达式基础廖雪峰
Mar 25 Python
Numpy 理解ndarray对象的示例代码
Apr 03 Python
Linux安装Python3如何和系统自带的Python2并存
Jul 23 Python
python利用递归方法实现求集合的幂集
Sep 07 Python
PyTorch 实现L2正则化以及Dropout的操作
May 27 Python
详解Python常用的魔法方法
Jun 03 Python
Python学习之Django的管理界面代码示例
Feb 10 #Python
Tensorflow 自带可视化Tensorboard使用方法(附项目代码)
Feb 10 #Python
tensorflow训练中出现nan问题的解决
Feb 10 #Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
You might like
法国:浪漫之都的咖啡文化
2021/03/03 咖啡文化
两个强悍的php 图像处理类1
2009/06/15 PHP
ecshop 订单确认中显示省市地址信息的方法
2010/03/15 PHP
php缩放图片(根据宽高的等比例缩放)实例介绍
2013/06/09 PHP
php实现图形显示Ip地址的代码及注释
2014/01/20 PHP
php+javascript实现的动态显示服务器运行程序进度条功能示例
2017/08/07 PHP
php表单处理操作
2017/11/16 PHP
PHPMAILER实现PHP发邮件功能
2018/04/18 PHP
PHP实现无限极分类的两种方式示例【递归和引用方式】
2019/03/25 PHP
JavaScript 程序编码规范
2010/11/23 Javascript
JS插件overlib用法实例详解
2015/12/26 Javascript
基于javascript实现listbox左右移动
2016/01/29 Javascript
JS中跨页面调用变量和函数的方法(例如a.js 和 b.js中互相调用)
2016/11/01 Javascript
详解vue项目构建与实战
2017/06/27 Javascript
Angular4开发解决跨域问题详解
2017/08/28 Javascript
360提示[高危]使用存在漏洞的JQuery版本的解决方法
2017/10/27 jQuery
基于vue-cli vue-router搭建底部导航栏移动前端项目
2018/02/28 Javascript
详解vue-cli 3.0 build包太大导致首屏过长的解决方案
2018/11/10 Javascript
微信小程序实现省市区三级地址选择
2020/06/21 Javascript
详解ES6 export default 和 import语句中的解构赋值
2019/05/28 Javascript
JS实现动态无缝轮播
2020/01/11 Javascript
Jquery滑动门/tab切换实现方法完整示例
2020/06/05 jQuery
three.js 将图片马赛克化的示例代码
2020/07/31 Javascript
python实现的阳历转阴历(农历)算法
2014/04/25 Python
python 接口_从协议到抽象基类详解
2017/08/24 Python
python实现单向链表详解
2018/02/08 Python
Centos 升级到python3后pip 无法使用的解决方法
2018/06/12 Python
详解Django模版中加载静态文件配置方法
2019/07/21 Python
Python全面分析系统的时域特性和频率域特性
2020/02/26 Python
配置python的编程环境之Anaconda + VSCode的教程
2020/03/29 Python
小天鹅官方商城:LittleSwan
2017/06/16 全球购物
June Jacobs尊积帕官网:知名的spa水疗护肤品牌
2019/03/21 全球购物
幼儿园父亲节活动方案
2014/03/11 职场文书
奥巴马当选演讲稿
2014/09/10 职场文书
幼儿园教师节活动总结
2015/03/23 职场文书
2015年语文教师工作总结
2015/05/25 职场文书