浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python PyQt5实现的简易计算器功能示例
Aug 23 Python
Python实现的多线程同步与互斥锁功能示例
Nov 30 Python
让Python更加充分的使用Sqlite3
Dec 11 Python
Python基于lxml模块解析html获取页面内所有叶子节点xpath路径功能示例
May 16 Python
python 中字典嵌套列表的方法
Jul 03 Python
mac下如何将python2.7改为python3
Jul 13 Python
Pandas分组与排序的实现
Jul 23 Python
使用Python给头像戴上圣诞帽的图像操作过程解析
Sep 20 Python
Python基础之字典常见操作经典实例详解
Feb 26 Python
Python 统计位数为偶数的数字代码详解
Mar 15 Python
利用Python判断你的密码难度等级
Jun 02 Python
一起来学习Python的元组和列表
Mar 13 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
JAVA/JSP学习系列之六
2006/10/09 PHP
php visitFile()遍历指定文件夹函数
2010/08/21 PHP
Zend Framework实现多服务器共享SESSION数据的方法
2016/03/22 PHP
PHP 应用容器化以及部署方法
2018/02/12 PHP
PHP htmlspecialchars() 函数实例代码及用法大全
2018/09/18 PHP
laravel5.1框架基础之Blade模板继承简单使用方法分析
2019/09/05 PHP
PHP程序员简单的开展服务治理架构操作详解(三)
2020/05/14 PHP
添加到收藏夹代码(兼容几乎所有的浏览器)
2007/01/09 Javascript
Javascript-Mozilla和IE中的一个函数直接量的问题分析
2007/08/12 Javascript
JavaScript实现快速排序(自已编写)
2012/12/19 Javascript
javascript打印html内容功能的方法示例
2013/11/28 Javascript
ie 7/8不支持trim的属性的解决方案
2014/05/23 Javascript
JavaScript中的函数嵌套使用
2015/06/04 Javascript
jQuery each函数源码分析
2016/05/25 Javascript
Angularjs中使用轮播图指令swiper
2017/05/30 Javascript
Vue+Vux项目实践完整代码
2017/11/30 Javascript
webpack组织模块打包Library的原理及实现
2018/03/10 Javascript
使用Vue如何写一个双向数据绑定(面试常见)
2018/04/20 Javascript
jquery实现自定义树形表格的方法【自定义树形结构table】
2019/07/12 jQuery
JS函数进阶之prototy用法实例分析
2020/01/15 Javascript
JS数组方法concat()用法实例分析
2020/01/18 Javascript
vue实现下载文件流完整前后端代码
2020/11/17 Vue.js
[41:11]完美世界DOTA2联赛PWL S2 Inki vs Magma 第一场 11.22
2020/11/24 DOTA
跟老齐学Python之变量和参数
2014/10/10 Python
使用Python的Scrapy框架十分钟爬取美女图
2016/12/26 Python
django的登录注册系统的示例代码
2018/05/14 Python
Python实现统计英文文章词频的方法分析
2019/01/28 Python
详解python中的hashlib模块的使用
2019/04/22 Python
python自动化测试之如何解析excel文件
2019/06/27 Python
HTML5 拖放功能实现代码
2016/07/14 HTML / CSS
台湾乐天市场:日本No.1的网路购物网站
2017/03/22 全球购物
文化建设工作方案
2014/05/12 职场文书
低碳日宣传活动总结
2014/07/09 职场文书
幼儿生日活动方案
2014/08/27 职场文书
群众路线教育实践活动心得体会(四风)
2014/11/03 职场文书
Mysql中where与on的区别及何时使用详析
2021/08/04 MySQL