keras的ImageDataGenerator和flow()的用法说明


Posted in Python onJuly 03, 2020

ImageDataGenerator的参数自己看文档

from keras.preprocessing import image
import numpy as np

X_train=np.ones((3,123,123,1))
Y_train=np.array([[1],[2],[2]])
generator=image.ImageDataGenerator(featurewise_center=False,
  samplewise_center=False,
  featurewise_std_normalization=False,
  samplewise_std_normalization=False,
  zca_whitening=False,
  zca_epsilon=1e-6,
  rotation_range=180,
  width_shift_range=0.2,
  height_shift_range=0.2,
  shear_range=0,
  zoom_range=0.001,
  channel_shift_range=0,
  fill_mode='nearest',
  cval=0.,
  horizontal_flip=True,
  vertical_flip=True,
  rescale=None,
  preprocessing_function=None,
  data_format='channels_last')

a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环
'''
batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据
batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机
,输出的X,Y是一一对对应的
如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配

'''
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

输出

[[2]
 [1]
 [2]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

补充知识:tensorflow 与keras 混用之坑

在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下

其中错误为:TypeError: tuple indices must be integers, not list

再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下

原训练代码

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
 
#Каталог с данными для обучения
train_dir = 'train'
# Каталог с данными для проверки
val_dir = 'val'
# Каталог с данными для тестирования
test_dir = 'val'
 
# Размеры изображения
img_width, img_height = 800, 800
# Размерность тензора на основе изображения для входных данных в нейронную сеть
# backend Tensorflow, channels_last
input_shape = (img_width, img_height, 3)
# Количество эпох
epochs = 1
# Размер мини-выборки
batch_size = 4
# Количество изображений для обучения
nb_train_samples = 300
# Количество изображений для проверки
nb_validation_samples = 25
# Количество изображений для тестирования
nb_test_samples = 25
 
model = Sequential()
 
model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Conv2D(64, (5, 5), padding="same"))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
 
model.compile(loss='categorical_crossentropy',
       optimizer="Nadam",
       metrics=['accuracy'])
print(model.summary())
datagen = ImageDataGenerator(rescale=1. / 255)
 
