Keras 数据增强ImageDataGenerator多输入多输出实例


Posted in Python onJuly 03, 2020

我就废话不多说了,大家还是直接看代码吧~

import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]=""
import sys
import gc
import time
import cv2
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

from random_eraser import get_random_eraser
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

datagen = ImageDataGenerator(
  rotation_range=20,   #旋转
  width_shift_range=0.1,  #水平位置平移
#   height_shift_range=0.2,  #上下位置平移
  shear_range=0.5,    #错切变换,让所有点的x坐标(或者y坐标)保持不变,而对应的y坐标(或者x坐标)则按比例发生平移
  zoom_range=[0.9,0.9],  # 单方向缩放,当一个数值时两个方向等比例缩放,参数为list时长宽不同程度缩放。参数大于0小于1时,执行的是放大操作,当参数大于1时,执行的是缩小操作。
  channel_shift_range = 40, #偏移通道数值,改变图片颜色,越大颜色越深
  horizontal_flip=True,  #水平翻转,垂直翻转vertical_flip
  fill_mode='nearest',   #操作导致图像缺失时填充方式。“constant”、“nearest”(默认)、“reflect”和“wrap”
  preprocessing_function = get_random_eraser(p=0.7,v_l=0,v_h=255,s_l=0.01,s_h=0.03,r_1=1,r_2=1.5,pixel_level=True)
  )

# train_generator = datagen.flow_from_directory(
#       'base/Images/',
#       save_to_dir = 'base/fake/',
#       batch_size=1
#       )
# for i in range(5):
#  train_generator.next()

# !
# df_train = pd.read_csv('base/Annotations/label.csv', header=None)
# df_train.columns = ['image_id', 'class', 'label']
# classes = ['collar_design_labels', 'neckline_design_labels', 'skirt_length_labels', 
#   'sleeve_length_labels', 'neck_design_labels', 'coat_length_labels', 'lapel_design_labels', 
#   'pant_length_labels']
# !

# classes = ['collar_design_labels']

# !
# for i in range(len(classes)):
#  gc.enable()

# #  单个分类
#  cur_class = classes[i]
#  df_load = df_train[(df_train['class'] == cur_class)].copy()
#  df_load.reset_index(inplace=True)
#  del df_load['index']

# #  print(cur_class)

# #  加载数据和label
#  n = len(df_load)
# #  n_class = len(df_load['label'][0])
# #  width = 256

# #  X = np.zeros((n,width, width, 3), dtype=np.uint8)
# #  y = np.zeros((n, n_class), dtype=np.uint8)

#  print(f'starting load trainset {cur_class} {n}')
#  sys.stdout.flush()
#  for i in tqdm(range(n)):
# #   tmp_label = df_load['label'][i]
#   img = load_img('base/{0}'.format(df_load['image_id'][i]))
#   x = img_to_array(img)
#   x = x.reshape((1,) + x.shape)
#   m=0
#   for batch in datagen.flow(x,batch_size=1):
# #    plt.imshow(array_to_img(batch[0]))
# #    print(batch)
#    array_to_img(batch[0]).save(f'base/fake/{format(df_load["image_id"][i])}-{m}.jpg')
#    m+=1
#    if m>3:
#     break
#  gc.collect()
# !  

img = load_img('base/Images/collar_design_labels/2f639f11de22076ead5fe1258eae024d.jpg')
plt.figure()
plt.imshow(img)
x = img_to_array(img)

x = x.reshape((1,) + x.shape)

i = 0
for batch in datagen.flow(x,batch_size=5):
 plt.figure()
 plt.imshow(array_to_img(batch[0]))
#  print(len(batch))
 i += 1
 if i >0:
  break
#多输入,设置随机种子
# Define the image transformations here
gen = ImageDataGenerator(horizontal_flip = True,
       vertical_flip = True,
       width_shift_range = 0.1,
       height_shift_range = 0.1,
       zoom_range = 0.1,
       rotation_range = 40)

