keras的ImageDataGenerator和flow()的用法说明


Posted in Python onJuly 03, 2020

ImageDataGenerator的参数自己看文档

from keras.preprocessing import image
import numpy as np

X_train=np.ones((3,123,123,1))
Y_train=np.array([[1],[2],[2]])
generator=image.ImageDataGenerator(featurewise_center=False,
  samplewise_center=False,
  featurewise_std_normalization=False,
  samplewise_std_normalization=False,
  zca_whitening=False,
  zca_epsilon=1e-6,
  rotation_range=180,
  width_shift_range=0.2,
  height_shift_range=0.2,
  shear_range=0,
  zoom_range=0.001,
  channel_shift_range=0,
  fill_mode='nearest',
  cval=0.,
  horizontal_flip=True,
  vertical_flip=True,
  rescale=None,
  preprocessing_function=None,
  data_format='channels_last')

a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环
'''
batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据
batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机
,输出的X,Y是一一对对应的
如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配

'''
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

输出

[[2]
 [1]
 [2]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

补充知识:tensorflow 与keras 混用之坑

在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下

其中错误为:TypeError: tuple indices must be integers, not list

再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下

原训练代码

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
 
#Каталог с данными для обучения
train_dir = 'train'
# Каталог с данными для проверки
val_dir = 'val'
# Каталог с данными для тестирования
test_dir = 'val'
 
# Размеры изображения
img_width, img_height = 800, 800
# Размерность тензора на основе изображения для входных данных в нейронную сеть
# backend Tensorflow, channels_last
input_shape = (img_width, img_height, 3)
# Количество эпох
epochs = 1
# Размер мини-выборки
batch_size = 4
# Количество изображений для обучения
nb_train_samples = 300
# Количество изображений для проверки
nb_validation_samples = 25
# Количество изображений для тестирования
nb_test_samples = 25
 
model = Sequential()
 
model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Conv2D(64, (5, 5), padding="same"))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
 
model.compile(loss='categorical_crossentropy',
       optimizer="Nadam",
       metrics=['accuracy'])
print(model.summary())
datagen = ImageDataGenerator(rescale=1. / 255)
 
train_generator = datagen.flow_from_directory(
  train_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
val_generator = datagen.flow_from_directory(
  val_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
test_generator = datagen.flow_from_directory(
  test_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
model.fit_generator(
  train_generator,
  steps_per_epoch=nb_train_samples // batch_size,
  epochs=epochs,
  validation_data=val_generator,
  validation_steps=nb_validation_samples // batch_size)
 
print('Сохраняем сеть')
model.save("grib.h5")
print("Сохранение завершено!")

模型载入

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from keras.models import load_model
 
print("Загрузка сети")
model = load_model("grib.h5")
print("Загрузка завершена!")

报错

/usr/bin/python3.5 /home/disk2/py/neroset/do.py
/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
 from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Загрузка сети
Traceback (most recent call last):
 File "/home/disk2/py/neroset/do.py", line 13, in <module>
  model = load_model("grib.h5")
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model
  model = model_from_config(model_config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config
  return layer_module.deserialize(config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize
  printable_module_name='layer')
 File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
  list(custom_objects.items())))
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config
  model.add(layer)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add
  output_tensor = layer(self.outputs[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__
  self.build(input_shapes[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build
  dim = input_shape[self.axis]
TypeError: tuple indices must be integers or slices, not list
 
Process finished with exit code 1

战斗种族解释

убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)

强调文本 强调文本

keras.preprocessing.image import ImageDataGenerator
keras.models import Sequential
keras.layers import Conv2D, MaxPooling2D, BatchNormalization
keras.layers import Activation, Dropout, Flatten, Dense

##完美解决

##附上原文链接

https://qa-help.ru/questions/keras-batchnormalization

以上这篇keras的ImageDataGenerator和flow()的用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python函数学习笔记
Oct 07 Python
详解Python中的strftime()方法的使用
May 22 Python
Python运算符重载用法实例分析
Jun 01 Python
python并发2之使用asyncio处理并发
Dec 21 Python
Python 实现使用dict 创建二维数据、DataFrame
Apr 13 Python
PyQT实现多窗口切换
Apr 20 Python
Python实现快速计算词频功能示例
Jun 25 Python
Python实现购物评论文本情感分析操作【基于中文文本挖掘库snownlp】
Aug 07 Python
一文了解Python并发编程的工程实现方法
May 31 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 Python
Python Pillow(PIL)库的用法详解
Sep 19 Python
python读取mnist数据集方法案例详解
Sep 04 Python
python如何安装下载后的模块
Jul 03 #Python
python中id函数运行方式
Jul 03 #Python
Keras 数据增强ImageDataGenerator多输入多输出实例
Jul 03 #Python
keras和tensorflow使用fit_generator 批次训练操作
Jul 03 #Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
You might like
php分页函数
2006/07/08 PHP
PHP获取网卡地址的代码
2008/04/09 PHP
根据ip调用新浪api获取城市名并转成拼音
2014/03/07 PHP
使用PHP和HTML5 FormData实现无刷新文件上传教程
2014/09/06 PHP
PHP检测用户语言的方法
2015/06/15 PHP
PHP入门教程之自定义函数用法详解(创建,调用,变量,参数,返回值等)
2016/09/11 PHP
php获取文件名称和扩展名的方法
2017/02/07 PHP
PHP简单实现合并2个数字键数组值的方法
2017/05/30 PHP
PHP针对伪静态的注入总结【附asp与Python相关代码】
2017/08/01 PHP
thinkPHP5分页功能实现方法分析
2017/10/25 PHP
JQuery处理json与ajax返回JSON实例代码
2014/01/03 Javascript
js中哈希表的几种用法总结
2014/01/28 Javascript
javascript中expression的用法整理
2014/05/13 Javascript
jQuery setTimeout传递字符串参数报错的解决方法
2014/06/09 Javascript
JS+CSS实现自动改变切换方向图片幻灯切换效果的方法
2015/03/02 Javascript
jquery特效 点击展示与隐藏全文
2015/12/09 Javascript
js防阻塞加载的实现方法
2016/09/09 Javascript
Bootstrap轮播插件使用代码
2016/10/11 Javascript
Javascript中this绑定的3种方法与比较
2016/10/13 Javascript
js移动端图片压缩上传功能
2020/08/18 Javascript
Vue.js 图标选择组件实践详解
2018/12/03 Javascript
解决layui追加或者动态修改的表单元素“没效果”的问题
2019/09/18 Javascript
vue.js iview打包上线后字体图标不显示解决办法
2020/01/20 Javascript
Vue使用预渲染代替SSR的方法
2020/07/02 Javascript
从0到1学习JavaScript编写贪吃蛇游戏
2020/07/28 Javascript
Vue + ts实现轮播插件的示例
2020/11/10 Javascript
Python出现segfault错误解决方法
2016/04/16 Python
Django接受前端数据的几种方法总结
2016/11/04 Python
Python实现针对给定字符串寻找最长非重复子串的方法
2018/04/21 Python
python 实现屏幕录制示例
2019/12/23 Python
Django框架之中间件MiddleWare的实现
2019/12/30 Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
2020/06/29 Python
西雅图电动自行车公司:Rad Power Bikes
2020/02/02 全球购物
纬创Java面试题笔试题
2014/10/02 面试题
2015年行政执法工作总结
2015/05/23 职场文书
2015年学生资助工作总结
2015/05/25 职场文书