Keras 数据增强ImageDataGenerator多输入多输出实例


Posted in Python onJuly 03, 2020

我就废话不多说了,大家还是直接看代码吧~

import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]=""
import sys
import gc
import time
import cv2
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

from random_eraser import get_random_eraser
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

datagen = ImageDataGenerator(
  rotation_range=20,   #旋转
  width_shift_range=0.1,  #水平位置平移
#   height_shift_range=0.2,  #上下位置平移
  shear_range=0.5,    #错切变换,让所有点的x坐标(或者y坐标)保持不变,而对应的y坐标(或者x坐标)则按比例发生平移
  zoom_range=[0.9,0.9],  # 单方向缩放,当一个数值时两个方向等比例缩放,参数为list时长宽不同程度缩放。参数大于0小于1时,执行的是放大操作,当参数大于1时,执行的是缩小操作。
  channel_shift_range = 40, #偏移通道数值,改变图片颜色,越大颜色越深
  horizontal_flip=True,  #水平翻转,垂直翻转vertical_flip
  fill_mode='nearest',   #操作导致图像缺失时填充方式。“constant”、“nearest”(默认)、“reflect”和“wrap”
  preprocessing_function = get_random_eraser(p=0.7,v_l=0,v_h=255,s_l=0.01,s_h=0.03,r_1=1,r_2=1.5,pixel_level=True)
  )

# train_generator = datagen.flow_from_directory(
#       'base/Images/',
#       save_to_dir = 'base/fake/',
#       batch_size=1
#       )
# for i in range(5):
#  train_generator.next()

# !
# df_train = pd.read_csv('base/Annotations/label.csv', header=None)
# df_train.columns = ['image_id', 'class', 'label']
# classes = ['collar_design_labels', 'neckline_design_labels', 'skirt_length_labels', 
#   'sleeve_length_labels', 'neck_design_labels', 'coat_length_labels', 'lapel_design_labels', 
#   'pant_length_labels']
# !

# classes = ['collar_design_labels']

# !
# for i in range(len(classes)):
#  gc.enable()

# #  单个分类
#  cur_class = classes[i]
#  df_load = df_train[(df_train['class'] == cur_class)].copy()
#  df_load.reset_index(inplace=True)
#  del df_load['index']

# #  print(cur_class)

# #  加载数据和label
#  n = len(df_load)
# #  n_class = len(df_load['label'][0])
# #  width = 256

# #  X = np.zeros((n,width, width, 3), dtype=np.uint8)
# #  y = np.zeros((n, n_class), dtype=np.uint8)

#  print(f'starting load trainset {cur_class} {n}')
#  sys.stdout.flush()
#  for i in tqdm(range(n)):
# #   tmp_label = df_load['label'][i]
#   img = load_img('base/{0}'.format(df_load['image_id'][i]))
#   x = img_to_array(img)
#   x = x.reshape((1,) + x.shape)
#   m=0
#   for batch in datagen.flow(x,batch_size=1):
# #    plt.imshow(array_to_img(batch[0]))
# #    print(batch)
#    array_to_img(batch[0]).save(f'base/fake/{format(df_load["image_id"][i])}-{m}.jpg')
#    m+=1
#    if m>3:
#     break
#  gc.collect()
# !  

img = load_img('base/Images/collar_design_labels/2f639f11de22076ead5fe1258eae024d.jpg')
plt.figure()
plt.imshow(img)
x = img_to_array(img)

x = x.reshape((1,) + x.shape)

i = 0
for batch in datagen.flow(x,batch_size=5):
 plt.figure()
 plt.imshow(array_to_img(batch[0]))
#  print(len(batch))
 i += 1
 if i >0:
  break
#多输入,设置随机种子
# Define the image transformations here
gen = ImageDataGenerator(horizontal_flip = True,
       vertical_flip = True,
       width_shift_range = 0.1,
       height_shift_range = 0.1,
       zoom_range = 0.1,
       rotation_range = 40)