# Here is the function that merges our two generators
# We use the exact same generator with the same random seed for both the y and angle arrays
def gen_flow_for_two_inputs(X1, X2, y):
 genX1 = gen.flow(X1,y, batch_size=batch_size,seed=666)
 genX2 = gen.flow(X1,X2, batch_size=batch_size,seed=666)
 while True:
   X1i = genX1.next()
   X2i = genX2.next()
   #Assert arrays are equal - this was for peace of mind, but slows down training
   #np.testing.assert_array_equal(X1i[0],X2i[0])
   yield [X1i[0], X2i[1]], X1i[1]
#手动构造,直接输出多label
generator = ImageDataGenerator(rotation_range=5.,
        width_shift_range=0.1, 
        height_shift_range=0.1, 
        horizontal_flip=True, 
        vertical_flip=True)

def generate_data_generator(generator, X, Y1, Y2):
 genX = generator.flow(X, seed=7)
 genY1 = generator.flow(Y1, seed=7)
 while True:
   Xi = genX.next()
   Yi1 = genY1.next()
   Yi2 = function(Y2)
   yield Xi, [Yi1, Yi2]
model.fit_generator(generate_data_generator(generator, X, Y1, Y2),
    epochs=epochs)
def batch_generator(generator,X,Y):
 Xgen = generator.flow(X)
 while True:
  yield Xgen.next(),Y
h = model.fit_generator(batch_generator(datagen, X_all, y_all), 
       steps_per_epoch=len(X_all)//32+1,
       epochs=80,workers=3,
       callbacks=[EarlyStopping(patience=3), checkpointer,ReduceLROnPlateau(monitor='val_loss',factor=0.5,patience=1)], 
       validation_data=(X_val,y_val))

补充知识:读取图片成numpy数组,裁剪并保存 和 数据增强(ImageDataGenerator)

我就废话不多说了,大家还是直接看代码吧~

from PIL import Image
import numpy as np
from PIL import Image
from keras.preprocessing import image
import matplotlib.pyplot as plt
import os
import cv2
# from scipy.misc import toimage
import matplotlib
# 生成图片地址和对应标签
file_dir = '../train/'
image_list = []
label_list = []
cate = [file_dir + x for x in os.listdir(file_dir) if os.path.isdir(file_dir + x)]
for name in cate:
 temp = name.split('/')
 path = '../train_new/' + temp[-1]
 isExists = os.path.exists(path)
 if not isExists:
  os.makedirs(path) # 目录不存在则创建
 class_path = name + "/"

 for file in os.listdir(class_path):
  print(file)
  img_obj = Image.open(class_path + file) # 读取图片
  img_array = np.array(img_obj)
  resized = cv2.resize(img_array, (256, 256)) # 裁剪
  resized = resized.astype('float32')
  resized /= 255.
  # plt.imshow(resized)
  # plt.show()
  save_path = path + '/' + file
  matplotlib.image.imsave(save_path, resized) # 保存

keras之数据增强

from PIL import Image
import numpy as np
from PIL import Image
from keras.preprocessing import image
import os
import cv2
# 生成图片地址和对应标签
file_dir = '../train/'

label_list = []
cate = [file_dir + x for x in os.listdir(file_dir) if os.path.isdir(file_dir + x)]
for name in cate:
 image_list = []
 class_path = name + "/"
 for file in os.listdir(class_path):
  image_list.append(class_path + file)
 batch_size = 64
 if len(image_list) < 10000:
  num = int(10000 / len(image_list))
 else:
  num = 0
 # 设置生成器参数
 datagen = image.ImageDataGenerator(fill_mode='wrap', # 填充模式
          rotation_range=40, # 指定旋转角度范围
          width_shift_range=0.2, # 水平位置平移
          height_shift_range=0.2, # 上下位置平移
          horizontal_flip=True, # 随机对图片执行水平翻转操作
          vertical_flip=True, # 对图片执行上下翻转操作
          shear_range=0.2,
          rescale=1./255, # 缩放
          data_format='channels_last')
 if num > 0:
  temp = name.split('/')
  path = '../train_datage/' + temp[-1]
  isExists = os.path.exists(path)
  if not isExists:
   os.makedirs(path)

  for image_path in image_list:
   i = 1
   img_obj = Image.open(image_path) # 读取图片
   img_array = np.array(img_obj)
   x = img_array.reshape((1,) + img_array.shape)  #要求为4维
   name_image = image_path.split('/')
   print(name_image)
   for batch in datagen.flow(x,
        batch_size=1,
        save_to_dir=path,
        save_prefix=name_image[-1][:-4] + '_',
        save_format='jpg'):
    i += 1
    if i > num:
     break

