浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的包和模块实例
Nov 22 Python
Python for Informatics 第11章之正则表达式(二)
Apr 21 Python
python魔法方法-自定义序列详解
Jul 21 Python
使用pyecharts在jupyter notebook上绘图
Apr 23 Python
Python反射的用法实例分析
Feb 11 Python
解决pycharm 误删掉项目文件的处理方法
Oct 22 Python
Python实现操纵控制windows注册表的方法分析
May 24 Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 Python
利用python批量爬取百度任意类别的图片的实现方法
Oct 07 Python
tensorflow2.0教程之Keras快速入门
Feb 20 Python
python3实现常见的排序算法(示例代码)
Jul 04 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
CI框架安全类Security.php源码分析
2014/11/04 PHP
php常用hash加密函数
2014/11/22 PHP
php实现跨域提交form表单的方法【2种方法】
2016/10/17 PHP
PHP实现笛卡尔积算法的实例讲解
2019/12/22 PHP
js substring从右边获取指定长度字符串(示例代码)
2013/12/23 Javascript
详解AngularJS中的表格使用
2015/06/16 Javascript
分享纯手写漂亮的表单验证
2015/11/19 Javascript
浅谈时钟的生成(js手写简洁代码)
2016/08/20 Javascript
详细谈谈AngularJS的子级作用域问题
2016/09/05 Javascript
利用jquery如何从json中读取数据追加到html中
2017/12/01 jQuery
vue项目中api接口管理总结
2018/04/20 Javascript
JS实现快递单打印功能【推荐】
2018/06/21 Javascript
原生js实现Flappy Bird小游戏
2018/12/24 Javascript
jQuery实现当拉动滚动条到底部加载数据的方法分析
2019/01/24 jQuery
vue鼠标悬停事件实例详解
2019/04/01 Javascript
js编写简易的计算器
2020/07/29 Javascript
在vant 中使用cell组件 定义图标该图片和位置操作
2020/11/02 Javascript
[01:03:47]VP vs NewBee Supermajor 胜者组 BO3 第一场 6.5
2018/06/06 DOTA
python使用wmi模块获取windows下硬盘信息的方法
2015/05/15 Python
基于python神经卷积网络的人脸识别
2018/05/24 Python
numpy中矩阵合并的实例
2018/06/15 Python
python opencv旋转图像(保持图像不被裁减)
2018/07/26 Python
啥是佩奇?使用Python自动绘画小猪佩奇的代码实例
2019/02/20 Python
python通过http下载文件的方法详解
2019/07/26 Python
详解python中的数据类型和控制流
2019/08/08 Python
使用python-pptx包批量修改ppt格式的实现
2020/02/14 Python
opencv 图像加法与图像融合的实现代码
2020/07/08 Python
HTML5中5个简单实用的API(第二篇,含全屏、可见性、拍照、预加载、电池状态)
2014/05/07 HTML / CSS
复古服装:RetroStage
2019/05/10 全球购物
定义一结构体变量,用其表示点坐标,并输入两点坐标,求两点之间的距离
2015/08/17 面试题
UNIX命令速查表
2012/03/10 面试题
一封普通求职者的求职信
2013/11/20 职场文书
公司门卫的岗位职责
2014/02/19 职场文书
事业单位工作人员年度考核个人总结
2015/02/12 职场文书
高三化学教学反思
2016/02/22 职场文书
电脑无法安装Windows 11怎么办?无法安装Win11的解决方法
2021/11/21 数码科技