keras的ImageDataGenerator和flow()的用法说明


Posted in Python onJuly 03, 2020

ImageDataGenerator的参数自己看文档

from keras.preprocessing import image
import numpy as np

X_train=np.ones((3,123,123,1))
Y_train=np.array([[1],[2],[2]])
generator=image.ImageDataGenerator(featurewise_center=False,
  samplewise_center=False,
  featurewise_std_normalization=False,
  samplewise_std_normalization=False,
  zca_whitening=False,
  zca_epsilon=1e-6,
  rotation_range=180,
  width_shift_range=0.2,
  height_shift_range=0.2,
  shear_range=0,
  zoom_range=0.001,
  channel_shift_range=0,
  fill_mode='nearest',
  cval=0.,
  horizontal_flip=True,
  vertical_flip=True,
  rescale=None,
  preprocessing_function=None,
  data_format='channels_last')

a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环
'''
batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据
batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机
,输出的X,Y是一一对对应的
如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配

'''
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

print(Y)
X,Y=next(a)

输出

[[2]
 [1]
 [2]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

[[2]
 [2]
 [1]]

补充知识:tensorflow 与keras 混用之坑

在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下

其中错误为:TypeError: tuple indices must be integers, not list

再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下

原训练代码

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
 
#Каталог с данными для обучения
train_dir = 'train'
# Каталог с данными для проверки
val_dir = 'val'
# Каталог с данными для тестирования
test_dir = 'val'
 
# Размеры изображения
img_width, img_height = 800, 800
# Размерность тензора на основе изображения для входных данных в нейронную сеть
# backend Tensorflow, channels_last
input_shape = (img_width, img_height, 3)
# Количество эпох
epochs = 1
# Размер мини-выборки
batch_size = 4
# Количество изображений для обучения
nb_train_samples = 300
# Количество изображений для проверки
nb_validation_samples = 25
# Количество изображений для тестирования
nb_test_samples = 25
 
model = Sequential()
 
model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Conv2D(64, (5, 5), padding="same"))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
 
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
 
model.compile(loss='categorical_crossentropy',
       optimizer="Nadam",
       metrics=['accuracy'])
print(model.summary())
datagen = ImageDataGenerator(rescale=1. / 255)
 
