Keras 数据增强ImageDataGenerator多输入多输出实例


Posted in Python onJuly 03, 2020

我就废话不多说了,大家还是直接看代码吧~

import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]=""
import sys
import gc
import time
import cv2
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

from random_eraser import get_random_eraser
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

datagen = ImageDataGenerator(
  rotation_range=20,   #旋转
  width_shift_range=0.1,  #水平位置平移
#   height_shift_range=0.2,  #上下位置平移
  shear_range=0.5,    #错切变换,让所有点的x坐标(或者y坐标)保持不变,而对应的y坐标(或者x坐标)则按比例发生平移
  zoom_range=[0.9,0.9],  # 单方向缩放,当一个数值时两个方向等比例缩放,参数为list时长宽不同程度缩放。参数大于0小于1时,执行的是放大操作,当参数大于1时,执行的是缩小操作。
  channel_shift_range = 40, #偏移通道数值,改变图片颜色,越大颜色越深
  horizontal_flip=True,  #水平翻转,垂直翻转vertical_flip
  fill_mode='nearest',   #操作导致图像缺失时填充方式。“constant”、“nearest”(默认)、“reflect”和“wrap”
  preprocessing_function = get_random_eraser(p=0.7,v_l=0,v_h=255,s_l=0.01,s_h=0.03,r_1=1,r_2=1.5,pixel_level=True)
  )

# train_generator = datagen.flow_from_directory(
#       'base/Images/',
#       save_to_dir = 'base/fake/',
#       batch_size=1
#       )
# for i in range(5):
#  train_generator.next()

# !
# df_train = pd.read_csv('base/Annotations/label.csv', header=None)
# df_train.columns = ['image_id', 'class', 'label']
# classes = ['collar_design_labels', 'neckline_design_labels', 'skirt_length_labels', 
#   'sleeve_length_labels', 'neck_design_labels', 'coat_length_labels', 'lapel_design_labels', 
#   'pant_length_labels']
# !

# classes = ['collar_design_labels']

# !
# for i in range(len(classes)):
#  gc.enable()

# #  单个分类
#  cur_class = classes[i]
#  df_load = df_train[(df_train['class'] == cur_class)].copy()
#  df_load.reset_index(inplace=True)
#  del df_load['index']

# #  print(cur_class)

# #  加载数据和label
#  n = len(df_load)
# #  n_class = len(df_load['label'][0])
# #  width = 256

# #  X = np.zeros((n,width, width, 3), dtype=np.uint8)
# #  y = np.zeros((n, n_class), dtype=np.uint8)

#  print(f'starting load trainset {cur_class} {n}')
#  sys.stdout.flush()
#  for i in tqdm(range(n)):
# #   tmp_label = df_load['label'][i]
#   img = load_img('base/{0}'.format(df_load['image_id'][i]))
#   x = img_to_array(img)
#   x = x.reshape((1,) + x.shape)
#   m=0
#   for batch in datagen.flow(x,batch_size=1):
# #    plt.imshow(array_to_img(batch[0]))
# #    print(batch)
#    array_to_img(batch[0]).save(f'base/fake/{format(df_load["image_id"][i])}-{m}.jpg')
#    m+=1
#    if m>3:
#     break
#  gc.collect()
# !  

img = load_img('base/Images/collar_design_labels/2f639f11de22076ead5fe1258eae024d.jpg')
plt.figure()
plt.imshow(img)
x = img_to_array(img)

x = x.reshape((1,) + x.shape)

i = 0
for batch in datagen.flow(x,batch_size=5):
 plt.figure()
 plt.imshow(array_to_img(batch[0]))
#  print(len(batch))
 i += 1
 if i >0:
  break
#多输入,设置随机种子
# Define the image transformations here
gen = ImageDataGenerator(horizontal_flip = True,
       vertical_flip = True,
       width_shift_range = 0.1,
       height_shift_range = 0.1,
       zoom_range = 0.1,
       rotation_range = 40)