# Here is the function that merges our two generators
# We use the exact same generator with the same random seed for both the y and angle arrays
def gen_flow_for_two_inputs(X1, X2, y):
 genX1 = gen.flow(X1,y, batch_size=batch_size,seed=666)
 genX2 = gen.flow(X1,X2, batch_size=batch_size,seed=666)
 while True:
   X1i = genX1.next()
   X2i = genX2.next()
   #Assert arrays are equal - this was for peace of mind, but slows down training
   #np.testing.assert_array_equal(X1i[0],X2i[0])
   yield [X1i[0], X2i[1]], X1i[1]
#手动构造,直接输出多label
generator = ImageDataGenerator(rotation_range=5.,
        width_shift_range=0.1, 
        height_shift_range=0.1, 
        horizontal_flip=True, 
        vertical_flip=True)

def generate_data_generator(generator, X, Y1, Y2):
 genX = generator.flow(X, seed=7)
 genY1 = generator.flow(Y1, seed=7)
 while True:
   Xi = genX.next()
   Yi1 = genY1.next()
   Yi2 = function(Y2)
   yield Xi, [Yi1, Yi2]
model.fit_generator(generate_data_generator(generator, X, Y1, Y2),
    epochs=epochs)
def batch_generator(generator,X,Y):
 Xgen = generator.flow(X)
 while True:
  yield Xgen.next(),Y
h = model.fit_generator(batch_generator(datagen, X_all, y_all), 
       steps_per_epoch=len(X_all)//32+1,
       epochs=80,workers=3,
       callbacks=[EarlyStopping(patience=3), checkpointer,ReduceLROnPlateau(monitor='val_loss',factor=0.5,patience=1)], 
       validation_data=(X_val,y_val))

补充知识:读取图片成numpy数组,裁剪并保存 和 数据增强(ImageDataGenerator)

我就废话不多说了,大家还是直接看代码吧~

from PIL import Image
import numpy as np
from PIL import Image
from keras.preprocessing import image
import matplotlib.pyplot as plt
import os
import cv2
# from scipy.misc import toimage
import matplotlib
# 生成图片地址和对应标签
file_dir = '../train/'
image_list = []
label_list = []
cate = [file_dir + x for x in os.listdir(file_dir) if os.path.isdir(file_dir + x)]
for name in cate:
 temp = name.split('/')
 path = '../train_new/' + temp[-1]
 isExists = os.path.exists(path)
 if not isExists:
  os.makedirs(path) # 目录不存在则创建
 class_path = name + "/"

 for file in os.listdir(class_path):
  print(file)
  img_obj = Image.open(class_path + file) # 读取图片
  img_array = np.array(img_obj)
  resized = cv2.resize(img_array, (256, 256)) # 裁剪
  resized = resized.astype('float32')
  resized /= 255.
  # plt.imshow(resized)
  # plt.show()
  save_path = path + '/' + file
  matplotlib.image.imsave(save_path, resized) # 保存

keras之数据增强

from PIL import Image
import numpy as np
from PIL import Image
from keras.preprocessing import image
import os
import cv2
# 生成图片地址和对应标签
file_dir = '../train/'

label_list = []
cate = [file_dir + x for x in os.listdir(file_dir) if os.path.isdir(file_dir + x)]
for name in cate:
 image_list = []
 class_path = name + "/"
 for file in os.listdir(class_path):
  image_list.append(class_path + file)
 batch_size = 64
 if len(image_list) < 10000:
  num = int(10000 / len(image_list))
 else:
  num = 0
 # 设置生成器参数
 datagen = image.ImageDataGenerator(fill_mode='wrap', # 填充模式
          rotation_range=40, # 指定旋转角度范围
          width_shift_range=0.2, # 水平位置平移
          height_shift_range=0.2, # 上下位置平移
          horizontal_flip=True, # 随机对图片执行水平翻转操作
          vertical_flip=True, # 对图片执行上下翻转操作
          shear_range=0.2,
          rescale=1./255, # 缩放
          data_format='channels_last')
 if num > 0:
  temp = name.split('/')
  path = '../train_datage/' + temp[-1]
  isExists = os.path.exists(path)
  if not isExists:
   os.makedirs(path)

  for image_path in image_list:
   i = 1
   img_obj = Image.open(image_path) # 读取图片
   img_array = np.array(img_obj)
   x = img_array.reshape((1,) + img_array.shape)  #要求为4维
   name_image = image_path.split('/')
   print(name_image)
   for batch in datagen.flow(x,
        batch_size=1,
        save_to_dir=path,
        save_prefix=name_image[-1][:-4] + '_',
        save_format='jpg'):
    i += 1
    if i > num:
     break

