keras实现图像预处理并生成一个generator的案例


Posted in Python onJune 17, 2020

如下所示:

keras实现图像预处理并生成一个generator的案例

接下来,给出我自己目前积累的代码,从目录中自动读取图像,并产生generator:

第一步:建立好目录结构和图像

keras实现图像预处理并生成一个generator的案例

可以看到目录images_keras_dict下有次级目录,次级目录下就直接包含照片了

**第二步:写代码建立预处理程序

# 先进行预处理图像
train_datagen = ImageDataGenerator(rescale=1./255, 
                  rotation_range=50,
                  height_shift_range=[-0.005, 0, 0.005],
                  width_shift_range=[-0.005, 0, 0.005],
                  horizontal_flip=True, 
                  fill_mode='reflect')
#再对预处理图像指定从目录中读取数据,可以看到我的目录最核心的地方是images_keras_dict(可以对照上一张图片)
train_generator = train_datagen.flow_from_directory('AgriculturalDisease_trainingset/images_keras_dict',
                          target_size=(height, width), batch_size=16)

val_datagen = ImageDataGenerator(rescale=1./255)
val_generator = val_datagen.flow_from_directory('AgriculturalDisease_validationset/images_keras_dict', target_size=(height, width),
                        batch_size=64)

save_weights = ModelCheckpoint(filepath='models/best_weights.hdf5',monitor='val_loss', verbose=1, save_best_only=True)

# 最后在fit_generator 中放入生成器的函数train_generator
model.fit_generator(train_generator,
          steps_per_epoch=times_train,
          verbose=1,
          epochs=300,
          initial_epoch=0,
          validation_data=val_generator,
          validation_steps=times_val,
          callbacks=[save_weights, TrainValTensorBoard(write_graph=False)])

第三步:写入fit_generator进行训练

已经写在上一个代码中。

第四步:写predict_generator进行预测**

首先我们需要建立同样的目录结构。把包含预测图片的次级目录放在一个文件夹下,这个文件夹名就是关键文件夹。

这里我的关键文件夹是test文件夹

# 建立预处理
predict_datagen = ImageDataGenerator(rescale=1./255)
predict_generator = predict_datagen.flow_from_directory('AgriculturalDisease_validationset/test',
                            target_size=(height, width), batch_size=128)
# predict_generator.reset()
# 利用predict_generator进行预测
pred = model.predict_generator(predict_generator, max_queue_size=10, workers=1, verbose=1)

# 利用几个属性来读取文件夹和对应的分类
train_datagen = ImageDataGenerator(rescale=1./255, rotation_range=40, fill_mode='wrap')
train_generator = train_datagen.flow_from_directory('new_images', target_size=(height, width), batch_size=96)
labels = (train_generator.class_indices)
labels = dict((v,k) for k,v in labels.items())
predictions = [labels[k] for k in predicted_class_indices]

# 还可以知道图片的名字
filenames = predict_generator.filenames