# Here is the function that merges our two generators
# We use the exact same generator with the same random seed for both the y and angle arrays
def gen_flow_for_two_inputs(X1, X2, y):
 genX1 = gen.flow(X1,y, batch_size=batch_size,seed=666)
 genX2 = gen.flow(X1,X2, batch_size=batch_size,seed=666)
 while True:
   X1i = genX1.next()
   X2i = genX2.next()
   #Assert arrays are equal - this was for peace of mind, but slows down training
   #np.testing.assert_array_equal(X1i[0],X2i[0])
   yield [X1i[0], X2i[1]], X1i[1]
#手动构造,直接输出多label
generator = ImageDataGenerator(rotation_range=5.,
        width_shift_range=0.1, 
        height_shift_range=0.1, 
        horizontal_flip=True, 
        vertical_flip=True)

def generate_data_generator(generator, X, Y1, Y2):
 genX = generator.flow(X, seed=7)
 genY1 = generator.flow(Y1, seed=7)
 while True:
   Xi = genX.next()
   Yi1 = genY1.next()
   Yi2 = function(Y2)
   yield Xi, [Yi1, Yi2]
model.fit_generator(generate_data_generator(generator, X, Y1, Y2),
    epochs=epochs)
def batch_generator(generator,X,Y):
 Xgen = generator.flow(X)
 while True:
  yield Xgen.next(),Y
h = model.fit_generator(batch_generator(datagen, X_all, y_all), 
       steps_per_epoch=len(X_all)//32+1,
       epochs=80,workers=3,
       callbacks=[EarlyStopping(patience=3), checkpointer,ReduceLROnPlateau(monitor='val_loss',factor=0.5,patience=1)], 
       validation_data=(X_val,y_val))

补充知识:读取图片成numpy数组,裁剪并保存 和 数据增强(ImageDataGenerator)

我就废话不多说了,大家还是直接看代码吧~

from PIL import Image
import numpy as np
from PIL import Image
from keras.preprocessing import image
import matplotlib.pyplot as plt
import os
import cv2
# from scipy.misc import toimage
import matplotlib
# 生成图片地址和对应标签
file_dir = '../train/'
image_list = []
label_list = []
cate = [file_dir + x for x in os.listdir(file_dir) if os.path.isdir(file_dir + x)]
for name in cate:
 temp = name.split('/')
 path = '../train_new/' + temp[-1]
 isExists = os.path.exists(path)
 if not isExists:
  os.makedirs(path) # 目录不存在则创建
 class_path = name + "/"

 for file in os.listdir(class_path):
  print(file)
  img_obj = Image.open(class_path + file) # 读取图片
  img_array = np.array(img_obj)
  resized = cv2.resize(img_array, (256, 256)) # 裁剪
  resized = resized.astype('float32')
  resized /= 255.
  # plt.imshow(resized)
  # plt.show()
  save_path = path + '/' + file
  matplotlib.image.imsave(save_path, resized) # 保存

keras之数据增强

from PIL import Image
import numpy as np
from PIL import Image
from keras.preprocessing import image
import os
import cv2
# 生成图片地址和对应标签
file_dir = '../train/'

label_list = []
cate = [file_dir + x for x in os.listdir(file_dir) if os.path.isdir(file_dir + x)]
for name in cate:
 image_list = []
 class_path = name + "/"
 for file in os.listdir(class_path):
  image_list.append(class_path + file)
 batch_size = 64
 if len(image_list) < 10000:
  num = int(10000 / len(image_list))
 else:
  num = 0
 # 设置生成器参数
 datagen = image.ImageDataGenerator(fill_mode='wrap', # 填充模式
          rotation_range=40, # 指定旋转角度范围
          width_shift_range=0.2, # 水平位置平移
          height_shift_range=0.2, # 上下位置平移
          horizontal_flip=True, # 随机对图片执行水平翻转操作
          vertical_flip=True, # 对图片执行上下翻转操作
          shear_range=0.2,
          rescale=1./255, # 缩放
          data_format='channels_last')
 if num > 0:
  temp = name.split('/')
  path = '../train_datage/' + temp[-1]
  isExists = os.path.exists(path)
  if not isExists:
   os.makedirs(path)

  for image_path in image_list:
   i = 1
   img_obj = Image.open(image_path) # 读取图片
   img_array = np.array(img_obj)
   x = img_array.reshape((1,) + img_array.shape)  #要求为4维
   name_image = image_path.split('/')
   print(name_image)
   for batch in datagen.flow(x,
        batch_size=1,
        save_to_dir=path,
        save_prefix=name_image[-1][:-4] + '_',
        save_format='jpg'):
    i += 1
    if i > num:
     break