以上这篇Keras 数据增强ImageDataGenerator多输入多输出实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python持久性管理pickle模块详细介绍
Feb 18 Python
numpy找出array中的最大值,最小值实例
Apr 03 Python
Numpy掩码式数组详解
Apr 17 Python
Python运维自动化之nginx配置文件对比操作示例
Aug 29 Python
Python实现登陆文件验证方法
Oct 06 Python
使用Python的turtle模块画国旗
Sep 24 Python
详解python itertools功能
Feb 07 Python
python_mask_array的用法
Feb 18 Python
python实现猜数游戏
Mar 27 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 Python
Python偏函数Partial function使用方法实例详解
Jun 17 Python
pytorch简介
Nov 11 Python
keras和tensorflow使用fit_generator 批次训练操作
Jul 03 #Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 #Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
You might like
GBK的页面输出JSON格式的php函数
2010/02/16 PHP
PHP 获取远程文件大小的3种解决方法
2013/07/11 PHP
php生成txt文件标题及内容的方法
2014/01/16 PHP
使用php转义输出HTML到JavaScript
2015/03/27 PHP
PHP中多线程的两个实现方法
2016/10/14 PHP
JavaScript 学习笔记一些小技巧
2010/03/28 Javascript
一些实用的jQuery代码片段收集
2011/07/12 Javascript
《JavaScript高级程序设计》阅读笔记(二) ECMAScript中的原始类型
2012/02/27 Javascript
jQuery得到多个值只能用取Class ,不能用取ID的方法
2016/12/04 Javascript
纯js实现悬浮按钮组件
2016/12/17 Javascript
微信小程序 监听手势滑动切换页面实例详解
2017/06/15 Javascript
js屏蔽退格键(backspace或者叫后退键与F5)
2019/02/10 Javascript
vue子路由跳转实现tab选项卡
2019/07/24 Javascript
vue.js+elementUI实现点击左右箭头切换头像功能(类似轮播图效果)
2019/09/05 Javascript
Vue修改项目启动端口号方法
2019/11/07 Javascript
vant IndexBar实现的城市列表的示例代码
2019/11/20 Javascript
uin-app+mockjs实现本地数据模拟
2020/08/26 Javascript
[53:21]2014 DOTA2国际邀请赛中国区预选赛5.21 DT VS LGD-CDEC
2014/05/22 DOTA
[08:53]DOTA2每周TOP10 精彩击杀集锦vol.9
2014/06/26 DOTA
python用户管理系统
2018/03/13 Python
通过Python 接口使用OpenCV的方法
2018/04/02 Python
pandas.loc 选取指定列进行操作的实例
2018/05/18 Python
Python实现DDos攻击实例详解
2019/02/02 Python
python中单下划线(_)和双下划线(__)的特殊用法
2019/08/29 Python
Django中的FBV和CBV用法详解
2019/09/15 Python
Pytorch实现神经网络的分类方式
2020/01/08 Python
鲜为人知的HTML5语音合成功能
2019/05/17 HTML / CSS
学生会主席事迹材料
2014/01/28 职场文书
安全教育实施方案
2014/03/02 职场文书
解除劳动合同协议书范本
2014/09/13 职场文书
2014镇副书记群众路线专题民主生活会思想汇报
2014/09/23 职场文书
2014年连锁店圣诞节活动方案
2014/12/09 职场文书
2015年世界无烟日演讲稿
2015/03/18 职场文书
Python如何使用logging为Flask增加logid
2021/03/30 Python
Python控制台输出俄罗斯方块移动和旋转功能
2021/04/18 Python
go select编译期的优化处理逻辑使用场景分析
2021/06/28 Golang