Keras 在fit_generator训练方式中加入图像random_crop操作


Posted in Python onJuly 03, 2020

使用Keras作前端写网络时,由于训练图像尺寸较大,需要做类似 tf.random_crop 图像裁剪操作。

为此研究了一番Keras下已封装的API。

Data Augmentation(数据扩充)

Data Aumentation 指使用下面或其他方法增加输入数据量。我们默认图像数据。

旋转&反射变换(Rotation/reflection): 随机旋转图像一定角度; 改变图像内容的朝向;

翻转变换(flip): 沿着水平或者垂直方向翻转图像;

缩放变换(zoom): 按照一定的比例放大或者缩小图像;

平移变换(shift): 在图像平面上对图像以一定方式进行平移;

可以采用随机或人为定义的方式指定平移范围和平移步长, 沿水平或竖直方向进行平移. 改变图像内容的位置;

尺度变换(scale): 对图像按照指定的尺度因子, 进行放大或缩小; 或者参照SIFT特征提取思想, 利用指定的尺度因子对图像滤波构造尺度空间. 改变图像内容的大小或模糊程度;

对比度变换(contrast): 在图像的HSV颜色空间,改变饱和度S和V亮度分量,保持色调H不变. 对每个像素的S和V分量进行指数运算(指数因子在0.25到4之间), 增加光照变化;

噪声扰动(noise): 对图像的每个像素RGB进行随机扰动, 常用的噪声模式是椒盐噪声和高斯噪声;

Data Aumentation 有很多好处,比如数据量较少时,用数据扩充来增加训练数据,防止过拟合。

ImageDataGenerator

在Keras中,ImageDataGenerator就是专门做数据扩充的。

from keras.preprocessing.image import ImageDataGenerator

注:Using TensorFlow backend.

官方写法如下:

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

datagen = ImageDataGenerator(
 featurewise_center=True,
 ...
 horizontal_flip=True)

# compute quantities required for featurewise normalization
datagen.fit(x_train)

# 使用fit_generator的【自动】训练方法: fits the model on batches with real-time data augmentation
model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
   steps_per_epoch=len(x_train), epochs=epochs)

# 自己写range循环的【手动】训练方法
for e in range(epochs):
 print 'Epoch', e
 batches = 0
 for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
 loss = model.train(x_batch, y_batch)
 batches += 1
 if batches >= len(x_train) / 32:
  # we need to break the loop by hand because
  # the generator loops indefinitely
  break

ImageDataGenerator的参数说明见官网文档。

上面两种训练方法的差异不讨论,我们要关注的是:官方封装的训练集batch生成器是ImageDataGenerator对象的flow方法(或flow_from_directory),该函数返回一个和python定义相似的generator。在它前一步,数据变换是ImageDataGenerator对象的fit方法。

random_crop并未在ImageDataGenerator中内置,但参数中给了一个preprocessing_function,我们可以利用它自定义my_random_crop函数,像下面这样写:

def my_random_crop(image):
 random_arr = numpy.random.randint(img_sz-crop_sz+1, size=2)
 y = int(random_arr[0])
 x = int(random_arr[1])
 h = img_crop
 w = img_crop
 image_crop = image[y:y+h, x:x+w, :]
 return image_crop

datagen = ImageDataGenerator(
 featurewise_center=False,
 ···
 preprocessing_function=my_random_crop)

datagen.fit(x_train)

fit方法调用时将预设的变换应用到x_train的每张图上,包括图像crop,因为是单张依次处理,每张图的crop位置随机。

在训练数据(x=image, y=class_label)时这样写已满足要求;

但在(x=image, y=image_mask)时该方法就不成立了。图像单张处理的缘故,一对(image, image_mask)分别crop的位置无法保持一致。

虽然官网也给出了同时变换image和mask的写法,但它提出的方案能保证二者内置函数的变换一致,自定义函数的random变量仍是随机的。

fit_generator

既然ImageDataGenerator和flow方法不能满足我们的random_crop预处理要求,就在fit_generator函数处想方法修改。

先看它的定义:

def fit_generator(self, generator, samples_per_epoch, nb_epoch,
   verbose=1, callbacks=[],
   validation_data=None, nb_val_samples=None,
   class_weight=None, max_q_size=10, **kwargs):

第一个参数generator,可以传入一个方法,也可以直接传入数据集。前面的 datagen.flow() 即是Keras封装的批量数据传入方法。

显然,我们可以自定义。

def generate_batch_data_random(x, y, batch_size):
 """分批取batch数据加载到显存"""
 total_num = len(x)
 batches = total_num // batch_size
 while (True):
 i = randint(0, batches)
 x_batch = x[i*batch_size:(i+1)*batch_size]
 y_batch = y[i*batch_size:(i+1)*batch_size]
 random_arr = numpy.random.randint(img_sz-crop_sz+1, size=2)
 y_pos = int(random_arr[0])
 x_pos = int(random_arr[1])
 x_crop = x_batch[:, y_pos:y_pos+crop_sz, x_pos:x_pos+crop_sz, :]
 y_crop = y_batch[:, y_pos:y_pos+crop_sz, x_pos:x_pos+crop_sz, :]
 yield (x_crop, y_crop)

这样写就符合我们同组image和mask位置一致的random_crop要求。

注意:

由于没有使用ImageDataGenerator内置的数据变换方法,数据扩充则也需要自定义;由于没有使用flow(…, shuffle=True,)方法,每个epoch的数据打乱需要自定义。

generator自定义时要写成死循环,因为在每个epoch内,generate_batch_data_random是不会重复调用的。

补充知识:tensorflow中的随机裁剪函数random_crop