train_generator = datagen.flow_from_directory(
  train_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
val_generator = datagen.flow_from_directory(
  val_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
test_generator = datagen.flow_from_directory(
  test_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
model.fit_generator(
  train_generator,
  steps_per_epoch=nb_train_samples // batch_size,
  epochs=epochs,
  validation_data=val_generator,
  validation_steps=nb_validation_samples // batch_size)
 
print('Сохраняем сеть')
model.save("grib.h5")
print("Сохранение завершено!")

模型载入

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from keras.models import load_model
 
print("Загрузка сети")
model = load_model("grib.h5")
print("Загрузка завершена!")

报错

/usr/bin/python3.5 /home/disk2/py/neroset/do.py
/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
 from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Загрузка сети
Traceback (most recent call last):
 File "/home/disk2/py/neroset/do.py", line 13, in <module>
  model = load_model("grib.h5")
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model
  model = model_from_config(model_config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config
  return layer_module.deserialize(config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize
  printable_module_name='layer')
 File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
  list(custom_objects.items())))
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config
  model.add(layer)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add
  output_tensor = layer(self.outputs[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__
  self.build(input_shapes[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build
  dim = input_shape[self.axis]
TypeError: tuple indices must be integers or slices, not list
 
Process finished with exit code 1

战斗种族解释

убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)

强调文本 强调文本

keras.preprocessing.image import ImageDataGenerator
keras.models import Sequential
keras.layers import Conv2D, MaxPooling2D, BatchNormalization
keras.layers import Activation, Dropout, Flatten, Dense

##完美解决

##附上原文链接

https://qa-help.ru/questions/keras-batchnormalization

以上这篇keras的ImageDataGenerator和flow()的用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
探究Python中isalnum()方法的使用
May 18 Python
Django rest framework基本介绍与代码示例
Jan 26 Python
将pandas.dataframe的数据写入到文件中的方法
Dec 07 Python
python根据文章标题内容自动生成摘要的实例
Feb 21 Python
pow在python中的含义及用法
Jul 11 Python
Python django框架输入汉字,数字,字符生成二维码实现详解
Sep 24 Python
Python命令行click参数用法解析
Dec 19 Python
python GUI库图形界面开发之PyQt5简单绘图板实例与代码分析
Mar 08 Python
django haystack实现全文检索的示例代码
Jun 24 Python
keras的ImageDataGenerator和flow()的用法说明
Jul 03 Python
selenium与xpath之获取指定位置的元素的实现
Jan 26 Python
Python万能模板案例之matplotlib绘制甘特图
Apr 13 Python
python如何安装下载后的模块
Jul 03 #Python
python中id函数运行方式
Jul 03 #Python
Keras 数据增强ImageDataGenerator多输入多输出实例
Jul 03 #Python
keras和tensorflow使用fit_generator 批次训练操作
Jul 03 #Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
You might like
codeigniter自带数据库类使用方法说明
2014/03/25 PHP
微信网页授权(OAuth2.0) PHP 源码简单实现
2016/08/29 PHP
php中使用websocket详解
2016/09/23 PHP
Thinkphp 空操作、空控制器、命名空间(详解)
2017/05/05 PHP
Laravel中七个非常有用但很少人知道的Carbon方法
2017/09/21 PHP
番茄的表单验证类代码修改版
2008/07/18 Javascript
基于jQuery的淡入淡出可自动切换的幻灯插件打包下载
2010/09/15 Javascript
silverlight线程与基于事件驱动javascript引擎(实现轨迹回放功能)
2011/08/09 Javascript
jQuery(1.6.3) 中css方法对浮动的实现缺陷分析
2011/09/09 Javascript
让复选框只能选择一项的方法
2013/10/08 Javascript
js实现编辑div节点名称的方法
2014/12/17 Javascript
javascript数组输出的两种方式
2015/01/13 Javascript
BootStrap实现邮件列表的分页和模态框添加邮件的功能
2016/10/13 Javascript
微信小程序 数组(增,删,改,查)等操作实例详解
2017/01/05 Javascript
JavaScript无阻塞加载和defer、async详解
2017/02/26 Javascript
JS仿JQuery选择器功能
2017/03/08 Javascript
jQuery实现鼠标经过显示动画边框特效
2017/03/24 jQuery
一些你可能不熟悉的JS知识点总结
2019/03/15 Javascript
vue+elementUI实现表单和图片上传及验证功能示例
2019/05/14 Javascript
JQuery中DOM节点的操作与访问方法实例分析
2019/12/23 jQuery
Vue微信公众号网页分享的示例代码
2020/05/28 Javascript
[01:29:46]DOTA2上海特级锦标赛C组资格赛#1 OG VS LGD第二局
2016/02/27 DOTA
[02:52]2017DOTA2国际邀请赛中国区预选赛晋级之路
2017/07/03 DOTA
[55:42]VG vs VGJ.T 2018国际邀请赛淘汰赛BO1 8.21
2018/08/22 DOTA
浅谈python中的面向对象和类的基本语法
2016/06/13 Python
利用Python将时间或时间间隔转为ISO 8601格式方法示例
2017/09/05 Python
使用NumPy和pandas对CSV文件进行写操作的实例
2018/06/14 Python
Laravel+Dingo/Api 自定义响应的实现
2019/02/17 Python
如何基于Python批量下载音乐
2019/11/11 Python
Numpy与Pytorch 矩阵操作方式
2019/12/27 Python
阿拉伯时尚购物网站:Nisnass
2021/02/07 全球购物
信息管理应届生求职信
2014/03/07 职场文书
安全教育主题班会总结
2015/08/14 职场文书
你会写请假条吗?
2019/06/26 职场文书
Python - 10行代码集2000张美女图
2021/05/23 Python
最新最全的手机号验证正则表达式
2022/02/24 Javascript