keras的ImageDataGenerator和flow()的用法说明


Posted in Python onJuly 03, 2020

ImageDataGenerator的参数自己看文档

from keras.preprocessing import image
import numpy as np

X_train=np.ones((3,123,123,1))
Y_train=np.array([[1],[2],[2]])
generator=image.ImageDataGenerator(featurewise_center=False,
  samplewise_center=False,
  featurewise_std_normalization=False,
  samplewise_std_normalization=False,
  zca_whitening=False,
  zca_epsilon=1e-6,
  rotation_range=180,
  width_shift_range=0.2,
  height_shift_range=0.2,
  shear_range=0,
  zoom_range=0.001,
  channel_shift_range=0,
  fill_mode='nearest',
  cval=0.,
  horizontal_flip=True,
  vertical_flip=True,
  rescale=None,
  preprocessing_function=None,
  data_format='channels_last')

a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环
'''
batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据
batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机
,输出的X,Y是一一对对应的
如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配

'''
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

输出

[[2]
 [1]
 [2]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

补充知识:tensorflow 与keras 混用之坑

在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下

其中错误为:TypeError: tuple indices must be integers, not list

再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下

原训练代码

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
 
#Каталог с данными для обучения
train_dir = 'train'
# Каталог с данными для проверки
val_dir = 'val'
# Каталог с данными для тестирования
test_dir = 'val'
 
# Размеры изображения
img_width, img_height = 800, 800
# Размерность тензора на основе изображения для входных данных в нейронную сеть
# backend Tensorflow, channels_last
input_shape = (img_width, img_height, 3)
# Количество эпох
epochs = 1
# Размер мини-выборки
batch_size = 4
# Количество изображений для обучения
nb_train_samples = 300
# Количество изображений для проверки
nb_validation_samples = 25
# Количество изображений для тестирования
nb_test_samples = 25
 
model = Sequential()
 
model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Conv2D(64, (5, 5), padding="same"))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
 
model.compile(loss='categorical_crossentropy',
       optimizer="Nadam",
       metrics=['accuracy'])
print(model.summary())
datagen = ImageDataGenerator(rescale=1. / 255)
 
