浅谈Tensorflow加载Vgg预训练模型的几个注意事项


Posted in Python onMay 26, 2020

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)
 
with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])
 
print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],
 
 image_float_ 
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)
 
with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

浅谈Tensorflow加载Vgg预训练模型的几个注意事项

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 正则表达式 概述及常用字符
May 04 Python
分享一个常用的Python模拟登陆类
Mar 29 Python
Python使用剪切板的方法
Jun 06 Python
Tensorflow 合并通道及加载子模型的方法
Jul 26 Python
详解Python发送email的三种方式
Oct 18 Python
python pandas读取csv后,获取列标签的方法
Nov 12 Python
Python基础教程之异常详解
Jan 10 Python
python实现多层感知器
Jan 18 Python
django框架forms组件用法实例详解
Dec 10 Python
详解python内置常用高阶函数(列出了5个常用的)
Feb 21 Python
Python中的面向接口编程示例详解
Jan 17 Python
Python 中Operator模块的使用
Jan 30 Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
You might like
日本收入最高的漫画家:海贼王作者版税年收入高达8.45亿元
2020/03/04 日漫
Terran历史背景
2020/03/14 星际争霸
PHP类的使用 实例代码讲解
2009/12/28 PHP
一步一步学习PHP(2)――PHP类型
2010/02/15 PHP
PHP中的一些常用函数收集
2015/05/26 PHP
HTML中嵌入PHP的简单方法
2016/02/16 PHP
Yii框架弹出窗口组件CJuiDialog用法分析
2017/01/07 PHP
PHP开发中解决并发问题的几种实现方法分析
2017/11/13 PHP
JavaScript游戏之优化篇
2010/11/08 Javascript
VBS通过WMI监视注册表变动的代码
2011/10/27 Javascript
推荐40个非常优秀的jQuery插件和教程【系列三】
2011/11/09 Javascript
遮罩层点击按钮弹出并且具有拖动和关闭效果(两种方法)
2015/08/20 Javascript
jQuery实现二维码扫描功能
2017/01/09 Javascript
JS常见简单正则表达式验证功能小结【手机,地址,企业税号,金额,身份证等】
2017/01/22 Javascript
小程序日历控件使用方法详解
2018/12/29 Javascript
jQuery实现高度灵活的表单验证功能示例【无UI】
2020/04/30 jQuery
vue项目打包为APP,静态资源正常显示,但API请求不到数据的操作
2020/09/12 Javascript
python根据给定文件返回文件名和扩展名的方法
2015/03/27 Python
python实现redis三种cas事务操作
2017/12/19 Python
Python Logging 日志记录入门学习
2018/06/02 Python
python跳出双层for循环的解决方法
2019/06/24 Python
flask/django 动态查询表结构相同表名不同数据的Model实现方法
2019/08/29 Python
Python中如何添加自定义模块
2020/06/09 Python
Python OpenCV读取中文路径图像的方法
2020/07/02 Python
python3 googletrans超时报错问题及翻译工具优化方案 附源码
2020/12/23 Python
HTML5自定义属性前缀data-及dataset的使用方法(html5 新特性)
2017/08/24 HTML / CSS
Anya Hindmarch官网:奢侈设计师手袋及配饰
2018/11/15 全球购物
开放系统互连参考模型
2016/06/29 面试题
应届毕业生求职信范例分享
2013/12/17 职场文书
学年自我鉴定
2014/01/16 职场文书
暑期研修感言
2014/02/17 职场文书
爱心募捐通知范文
2015/04/27 职场文书
九九重阳节致辞
2015/07/31 职场文书
详解 TypeScript 枚举类型
2021/11/02 Javascript
Redis基本数据类型List常用操作命令
2022/06/01 Redis
django项目、vue项目部署云服务器的详细过程
2022/07/23 Servers