tf.random_crop是tensorflow中的随机裁剪函数,可以用来裁剪图片。我采用如下图片进行随机裁剪,裁剪大小为原图的一半。

Keras 在fit_generator训练方式中加入图像random_crop操作

如下是实验代码

import tensorflow as tf
import matplotlib.image as img
import matplotlib.pyplot as plt
sess = tf.InteractiveSession()
image = img.imread('D:/Documents/Pictures/logo3.jpg')

reshaped_image = tf.cast(image,tf.float32)
size = tf.cast(tf.shape(reshaped_image).eval(),tf.int32)
height = sess.run(size[0]//2)
width = sess.run(size[1]//2)
distorted_image = tf.random_crop(reshaped_image,[height,width,3])
print(tf.shape(reshaped_image).eval())
print(tf.shape(distorted_image).eval())

fig = plt.figure()
fig1 = plt.figure()
ax = fig.add_subplot(111)
ax1 = fig1.add_subplot(111)
ax.imshow(sess.run(tf.cast(reshaped_image,tf.uint8)))
ax1.imshow(sess.run(tf.cast(distorted_image,tf.uint8)))
plt.show()

如下是随机实验两次的结果

Keras 在fit_generator训练方式中加入图像random_crop操作

Keras 在fit_generator训练方式中加入图像random_crop操作

以上这篇Keras 在fit_generator训练方式中加入图像random_crop操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python的chardet库获得文件编码并修改编码
Jan 22 Python
Python中使用MELIAE分析程序内存占用实例
Feb 18 Python
Python列表推导式、字典推导式与集合推导式用法实例分析
Feb 07 Python
用Python写一段用户登录的程序代码
Apr 22 Python
Python实现合并同一个文件夹下所有txt文件的方法示例
Apr 26 Python
浅谈pandas.cut与pandas.qcut的使用方法及区别
Mar 03 Python
Python2.6版本pip安装步骤解析
Aug 17 Python
详解python with 上下文管理器
Sep 02 Python
Python jieba库分词模式实例用法
Jan 13 Python
Matlab使用Plot函数实现数据动态显示方法总结
Feb 25 Python
python使用openpyxl库读写Excel表格的方法(增删改查操作)
May 02 Python
OpenCV-Python实现人脸磨皮算法
Jun 07 Python
keras的三种模型实现与区别说明
Jul 03 #Python
Keras中 ImageDataGenerator函数的参数用法
Jul 03 #Python
python程序如何进行保存
Jul 03 #Python
keras的ImageDataGenerator和flow()的用法说明
Jul 03 #Python
python如何安装下载后的模块
Jul 03 #Python
python中id函数运行方式
Jul 03 #Python
Keras 数据增强ImageDataGenerator多输入多输出实例
Jul 03 #Python
You might like
PHP函数之error_reporting(E_ALL ^ E_NOTICE)详细说明
2011/07/01 PHP
根据key删除数组中指定的元素实现方法
2017/03/02 PHP
yii插入数据库防并发的简单代码
2017/05/27 PHP
php输出反斜杠的实例方法
2019/09/19 PHP
JavaScript 事件参考手册
2008/12/24 Javascript
jQuery+JSON+jPlayer实现QQ空间音乐查询功能示例
2013/06/17 Javascript
addEventListener()第三个参数useCapture (Boolean)详细解析
2013/11/07 Javascript
Jquery实现自定义弹窗示例
2014/03/12 Javascript
使用纯javascript实现放大镜效果
2015/03/18 Javascript
使用AngularJS和PHP的Laravel实现单页评论的方法
2015/06/19 Javascript
JavaScript之AOP编程实例
2015/07/17 Javascript
jQuery获取父元素节点、子元素节点及兄弟元素节点的方法
2016/04/14 Javascript
JavaScript通过filereader接口读取文件
2017/05/10 Javascript
用node和express连接mysql实现登录注册的实现代码
2017/07/05 Javascript
Angular实现的table表格排序功能完整示例
2017/12/22 Javascript
vue实现跳转接口push 转场动画示例
2019/11/01 Javascript
[55:45]DOTA2上海特级锦标赛D组败者赛 Liquid VS COL第一局
2016/02/28 DOTA
[48:28]完美世界DOTA2联赛循环赛FTD vs Magma第二场 10月30日
2020/10/31 DOTA
[01:15]PWL S2开团时刻第二期——他们杀 我就白给
2020/11/25 DOTA
浅谈Python的Django框架中的缓存控制
2015/07/24 Python
在Python的Flask框架中验证注册用户的Email的方法
2015/09/02 Python
Python中遇到的小问题及解决方法汇总
2017/01/11 Python
Python 专题三 字符串的基础知识
2017/03/19 Python
Pandas 数据框增、删、改、查、去重、抽样基本操作方法
2018/04/12 Python
pycharm远程linux开发和调试代码的方法
2018/07/17 Python
对pandas中Series的map函数详解
2018/07/25 Python
Python面向对象总结及类与正则表达式详解
2019/04/18 Python
python中 * 的用法详解
2019/07/10 Python
如何安装并使用conda指令管理python环境
2019/07/10 Python
Python读取文件内容为字符串的方法(多种方法详解)
2020/03/04 Python
python实现逢七拍腿小游戏的思路详解
2020/05/26 Python
CAT鞋美国官网:CAT Footwear
2017/11/27 全球购物
Canal官网:巴西女性时尚品牌
2019/10/16 全球购物
幼儿教师个人求职信范文
2013/09/21 职场文书
公司宣传语大全
2015/07/13 职场文书
python中取整数的几种方法
2021/11/07 Python