train_generator = datagen.flow_from_directory(
  train_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
val_generator = datagen.flow_from_directory(
  val_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
test_generator = datagen.flow_from_directory(
  test_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
model.fit_generator(
  train_generator,
  steps_per_epoch=nb_train_samples // batch_size,
  epochs=epochs,
  validation_data=val_generator,
  validation_steps=nb_validation_samples // batch_size)
 
print('Сохраняем сеть')
model.save("grib.h5")
print("Сохранение завершено!")

模型载入

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from keras.models import load_model
 
print("Загрузка сети")
model = load_model("grib.h5")
print("Загрузка завершена!")

报错

/usr/bin/python3.5 /home/disk2/py/neroset/do.py
/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
 from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Загрузка сети
Traceback (most recent call last):
 File "/home/disk2/py/neroset/do.py", line 13, in <module>
  model = load_model("grib.h5")
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model
  model = model_from_config(model_config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config
  return layer_module.deserialize(config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize
  printable_module_name='layer')
 File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
  list(custom_objects.items())))
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config
  model.add(layer)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add
  output_tensor = layer(self.outputs[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__
  self.build(input_shapes[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build
  dim = input_shape[self.axis]
TypeError: tuple indices must be integers or slices, not list
 
Process finished with exit code 1

战斗种族解释

убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)

强调文本 强调文本

keras.preprocessing.image import ImageDataGenerator
keras.models import Sequential
keras.layers import Conv2D, MaxPooling2D, BatchNormalization
keras.layers import Activation, Dropout, Flatten, Dense

##完美解决

##附上原文链接

https://qa-help.ru/questions/keras-batchnormalization

以上这篇keras的ImageDataGenerator和flow()的用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的列表中利用remove()方法删除元素的教程
May 21 Python
Python实现图片转字符画的示例
Aug 22 Python
用tensorflow搭建CNN的方法
Mar 05 Python
Python使用Windows API创建窗口示例【基于win32gui模块】
May 09 Python
python实现停车管理系统
Nov 30 Python
python匹配两个短语之间的字符实例
Dec 25 Python
python getpass模块用法及实例详解
Oct 07 Python
python运用pygame库实现双人弹球小游戏
Nov 25 Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 Python
Python 模拟生成动态产生验证码图片的方法
Feb 01 Python
Python 发送邮件方法总结
Aug 10 Python
图文详解matlab原始处理图像几何变换
Jul 09 Python
python如何安装下载后的模块
Jul 03 #Python
python中id函数运行方式
Jul 03 #Python
Keras 数据增强ImageDataGenerator多输入多输出实例
Jul 03 #Python
keras和tensorflow使用fit_generator 批次训练操作
Jul 03 #Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
You might like
生成缩略图
2006/10/09 PHP
玩转虚拟域名◎+ .
2006/10/09 PHP
简单的php写入数据库类代码分享
2011/07/26 PHP
PHP5.4中json_encode中文转码的变化小结
2013/01/30 PHP
php中使用preg_replace函数匹配图片并加上链接的方法
2013/02/06 PHP
CodeIgniter实现从网站抓取图片并自动下载到文件夹里的方法
2015/06/17 PHP
jQuery 获取对象 根据属性、内容匹配, 还有表单元素匹配
2010/05/31 Javascript
js如何调用qq互联api实现第三方登录
2014/03/28 Javascript
JavaScript实现身份证验证代码
2016/02/17 Javascript
聊一聊JavaScript作用域和作用域链
2016/05/03 Javascript
简单谈谈Vue 模板各类数据绑定
2016/09/25 Javascript
JS中BOM相关知识点总结(必看篇)
2016/11/22 Javascript
学习 NodeJS 第八天:Socket 通讯实例
2016/12/21 NodeJs
JS对象是否拥有某属性如何判断
2017/02/03 Javascript
angularjs点击图片放大实现上传图片预览
2017/02/24 Javascript
兼容浏览器的js事件绑定函数(详解)
2017/05/09 Javascript
详解Node.js access_token的获取、存储及更新
2017/06/20 Javascript
JavaScript 中Date对象的格式化代码方法汇总
2017/09/06 Javascript
Vue与Node.js通过socket.io通信的示例代码
2018/07/25 Javascript
arctext.js实现文字平滑弯曲弧形效果的插件
2019/05/13 Javascript
Vue实现购物车实例代码两则
2020/05/30 Javascript
vue+element使用动态加载路由方式实现三级菜单页面显示的操作
2020/08/04 Javascript
vue将文件/图片批量打包下载zip的教程
2020/10/21 Javascript
[06:45]2018DOTA2亚洲邀请赛 4.5 SOLO赛 Sccc vs Maybe
2018/04/06 DOTA
在Python中使用mechanize模块模拟浏览器功能
2015/05/05 Python
Python+Django搭建自己的blog网站
2018/03/13 Python
pandas Dataframe行列读取的实例
2018/06/08 Python
Flask框架Flask-Login用法分析
2018/07/23 Python
Python设计模式之状态模式原理与用法详解
2019/01/15 Python
如何基于python操作json文件获取内容
2019/12/24 Python
HTML5 Canvas标签使用收录
2009/07/07 HTML / CSS
canvas仿写贝塞尔曲线的示例代码
2017/12/29 HTML / CSS
html5调用app分享功能示例(WebViewJavascriptBridge)
2018/03/21 HTML / CSS
MSC邮轮官方网站:加勒比海、地中海和世界各地的假期
2018/08/27 全球购物
网站创业计划书
2014/04/30 职场文书
学校教师培训工作总结
2015/10/14 职场文书