浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字典的常用操作方法小结
May 16 Python
Python中使用platform模块获取系统信息的用法教程
Jul 08 Python
python 实现自动远程登陆scp文件实例代码
Mar 13 Python
TF-IDF算法解析与Python实现方法详解
Nov 16 Python
基于DATAFRAME中元素的读取与修改方法
Jun 08 Python
Python实现对文件进行单词划分并去重排序操作示例
Jul 10 Python
python 实现一次性在文件中写入多行的方法
Jan 28 Python
Python中使用双下划线防止类属性被覆盖问题
Jun 27 Python
Pycharm+Python工程,引用子模块的实现
Mar 09 Python
Pycharm中切换pytorch的环境和配置的教程详解
Mar 13 Python
Python第三方库的几种安装方式(小结)
Apr 03 Python
Keras 切换后端方式(Theano和TensorFlow)
Jun 19 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
php array_pop()数组函数将数组最后一个单元弹出(出栈)
2011/07/12 PHP
php操作xml
2013/10/27 PHP
PHP函数in_array()使用详解
2014/08/20 PHP
PHP随机生成信用卡卡号的方法
2015/03/23 PHP
django中的ajax组件教程详解
2018/10/18 PHP
经典的解除许多网站无法复制文字的绝招
2006/12/31 Javascript
JQuery 动态扩展对象之另类视角
2010/05/25 Javascript
ff下JQuery无法监听input的keyup事件的解决方法
2013/12/12 Javascript
js获取select标签选中值的两种方式
2014/01/09 Javascript
javascript函数声明和函数表达式区别分析
2014/12/02 Javascript
AspNet中使用JQuery上传插件Uploadify详解
2015/05/20 Javascript
简述AngularJS相关的一些编程思想
2015/06/23 Javascript
JS随机打乱数组的方法小结
2016/06/22 Javascript
纯JavaScript 实现flappy bird小游戏实例代码
2016/09/27 Javascript
利用vue.js实现被选中状态的改变方法
2018/02/08 Javascript
vue-rx的初步使用教程
2018/09/21 Javascript
详解如何在Node.js的httpServer中接收前端发送的arraybuffer数据
2018/11/11 Javascript
vue-cli3 项目优化之通过 node 自动生成组件模板 generate View、Component
2019/04/30 Javascript
vue elementUI 表单校验功能之数组多层嵌套
2019/06/04 Javascript
详解小程序如何动态绑定点击的执行方法
2019/11/26 Javascript
[05:17]DOTA2睡衣妹卖萌求签名 CJ第二天全明星影像
2013/07/28 DOTA
[05:39]2014DOTA2国际邀请赛 DK晋级胜者组专访战队国士无双
2014/07/14 DOTA
详解Python中find()方法的使用
2015/05/18 Python
python中dir()与__dict__属性的区别浅析
2018/12/10 Python
Python实现监控Nginx配置文件的不同并发送邮件报警功能示例
2019/02/26 Python
Anaconda之conda常用命令介绍(安装、更新、删除)
2019/10/06 Python
欧克利英国官网:Oakley英国
2019/08/24 全球购物
介绍一下XMLHttpRequest对象的常用方法和属性
2013/05/24 面试题
理工学院学生自我鉴定
2014/02/23 职场文书
中秋寄语大全
2014/04/11 职场文书
党的群众路线对照检查材料
2014/08/27 职场文书
升学宴学生答谢词
2015/01/05 职场文书
2015年度考核个人工作总结
2015/10/24 职场文书
《去年的树》教学反思
2016/02/18 职场文书
golang为什么要统一错误处理
2022/04/03 Golang
SQL Server #{}可以防止SQL注入
2022/05/11 SQL Server