浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本实现数据导出excel格式的简单方法(推荐)
Dec 30 Python
windows下python安装paramiko模块和pycrypto模块(简单三步)
Jul 06 Python
python 实现在txt指定行追加文本的方法
Apr 29 Python
python基于物品协同过滤算法实现代码
May 31 Python
使用python进行文本预处理和提取特征的实例
Jun 05 Python
基于python log取对数详解
Jun 08 Python
python快排算法详解
Mar 04 Python
Python对象转换为json的方法步骤
Apr 25 Python
python控制nao机器人身体动作实例详解
Apr 29 Python
python画图的函数用法以及技巧
Jun 28 Python
如何利用Python开发一个简单的猜数字游戏
Sep 22 Python
如何用python免费看美剧
Aug 11 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
phpwind中的数据库操作类
2007/01/02 PHP
PHP PDOStatement对象bindpram()、bindvalue()和bindcolumn之间的区别
2014/11/20 PHP
php将图片文件转换成二进制输出的方法
2015/06/10 PHP
CI映射(加载)数据到view层的方法
2016/03/28 PHP
Laravel的throttle中间件失效问题解决方法
2016/10/09 PHP
php使用parse_str实现查询字符串解析到变量中的方法
2017/02/17 PHP
[原创]php token使用与验证示例【测试可用】
2017/08/30 PHP
解决thinkphp5未定义变量会抛出异常,页面错误,请稍后再试的问题
2019/10/16 PHP
图片格式的JavaScript和CSS速查手册
2007/08/20 Javascript
JavaScript中的this关键字介绍与使用实例
2013/06/21 Javascript
网站接入QQ登录的两种方法
2014/07/22 Javascript
JavaScript中Number.MAX_VALUE属性的使用方法
2015/06/04 Javascript
JavaScript中String.match()方法的使用详解
2015/06/06 Javascript
JS基于面向对象实现的拖拽库实例
2015/09/24 Javascript
JS+CSS实现的蓝色table选项卡效果
2015/10/08 Javascript
JavaScript实现页面定时刷新(定时器,meta)
2016/10/12 Javascript
JavaScript实现翻页功能(附效果图)
2017/02/16 Javascript
Angular实现下拉框模糊查询功能示例
2018/01/03 Javascript
详解Angular调试技巧之报错404(not found)
2018/01/31 Javascript
Json实现传值到后台代码实例
2020/06/30 Javascript
[02:37]2015国际邀请赛选手档案—LGD.Xiao8
2015/07/28 DOTA
详解Python中contextlib上下文管理模块的用法
2016/06/28 Python
老生常谈Python进阶之装饰器
2017/05/11 Python
python3+PyQt5重新实现QT事件处理程序
2018/04/19 Python
python3下使用cv2.imwrite存储带有中文路径图片的方法
2018/05/10 Python
Python3读写ini配置文件的示例
2020/11/06 Python
Baracuta官方网站:Harrington夹克,G9,G4,G10等
2018/03/06 全球购物
意大利一家专营包包和配饰的网上商店:Borse Last Minute
2019/08/26 全球购物
什么是测试驱动开发(TDD)
2012/02/15 面试题
秘书英文求职信范文
2014/01/31 职场文书
简历中的自我评价范文
2014/02/05 职场文书
预防艾滋病宣传标语
2014/06/25 职场文书
2014年教师教学工作总结
2014/11/08 职场文书
办公用房租赁协议书
2014/11/29 职场文书
军训后的感想
2015/08/07 职场文书
家庭聚会祝酒词
2015/08/11 职场文书