以上这篇Keras 数据增强ImageDataGenerator多输入多输出实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现哈希表
Feb 07 Python
Python yield 使用浅析
May 28 Python
win系统下为Python3.5安装flask-mongoengine 库
Dec 20 Python
Python中pandas模块DataFrame创建方法示例
Jun 20 Python
Python快速查找list中相同部分的方法
Jun 27 Python
Django实现支付宝付款和微信支付的示例代码
Jul 25 Python
python机器人运动范围问题的解答
Apr 29 Python
Python3.5迭代器与生成器用法实例分析
Apr 30 Python
python实现图片转字符小工具
Apr 30 Python
Python基于gevent实现高并发代码实例
May 15 Python
使用Pycharm在运行过程中,查看每个变量的操作(show variables)
Jun 08 Python
详解Python模块化编程与装饰器
Jan 16 Python
keras和tensorflow使用fit_generator 批次训练操作
Jul 03 #Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 #Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
You might like
php,ajax实现分页
2008/03/27 PHP
php gd2 上传图片/文字水印/图片水印/等比例缩略图/实现代码
2010/05/15 PHP
深入PHP FTP类的详解
2013/06/13 PHP
php之curl实现http与https请求的方法
2014/10/21 PHP
php中print(),print_r(),echo()的区别详解
2014/12/01 PHP
php的4种常见运行方式
2015/03/20 PHP
laravel 5.1下php artisan migrate的使用注意事项总结
2017/06/07 PHP
laravel接管Dingo-api和默认的错误处理方式
2019/10/25 PHP
XP折叠菜单&amp;仿QQ2006菜单
2006/12/16 Javascript
jQuery 表单验证插件formValidation实现个性化错误提示
2009/06/23 Javascript
javascript折半查找详解
2015/01/26 Javascript
Bootstrap实现响应式导航栏效果
2015/12/28 Javascript
Vue.js快速入门实例教程
2016/10/15 Javascript
jQuery设置和获取select、checkbox、radio的选中值方法
2017/01/01 Javascript
微信小程序地图(map)组件点击(tap)获取经纬度的方法
2019/01/10 Javascript
在微信小程序中使用图表的方法示例
2019/04/25 Javascript
Vue.js实现备忘录功能
2019/06/26 Javascript
使用layer弹窗提交表单时判断表单是否输入为空的例子
2019/09/26 Javascript
JavaScript中arguments的使用方法详解
2020/12/20 Javascript
Python中文编码那些事
2014/06/25 Python
由Python运算π的值深入Python中科学计算的实现
2015/04/17 Python
使用python编写简单的小程序编译成exe跑在win10上
2018/01/15 Python
Python生成器以及应用实例解析
2018/02/08 Python
python数据预处理之数据标准化的几种处理方式
2019/07/17 Python
Python将列表中的元素转化为数字并排序的示例
2019/12/25 Python
完美解决pycharm导入自己写的py文件爆红问题
2020/02/12 Python
解决python Jupyter不能导入外部包问题
2020/04/15 Python
Python dict的常用方法示例代码
2020/06/23 Python
Python 中的函数装饰器和闭包详解
2021/02/06 Python
微软香港官网及网上商店:Microsoft HK
2016/09/01 全球购物
Tiqets英国:智能手机上的文化和娱乐门票
2019/07/10 全球购物
文明工地标语
2014/06/16 职场文书
司机岗位职责
2015/02/04 职场文书
单位更名证明
2015/06/18 职场文书
基于PyTorch实现一个简单的CNN图像分类器
2021/05/29 Python
MySQL事务操作的四大特性以及并发事务问题
2022/04/12 MySQL