浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python统计日志ip访问数的方法
Jul 06 Python
python实现机械分词之逆向最大匹配算法代码示例
Dec 13 Python
解决pycharm回车之后不能换行或不能缩进的问题
Jan 16 Python
详解Python3 对象组合zip()和回退方式*zip
May 15 Python
python面向对象 反射原理解析
Aug 12 Python
关于Python形参打包与解包小技巧分享
Aug 24 Python
Python有参函数使用代码实例
Jan 06 Python
Pyecharts 动态地图 geo()和map()的安装与用法详解
Mar 25 Python
python 用opencv实现图像修复和图像金字塔
Nov 27 Python
Python 实现一个简单的web服务器
Jan 03 Python
python中HTMLParser模块知识点总结
Jan 25 Python
Python 爬取淘宝商品信息栏目的实现
Feb 06 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
使用apache模块rewrite_module (转)
2007/02/14 PHP
如何使用“PHP” 彩蛋进行敏感信息获取
2013/08/07 PHP
解决php接收shell返回的结果中文乱码问题
2014/01/23 PHP
php抽象类用法实例分析
2015/07/07 PHP
PHP中PDO事务处理操作示例
2018/05/02 PHP
Ajax::prototype 源码解读
2007/01/22 Javascript
Gird事件机制初级读本
2007/03/10 Javascript
JavaScript面向对象编程
2008/03/02 Javascript
jQuery学习笔记 操作jQuery对象 文档处理
2012/09/19 Javascript
Jquery动态更改一张位图的src与Attr的使用
2013/07/31 Javascript
js中substr,substring,indexOf,lastIndexOf的用法小结
2013/12/27 Javascript
Js获取下拉框选定项的值和文本的实现代码
2014/02/26 Javascript
JavaScript中5种调用函数的方法
2015/03/12 Javascript
悬浮广告方法日常收集整理
2016/03/18 Javascript
node.js 中国天气预报 简单实现
2016/06/06 Javascript
Bootstrap响应式侧边栏改进版
2016/09/17 Javascript
JavaScript实现横线提示输入验证码随输入验证码输入消失的方法
2016/09/24 Javascript
Nodejs之TCP服务端与客户端聊天程序详解
2017/07/07 NodeJs
vue-resource post数据时碰到Django csrf问题的解决
2020/03/13 Javascript
jQuery实现日历效果
2020/09/11 jQuery
利用node.js开发cli的完整步骤
2020/12/29 Javascript
[01:05:30]VP vs TNC 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
[01:03:33]Alliance vs TNC 2019国际邀请赛小组赛 BO2 第一场 8.16
2019/08/18 DOTA
tensorflow入门之训练简单的神经网络方法
2018/02/26 Python
Python读取txt某几列绘图的方法
2018/10/14 Python
python 获取页面表格数据存放到csv中的方法
2018/12/26 Python
英国现代家具和照明购物网站:Heal’s
2019/10/30 全球购物
水污染治理专业毕业生推荐信
2013/11/14 职场文书
公积金转移接收函
2014/01/11 职场文书
老同学聚会感言
2014/02/23 职场文书
2014年端午节演讲稿范文
2014/05/23 职场文书
维稳工作情况汇报
2014/10/27 职场文书
12.4全国法制宣传日活动方案
2014/11/02 职场文书
文艺委员竞选稿
2015/11/19 职场文书
react国际化react-intl的使用
2021/05/06 Javascript
python数字图像处理:图像简单滤波
2022/06/28 Python