train_generator = datagen.flow_from_directory(
  train_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
val_generator = datagen.flow_from_directory(
  val_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
test_generator = datagen.flow_from_directory(
  test_dir,
  target_size=(img_width, img_height),
  batch_size=batch_size,
  class_mode='categorical')
 
model.fit_generator(
  train_generator,
  steps_per_epoch=nb_train_samples // batch_size,
  epochs=epochs,
  validation_data=val_generator,
  validation_steps=nb_validation_samples // batch_size)
 
print('Сохраняем сеть')
model.save("grib.h5")
print("Сохранение завершено!")

模型载入

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from keras.models import load_model
 
print("Загрузка сети")
model = load_model("grib.h5")
print("Загрузка завершена!")

报错

/usr/bin/python3.5 /home/disk2/py/neroset/do.py
/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
 from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Загрузка сети
Traceback (most recent call last):
 File "/home/disk2/py/neroset/do.py", line 13, in <module>
  model = load_model("grib.h5")
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model
  model = model_from_config(model_config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config
  return layer_module.deserialize(config, custom_objects=custom_objects)
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize
  printable_module_name='layer')
 File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
  list(custom_objects.items())))
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config
  model.add(layer)
 File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add
  output_tensor = layer(self.outputs[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__
  self.build(input_shapes[0])
 File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build
  dim = input_shape[self.axis]
TypeError: tuple indices must be integers or slices, not list
 
Process finished with exit code 1

战斗种族解释

убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)

强调文本 强调文本

keras.preprocessing.image import ImageDataGenerator
keras.models import Sequential
keras.layers import Conv2D, MaxPooling2D, BatchNormalization
keras.layers import Activation, Dropout, Flatten, Dense

##完美解决

##附上原文链接

https://qa-help.ru/questions/keras-batchnormalization

以上这篇keras的ImageDataGenerator和flow()的用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pyhton中防止SQL注入的方法
Feb 05 Python
Python的装饰器模式与面向切面编程详解
Jun 21 Python
使用Python的PIL模块来进行图片对比
Feb 18 Python
python 安装virtualenv和virtualenvwrapper的方法
Jan 13 Python
用python3 urllib破解有道翻译反爬虫机制详解
Aug 14 Python
python3.7将代码打包成exe程序并添加图标的方法
Oct 11 Python
django实现web接口 python3模拟Post请求方式
Nov 19 Python
python 读写文件包含多种编码格式的解决方式
Dec 20 Python
Python 实现Image和Ndarray互相转换
Feb 19 Python
Python生成器实现简单&quot;生产者消费者&quot;模型代码实例
Mar 27 Python
python 双循环遍历list 变量判断代码
May 04 Python
用Python实现定时备份Mongodb数据并上传到FTP服务器
Jan 27 Python
python如何安装下载后的模块
Jul 03 #Python
python中id函数运行方式
Jul 03 #Python
Keras 数据增强ImageDataGenerator多输入多输出实例
Jul 03 #Python
keras和tensorflow使用fit_generator 批次训练操作
Jul 03 #Python
基于Python+QT的gui程序开发实现
Jul 03 #Python
keras 两种训练模型方式详解fit和fit_generator(节省内存)
Jul 03 #Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
You might like
Cakephp 执行主要流程
2010/03/24 PHP
php学习之简单计算器实现代码
2011/06/09 PHP
基于Zend的Config机制的应用分析
2013/05/02 PHP
基于PHP magic_quotes_gpc的使用方法详解
2013/06/24 PHP
php实现12306火车票余票查询和价格查询(12306火车票查询)
2014/01/14 PHP
PHP使用SWOOLE扩展实现定时同步 MySQL 数据
2017/04/09 PHP
Javascript入门学习资料收集整理篇
2008/07/06 Javascript
初学JavaScript第二章
2008/09/30 Javascript
JavaScript高级程序设计 XML、Ajax 学习笔记
2011/09/10 Javascript
jQuery带箭头提示框tooltips插件集锦
2014/11/17 Javascript
超实用的JavaScript表单代码段
2016/02/26 Javascript
jQuery根据name属性进行查找的用法分析
2016/06/23 Javascript
AngularJS打开页面隐藏显示表达式用法示例
2016/12/25 Javascript
jQuery简单绑定单个事件的方法示例
2017/06/10 jQuery
js循环map 获取所有的key和value的实现代码(json)
2018/05/09 Javascript
在Vue 中使用Typescript的示例代码
2018/09/10 Javascript
JS数组中对象去重操作示例
2019/06/04 Javascript
VUE写一个简单的表格实例
2019/08/06 Javascript
JavaScript函数重载操作实例浅析
2020/05/02 Javascript
Python切片工具pillow用法示例
2018/03/30 Python
python编辑用户登入界面的实现代码
2018/07/16 Python
Python如何发布程序的详细教程
2018/10/09 Python
Django 外键的使用方法详解
2019/07/19 Python
python3.7 openpyxl 删除指定一列或者一行的代码
2019/10/08 Python
Python for循环搭配else常见问题解决
2020/02/11 Python
Python通过2种方法输出带颜色字体
2020/03/02 Python
微软开源最强Python自动化神器Playwright(不用写一行代码)
2021/01/05 Python
html5 桌面提醒:Notifycations应用介绍
2012/11/27 HTML / CSS
YOINS官网:时尚女装网上购物
2017/03/17 全球购物
Clarria化妆品官方网站:购买天然和有机化妆品系列
2018/04/08 全球购物
《母亲的恩情》教学反思
2014/02/13 职场文书
人力资源经理的岗位职责
2014/03/02 职场文书
家长会演讲稿
2014/04/26 职场文书
读后感作文评语
2014/12/25 职场文书
煤矿安全保证书
2015/02/27 职场文书
总经理聘用协议书
2015/09/21 职场文书