keras的ImageDataGenerator和flow()的用法说明


Posted in Python onJuly 03, 2020

ImageDataGenerator的参数自己看文档

from keras.preprocessing import image
import numpy as np

X_train=np.ones((3,123,123,1))
Y_train=np.array([[1],[2],[2]])
generator=image.ImageDataGenerator(featurewise_center=False,
  samplewise_center=False,
  featurewise_std_normalization=False,
  samplewise_std_normalization=False,
  zca_whitening=False,
  zca_epsilon=1e-6,
  rotation_range=180,
  width_shift_range=0.2,
  height_shift_range=0.2,
  shear_range=0,
  zoom_range=0.001,
  channel_shift_range=0,
  fill_mode='nearest',
  cval=0.,
  horizontal_flip=True,
  vertical_flip=True,
  rescale=None,
  preprocessing_function=None,
  data_format='channels_last')

a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环
'''
batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据
batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机
,输出的X,Y是一一对对应的
如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配

'''
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

输出

[[2]
 [1]
 [2]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

补充知识:tensorflow 与keras 混用之坑

在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下

其中错误为:TypeError: tuple indices must be integers, not list

再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下

原训练代码

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
 
#Каталог с данными для обучения
train_dir = 'train'
# Каталог с данными для проверки
val_dir = 'val'
# Каталог с данными для тестирования
test_dir = 'val'
 
# Размеры изображения
img_width, img_height = 800, 800
# Размерность тензора на основе изображения для входных данных в нейронную сеть
# backend Tensorflow, channels_last
input_shape = (img_width, img_height, 3)
# Количество эпох
epochs = 1
# Размер мини-выборки
batch_size = 4
# Количество изображений для обучения
nb_train_samples = 300
# Количество изображений для проверки
nb_validation_samples = 25
# Количество изображений для тестирования
nb_test_samples = 25
 
model = Sequential()
 
model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Conv2D(64, (5, 5), padding="same"))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
 
model.compile(loss='categorical_crossentropy',
       optimizer="Nadam",
       metrics=['accuracy'])
print(model.summary())
datagen = ImageDataGenerator(rescale=1. / 255)
 
train_generator = datagen.flow_from_directory(
  train_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
val_generator = datagen.flow_from_directory(
  val_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
test_generator = datagen.flow_from_directory(
  test_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
model.fit_generator(
  train_generator,
  steps_per_epoch=nb_train_samples // batch_size,
  epochs=epochs,
  validation_data=val_generator,
  validation_steps=nb_validation_samples // batch_size)
 
print('Сохраняем сеть')
model.save("grib.h5")
print("Сохранение завершено!")

模型载入

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from keras.models import load_model
 
print("Загрузка сети")
model = load_model("grib.h5")
print("Загрузка завершена!")

报错

/usr/bin/python3.5 /home/disk2/py/neroset/do.py
/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
 from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Загрузка сети
Traceback (most recent call last):
 File "/home/disk2/py/neroset/do.py", line 13, in <module>
  model = load_model("grib.h5")
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model
  model = model_from_config(model_config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config
  return layer_module.deserialize(config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize
  printable_module_name='layer')
 File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
  list(custom_objects.items())))
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config
  model.add(layer)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add
  output_tensor = layer(self.outputs[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__
  self.build(input_shapes[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build
  dim = input_shape[self.axis]
TypeError: tuple indices must be integers or slices, not list
 
Process finished with exit code 1

战斗种族解释

убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)

强调文本 强调文本

keras.preprocessing.image import ImageDataGenerator
keras.models import Sequential
keras.layers import Conv2D, MaxPooling2D, BatchNormalization
keras.layers import Activation, Dropout, Flatten, Dense

##完美解决

##附上原文链接

https://qa-help.ru/questions/keras-batchnormalization

以上这篇keras的ImageDataGenerator和flow()的用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中xrange和range的区别
May 13 Python
python实现模拟按键,自动翻页看u17漫画
Mar 17 Python
使用Python写一个小游戏
Apr 02 Python
Python基于opencv的图像压缩算法实例分析
May 03 Python
20行python代码实现人脸识别
May 05 Python
python实现视频分帧效果
May 31 Python
python numpy 常用随机数的产生方法的实现
Aug 21 Python
基于python 等频分箱qcut问题的解决
Mar 03 Python
Django Model层F,Q对象和聚合函数原理解析
Nov 12 Python
next在python中返回迭代器的实例方法
Dec 15 Python
python调用百度AI接口实现人流量统计
Feb 03 Python
Python中三种花式打印的示例详解
Mar 19 Python
python如何安装下载后的模块
Jul 03 #Python
python中id函数运行方式
Jul 03 #Python
Keras 数据增强ImageDataGenerator多输入多输出实例
Jul 03 #Python
keras和tensorflow使用fit_generator 批次训练操作
Jul 03 #Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
You might like
PHP对象克隆clone用法示例
2016/09/28 PHP
Laravel构建即时应用的一种实现方法详解
2017/08/31 PHP
PHP通过get方法获得form表单数据方法总结
2018/09/12 PHP
JavaScript 实现模态对话框 源代码大全
2009/05/02 Javascript
Javascript String对象扩展HTML编码和解码的方法
2009/06/02 Javascript
js同比例缩放图片的小例子
2013/10/30 Javascript
jQuery实现Twitter的自动文字补齐特效
2014/11/28 Javascript
js实现左侧网页tab滑动门效果代码
2015/09/06 Javascript
jQuery+PHP实现微信转盘抽奖功能的方法
2016/05/25 Javascript
js中用cssText设置css样式的简单方法
2016/09/19 Javascript
Bootstrap栅格系统学习笔记
2016/11/25 Javascript
jQuery实现鼠标滑过图片移动特效
2016/12/08 Javascript
angularjs 实现带查找筛选功能的select下拉框实例
2017/01/11 Javascript
javascript 单例模式详解及简单实例
2017/02/14 Javascript
Angular之指令Directive用法详解
2017/03/01 Javascript
vue.js表单验证插件(vee-validate)的使用教程详解
2019/05/23 Javascript
layui文件上传控件带更改后数据传值的方法
2019/09/23 Javascript
js实现登录拖拽窗口
2020/02/10 Javascript
浅谈Vue 自动化部署打包上线
2020/06/14 Javascript
[02:07]2018DOTA2亚洲邀请赛主赛事第三日五佳镜头 fy极限反杀
2018/04/06 DOTA
[51:52]Liquid vs Secret 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.24
2019/09/10 DOTA
[43:03]完美世界DOTA2联赛PWL S2 PXG vs Magma 第二场 11.21
2020/11/24 DOTA
[04:59]DOTA2-DPC中国联赛 正赛 Ehome vs iG 选手采访
2021/03/11 DOTA
linux系统使用python监控apache服务器进程脚本分享
2014/01/15 Python
浅谈Python 的枚举 Enum
2017/06/12 Python
python实战之实现excel读取、统计、写入的示例讲解
2018/05/02 Python
Python之修改图片像素值的方法
2019/07/03 Python
python循环输出三角形图案的例子
2019/11/22 Python
UGG美国官网:购买UGG雪地靴、拖鞋和鞋子
2017/12/31 全球购物
Larsson & Jennings官网:现代瑞士钟表匠
2018/03/20 全球购物
职高毕业生自我鉴定
2013/10/21 职场文书
小学毕业感言50字
2014/02/16 职场文书
厨师个人自我鉴定范文
2014/04/19 职场文书
实验室标语
2014/06/21 职场文书
人力资源部岗位职责
2015/02/11 职场文书
Go语言带缓冲的通道实现
2021/04/26 Golang