浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用Matplotlib实现Logos设计代码
Dec 25 Python
教你用一行Python代码实现并行任务(附代码)
Feb 02 Python
django数据库migrate失败的解决方法解析
Feb 08 Python
django用户登录和注销的实现方法
Jul 16 Python
pycharm执行python时,填写参数的方法
Oct 29 Python
Flask框架学习笔记之消息提示与异常处理操作详解
Aug 15 Python
解决python 上传图片限制格式问题
Oct 30 Python
Python如何把十进制数转换成ip地址
May 25 Python
python判断一个变量是否已经设置的方法
Aug 13 Python
使用python把xmind转换成excel测试用例的实现代码
Oct 12 Python
python 使用xlsxwriter循环向excel中插入数据和图片的操作
Jan 01 Python
python前后端自定义分页器
Apr 13 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
php知道与问问的采集插件代码
2010/10/12 PHP
php设计模式之单例、多例设计模式的应用分析
2013/06/30 PHP
详解WordPress中过滤链接与过滤SQL语句的方法
2015/12/18 PHP
PHP 实现公历日期与农历日期的互转换
2017/09/13 PHP
基于swoole实现多人聊天室
2018/06/14 PHP
JavaScript高级程序设计 阅读笔记(四) ECMAScript中的类型转换
2012/02/27 Javascript
5个最佳的Javascript日期处理类库分享
2012/04/15 Javascript
javascript 中String.match()与RegExp.exec()的区别说明
2013/01/10 Javascript
javascript中全局对象的parseInt()方法使用介绍
2013/12/19 Javascript
JS实现简单的顶部定时关闭层效果
2014/06/15 Javascript
使用jQuery Mobile框架开发移动端Web App的入门教程
2016/05/17 Javascript
javascript的document中的动态添加标签实现方法
2016/10/24 Javascript
深入理解react-router@4.0 使用和源码解析
2017/05/23 Javascript
php结合js实现多条件组合查询
2019/05/28 Javascript
Vue中实现权限控制的方法示例
2019/06/07 Javascript
nodejs和react实现即时通讯简易聊天室功能
2019/08/21 NodeJs
[01:33:30]DOTA2-DPC中国联赛 正赛 RNG vs Phoenix BO3 第二场 2月5日
2021/03/11 DOTA
初学Python实用技巧两则
2014/08/29 Python
Python爬取成语接龙类网站
2018/10/19 Python
pyqt5 获取显示器的分辨率的方法
2019/06/18 Python
使用 Supervisor 监控 Python3 进程方式
2019/12/05 Python
Python run()函数和start()函数的比较和差别介绍
2020/05/03 Python
Python filter过滤器原理及实例应用
2020/08/18 Python
Python 中 sorted 如何自定义比较逻辑
2021/02/02 Python
Ramy Brook官网:美国现代女装品牌
2019/06/18 全球购物
远程调用的原理
2014/07/05 面试题
致铅球运动员广播稿精选
2014/01/12 职场文书
村官学习十八大感想
2014/01/15 职场文书
庆祝教师节活动方案
2014/01/31 职场文书
优秀教师事迹简介
2014/02/02 职场文书
《日月潭》教学反思
2014/02/28 职场文书
乡镇三项教育实施方案
2014/03/30 职场文书
四风批评与自我批评范文
2014/10/14 职场文书
生鲜超市—未来中国最具有潜力零售业态
2019/08/02 职场文书
自己搭建resnet18网络并加载torchvision自带权重的操作
2021/05/13 Python
js中Object.create实例用法详解
2021/10/05 Javascript