浅谈keras2 predict和fit_generator的坑


Posted in Python onJune 17, 2020

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python的Django框架完成视频处理任务的教程
Apr 02 Python
怎么使用pipenv管理你的python项目
Mar 12 Python
PYQT5设置textEdit自动滚屏的方法
Jun 14 Python
python自动化测试之如何解析excel文件
Jun 27 Python
Django多进程滚动日志问题解决方案
Dec 17 Python
用python3读取python2的pickle数据方式
Dec 25 Python
TensorFlow实现打印每一层的输出
Jan 21 Python
keras实现多种分类网络的方式
Jun 11 Python
Tensorflow--取tensorf指定列的操作方式
Jun 30 Python
浅谈anaconda python 版本对应关系
Oct 07 Python
python字典与json转换的方法总结
Dec 28 Python
Python list去重且保持原顺序不变的方法
Apr 03 Python
python能在浏览器能运行吗
Jun 17 #Python
python的pip有什么用
Jun 17 #Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 #Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
You might like
PHP Class&Object -- 解析PHP实现二叉树
2013/06/25 PHP
PHPMailer的主要功能特点和简单使用说明
2014/02/17 PHP
php使用for语句输出三角形的方法
2015/06/09 PHP
php判断对象是派生自哪个类的方法
2015/06/20 PHP
PHP foreach遍历多维数组实现方式
2016/11/16 PHP
php实现和c#一致的DES加密解密实例
2017/07/24 PHP
js arguments.callee的应用代码
2009/05/07 Javascript
jquery获取div宽度的实现思路与代码
2013/01/13 Javascript
jQuery学习之prop和attr的区别示例介绍
2013/11/15 Javascript
纯js和css实现渐变色包括静态渐变和动态渐变
2014/05/29 Javascript
jQuery实现炫酷的鼠标轨迹特效
2015/02/01 Javascript
jQuery插件Slider Revolution实现响应动画滑动图片切换效果
2015/06/05 Javascript
JavaScript+html5 canvas绘制的圆弧荡秋千效果完整实例
2016/01/26 Javascript
解决JS组件bootstrap table分页实现过程中遇到的问题
2016/04/21 Javascript
JS模仿腾讯图片站的图片翻页按钮效果完整实例
2016/06/21 Javascript
BOM系列第三篇之定时器应用(时钟、倒计时、秒表和闹钟)
2016/08/17 Javascript
Vue.js原理分析之observer模块详解
2017/02/17 Javascript
JS写XSS cookie stealer来窃取密码的步骤详解
2017/11/20 Javascript
JS面试题大坑之隐式类型转换实例代码
2018/10/14 Javascript
vue使用keep-alive实现组件切换时保存原组件数据方法
2020/10/30 Javascript
[56:56]VG vs LGD 2019国际邀请赛淘汰赛 胜者组 BO3 第一场 8.22
2019/09/05 DOTA
跟老齐学Python之关于循环的小伎俩
2014/10/02 Python
Python中使用scapy模拟数据包实现arp攻击、dns放大攻击例子
2014/10/23 Python
Python实现动态图解析、合成与倒放
2018/01/18 Python
Python计算公交发车时间的完整代码
2020/02/12 Python
python实现PolynomialFeatures多项式的方法
2021/01/06 Python
Html5 Canvas动画基础碰撞检测的实现
2018/12/06 HTML / CSS
老海军美国官网:Old Navy
2016/09/05 全球购物
实习医生自我评价
2013/09/22 职场文书
小学岗位竞聘方案
2014/01/22 职场文书
商业计算机应用专业自荐书
2014/06/09 职场文书
承诺保证书格式
2015/02/28 职场文书
2015年乡镇平安建设工作总结
2015/05/13 职场文书
庆七一晚会主持词
2015/06/30 职场文书
修改并编译golang源码的操作步骤
2021/07/25 Golang
世界十大儿童漫画书排名,法国国宝漫画排第五,第二是轰动日本连环
2022/03/18 欧美动漫