以上这篇Keras 数据增强ImageDataGenerator多输入多输出实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中删除文件的程序代码
Mar 13 Python
python 从远程服务器下载东西的代码
Feb 10 Python
Python lxml模块安装教程
Jun 02 Python
对python周期性定时器的示例详解
Feb 19 Python
详解爬虫被封的问题
Apr 23 Python
django-rest-framework 自定义swagger过程详解
Jul 18 Python
Python CVXOPT模块安装及使用解析
Aug 01 Python
详解如何用TensorFlow训练和识别/分类自定义图片
Aug 05 Python
jupyter 导入csv文件方式
Apr 21 Python
pycharm开发一个简单界面和通用mvc模板(操作方法图解)
May 27 Python
Python函数参数定义及传递方式解析
Jun 10 Python
十个Python自动化常用操作,即拿即用
May 10 Python
keras和tensorflow使用fit_generator 批次训练操作
Jul 03 #Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 #Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
You might like
apache mysql php 源码编译使用方法
2012/05/03 PHP
PHP CodeBase:将时间显示为&quot;刚刚&quot;&quot;n分钟/小时前&quot;的方法详解
2013/06/06 PHP
深入解析fsockopen与pfsockopen的区别
2013/07/05 PHP
destoon首页调用求购供应信息的地区名称的方法
2014/08/21 PHP
PHP基于phpqrcode生成带LOGO图像的二维码实例
2015/07/10 PHP
Ajax PHP JavaScript MySQL实现简易无刷新在线聊天室
2016/08/17 PHP
PHP文件操作详解
2016/12/30 PHP
关于php支持的协议与封装协议总结(推荐)
2017/11/17 PHP
图片在浏览器中底部对齐 解决方法之一
2011/11/30 Javascript
Javascript 鼠标移动上去 滑块跟随效果代码分享
2013/11/23 Javascript
JQuery的$命名冲突详细解析
2013/12/28 Javascript
jQuery学习笔记之总体架构
2014/06/03 Javascript
jQuery学习笔记之 Ajax操作篇(二) - 数据传递
2014/06/23 Javascript
详解JavaScript中常用的函数类型
2015/11/18 Javascript
移动端 一个简单易懂的弹出框
2016/07/06 Javascript
xmlplus组件设计系列之路由(ViewStack)(7)
2017/05/02 Javascript
NodeJs 模仿SIP话机注册的方法
2019/06/21 NodeJs
JavaScript算法学习之冒泡排序和选择排序
2019/11/02 Javascript
vue中英文切换实例代码
2020/01/21 Javascript
[01:02:38]DOTA2-DPC中国联赛定级赛 LBZS vs Phoenix BO3第二场 1月10日
2021/03/11 DOTA
Python获取Linux系统下的本机IP地址代码分享
2014/11/07 Python
python3实现读取chrome浏览器cookie
2016/06/19 Python
python中的二维列表实例详解
2018/06/19 Python
python判断数字是否是超级素数幂
2018/09/27 Python
python制作简单五子棋游戏
2019/06/18 Python
Python astype(np.float)函数使用方法解析
2020/06/08 Python
利用Python实现朋友圈中的九宫格图片效果
2020/09/03 Python
python hmac模块验证客户端的合法性
2020/11/07 Python
Volcom英国官方商店:美国殿堂级滑板、冲浪、滑雪服装品牌
2019/03/13 全球购物
世界上最大的冷却器制造商:Igloo Coolers
2019/07/23 全球购物
公安个人四风问题对照检查及整改措施
2014/10/28 职场文书
出差报告格式模板
2014/11/06 职场文书
毕业生个人总结
2015/02/28 职场文书
2016年世界艾滋病日宣传活动总结
2016/04/01 职场文书
pytorch MSELoss计算平均的实现方法
2021/05/12 Python
Opencv实现二维直方图的计算及绘制
2021/07/21 Python