浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python开发WebService系列教程之REST,web.py,eurasia,Django
Jun 30 Python
python通过colorama模块在控制台输出彩色文字的方法
Mar 19 Python
Python素数检测实例分析
Jun 15 Python
Python实现简单的代理服务器
Jul 25 Python
分析Python中解析构建数据知识
Jan 20 Python
python操作excel让工作自动化
Aug 09 Python
Python3常用内置方法代码实例
Nov 18 Python
python logging设置level失败的解决方法
Feb 19 Python
python GUI库图形界面开发之PyQt5布局控件QGridLayout详细使用方法与实例
Mar 06 Python
使用python处理题库表格并转化为word形式的实现
Apr 14 Python
Python基于argparse与ConfigParser库进行入参解析与ini parser
Feb 02 Python
Python语言规范之Pylint的详细用法
Jun 24 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
PHP Pear 安装及使用
2009/03/19 PHP
PHP URL地址获取函数代码(端口等) 推荐
2010/05/15 PHP
PHP开发中常用的字符串操作函数
2011/02/08 PHP
搭建基于Docker的PHP开发环境的详细教程
2015/07/01 PHP
关于PHP开发的9条建议
2015/07/27 PHP
PHP+sqlite数据库操作示例(创建/打开/插入/检索)
2016/05/26 PHP
从javascript语言本身谈项目实战
2006/12/27 Javascript
Jquery操作Select 简单方便 一个js插件搞定
2009/11/12 Javascript
cnblogs TagCloud基于jquery的实现代码
2010/06/11 Javascript
兼容IE和FF的js脚本代码小结(比较常用)
2010/12/06 Javascript
分享XmlHttpRequest调用Webservice的一点心得
2012/07/20 Javascript
node.js 开发指南 ? Node.js 连接 MySQL 并进行数据库操作
2014/07/29 Javascript
jquery中push()的用法(数组添加元素)
2014/11/25 Javascript
jQuery插件slicebox实现3D动画图片轮播切换特效
2015/04/12 Javascript
谈谈AngularJs中的隐藏和显示
2015/12/09 Javascript
JS实现简单的二维矩阵乘积运算
2016/01/26 Javascript
基于JS设计12306登录页面
2016/12/28 Javascript
基于input框覆盖掉数字英文的实例讲解
2017/07/21 Javascript
细说webpack源码之compile流程-rules参数处理技巧(1)
2017/12/26 Javascript
Vue之mixin全局的用法详解
2018/08/22 Javascript
vuex如何重置所有state(可定制)
2019/01/17 Javascript
JS实现的定时器展示简单秒表、页面弹框及跳转操作完整示例
2020/01/26 Javascript
js实现抽奖的两种方法
2020/03/19 Javascript
pandas 实现将重复表格去重,并重新转换为表格的方法
2018/04/18 Python
简单了解Python3 bytes和str类型的区别和联系
2019/12/19 Python
django之导入并执行自定义的函数模块图解
2020/04/01 Python
将不规则的Python多维数组拉平到一维的方法实现
2021/01/11 Python
多重CSS背景动画实现方法示例
2014/04/04 HTML / CSS
BONIA波尼亚新加坡官网:皮革手袋,鞋类和配件
2016/08/25 全球购物
Gap加拿大官网:Gap Canada
2017/08/24 全球购物
Veronica Beard官网:在酷、经典和别致之间找到了平衡
2018/01/11 全球购物
Aeropostale官网:美国著名校园品牌及青少年服饰品牌
2019/03/21 全球购物
少先队学雷锋活动月总结
2014/03/09 职场文书
2014年综治宣传月活动总结
2014/04/28 职场文书
北京大学中文系教授推荐的10本小说
2019/08/08 职场文书
CSS link与@import的区别和用法解析
2023/05/07 HTML / CSS