补充知识:[TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用

前言

是的,除了水报错文,我也来写点其他的。本文主要介绍Keras中以下三个函数的用法:

1、fit()

2、fit_generator()

3、train_on_batch()

当然,与上述三个函数相似的evaluate、predict、test_on_batch、predict_on_batch、evaluate_generator和predict_generator等就不详细说了,举一反三嘛。

环境

本文的代码是在以下环境下进行测试的:

Windows 10

Python 3.6

TensorFlow 2.0 Alpha

异同

大家用Keras也就图个简单快捷,但是在享受简单快捷的时候,也常常需要些定制化需求,除了model.fit(),有时候model.fit_generator()和model.train_on_batch()也很重要。

那么,这三个函数有什么异同呢?Adrian Rosebrock [1] 有如下总结:

当你使用.fit()函数时,意味着如下两个假设:

训练数据可以 完整地 放入到内存(RAM)里

数据已经不需要再进行任何处理了

这两个原因解释的非常好,之前我运行程序的时候,由于数据集太大(实际中的数据集显然不会都像 TensorFlow 官方教程里经常使用的 MNIST 数据集那样小),一次性加载训练数据到fit()函数里根本行不通:

history = model.fit(train_data, train_label) // Bomb!!!

于是我想,能不能先加载一个batch训练,然后再加载一个batch,如此往复。于是我就注意到了fit_generator()函数。什么时候该使用fit_generator函数呢?Adrian Rosebrock 的总结道:

内存不足以一次性加载整个训练数据的时候

需要一些数据预处理(例如旋转和平移图片、增加噪音、扩大数据集等操作)

在生成batch的时候需要更多的处理

对于我自己来说,除了数据集太大的缘故之外,我需要在生成batch的时候,对输入数据进行padding,所以fit_generator()就派上了用场。下面介绍如何使用这三种函数。

fit()函数

fit()函数其实没什么好说的,大家在看TensorFlow教程的时候已经见识过了。此外插一句话,tf.data.Dataset对不规则的序列数据真是不友好。

import tensorflow as tf
model = tf.keras.models.Sequential([
 ... // 你的模型
])
model.fit(train_x, // 训练输入
  train_y, // 训练标签
  epochs=5 // 训练5轮
)

fit_generator()函数

fit_generator()函数就比较重要了,也是本文讨论的重点。fit_generator()与fit()的主要区别就在一个generator上。之前,我们把整个训练数据都输入到fit()里,我们也不需要考虑batch的细节;现在,我们使用一个generator,每次生成一个batch送给fit_generator()训练。

def generator(x, y, b_size):
 ... // 处理函数

model.fit_generator(generator(train_x, train_y, batch_size), 
   step_per_epochs=np.ceil(len(train_x)/batch_size), 
   epochs=5
)

从上述代码中,我们发现有两处不同:

一个我们自定义的generator()函数,作为fit_generator()函数的第一个参数;

fit_generator()函数的step_per_epochs参数

自定义的generator()函数

该函数即是我们数据的生成器,在训练的时候,fit_generator()函数会不断地执行generator()函数,获取一个个的batch。

def generator(x, y, b_size):
 """Generates batch and batch and batch then feed into models.
 Args:
 x: input data;
 y: input labels;
 b_size: batch_size.
 Yield:
 (batch_x, batch_label): batched x and y.
 """
 while 1: // 死循环
 idx = ...
 batch_x = ...
 batch_y = ...
 ... // 任何你想要对这个`batch`中的数据执行的操作
 yield (batch_x, batch_y)

需要注意的是,不要使用return或者exit。

step_per_epochs参数

由于generator()函数的循环没有终止条件,fit_generator也不知道一个epoch什么时候结束,所以我们需要手动指定step_per_epochs参数,一般的数值即为len(y)//batch_size。如果数据集大小不能整除batch_size,而且你打算使用最后一个batch的数据(该batch比batch_size要小),此时使用np.ceil(len(y)/batch_size)。

keras.utils.Sequence类(2019年6月10日更新)

除了写generator()函数,我们还可以利用keras.utils.Sequence类来生成batch。先扔代码:

class Generator(keras.utils.Sequence):
 def __init__(self, x, y, b_size):
 self.x, self.y = x, y
 self.batch_size = b_size
 
 def __len__(self):
 return math.ceil(len(self.y)/self.batch_size

 def __getitem__(self, idx):
 b_x = self.x[idx*self.batch_size:(idx+1)*self.batch_size]
 b_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size]
 ... // 对`batch`的其余操作
 return np.array(b_x), np.array(b_y)
 
 def on_epoch_end(self):
 """执行完一个`epoch`之后,还可以做一些其他的事情!"""
 ...

我们首先定义__init__函数,读取训练集数据,然后定义__len__函数,返回一个epoch中需要执行的step数(此时在fit_generator()函数中就不需要指定steps_per_epoch参数了),最后定义__getitem__函数,返回一个batch的数据。代码如下:

train_generator = Generator(train_x, train_y, batch_size)
val_generator = Generator(val_x, val_y, batch_size)

model.fit_generator(generator=train_generator, 
   epochs=3197747, 
   validation_data=val_generator
   )

根据官方 [2] 的说法,使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。总之,使用keras.utils.Sequence也是很方便的啦!

train_on_batch()函数

train_on_batch()函数接受一个batch的输入和标签,然后开始反向传播,更新参数等。大部分情况下你都不需要用到train_on_batch()函数,除非你有着充足的理由去定制化你的模型的训练流程。

结语

本文到此结束啦!希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python写的Discuz7.2版faq.php注入漏洞工具
Aug 06 Python
python循环监控远程端口的方法
Mar 14 Python
Python中的urllib模块使用详解
Jul 07 Python
Python实现HTTP协议下的文件下载方法总结
Apr 20 Python
深入理解Python3中的http.client模块
Mar 29 Python
python实现在pandas.DataFrame添加一行
Apr 04 Python
python向已存在的excel中新增表,不覆盖原数据的实例
May 02 Python
我用Python抓取了7000 多本电子书案例详解
Mar 25 Python
深入了解Python枚举类型的相关知识
Jul 09 Python
python按行读取文件并找出其中指定字符串
Aug 08 Python
解决python中的幂函数、指数函数问题
Nov 25 Python
python小技巧——将变量保存在本地及读取
Nov 13 Python
pytorch快速搭建神经网络_Sequential操作
Jun 17 #Python
浅谈Keras的Sequential与PyTorch的Sequential的区别
Jun 17 #Python
Keras之fit_generator与train_on_batch用法
Jun 17 #Python
基于Keras的格式化输出Loss实现方式
Jun 17 #Python
Tensorflow之MNIST CNN实现并保存、加载模型
Jun 17 #Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
You might like
通用PHP动态生成静态HTML网页的代码
2010/03/04 PHP
apache+php完美解决301重定向的两种方法
2011/06/08 PHP
php实现根据IP地址获取其所在省市的方法
2015/04/30 PHP
ThinkPHP路由详解
2015/07/27 PHP
在网页中屏蔽快捷键
2006/09/06 Javascript
jquery.hotkeys监听键盘按下事件keydown插件
2014/05/11 Javascript
jQuery的each循环用法简单示例
2016/06/12 Javascript
Javascript实现登录记住用户名和密码功能
2017/03/22 Javascript
微信小程序上滑加载下拉刷新(onscrollLower)分批加载数据(一)
2017/05/11 Javascript
跨域请求两种方法 jsonp和cors的实现
2018/11/11 Javascript
ES6知识点整理之Proxy的应用实例详解
2019/04/16 Javascript
vue实现分页栏效果
2019/06/28 Javascript
解决layui弹出层layer的area过大被遮挡的问题
2019/09/21 Javascript
这15个Vue指令,让你的项目开发爽到爆
2019/10/11 Javascript
Vue实现返回顶部按钮实例代码
2020/10/21 Javascript
python基础入门学习笔记(Python环境搭建)
2016/01/13 Python
Python外星人入侵游戏编程完整版
2020/03/30 Python
python OpenCV学习笔记之绘制直方图的方法
2018/02/08 Python
Python中elasticsearch插入和更新数据的实现方法
2018/04/01 Python
python爬虫基础教程:requests库(二)代码实例
2019/04/09 Python
解决python flask中config配置管理的问题
2019/07/26 Python
Python学习笔记之Zip和Enumerate用法实例分析
2019/08/14 Python
纯css3实现效果超级炫的checkbox复选框和radio单选框
2014/09/01 HTML / CSS
HTML5计时器小例子
2013/10/15 HTML / CSS
HTML5里的placeholder属性使用实例和美化显示效果的方法
2014/04/23 HTML / CSS
深圳-东方伟业笔试部分
2015/02/11 面试题
自我鉴定范文200字
2013/10/02 职场文书
小学生获奖感言范文
2014/02/02 职场文书
翻译学院毕业生自荐书
2014/02/02 职场文书
安卓程序员求职信
2014/02/28 职场文书
2014年元旦感言
2014/03/06 职场文书
房务中心文员岗位职责
2014/04/16 职场文书
活动总结格式范文
2014/04/26 职场文书
2015教师个人工作总结范文
2015/03/31 职场文书
jQuery ajax - getScript() 方法和getJSON方法
2021/05/14 jQuery
Android Flutter实现图片滑动切换效果
2022/04/07 Java/Android