浅谈keras2 predict和fit_generator的坑


Posted in Python onJune 17, 2020

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中AND、OR的一个使用小技巧
Feb 18 Python
Python 实现简单的shell sed替换功能(实例讲解)
Sep 29 Python
Python面向对象之类的定义与继承用法示例
Jan 14 Python
关于Python作用域自学总结
Jun 10 Python
Python变量访问权限控制详解
Jun 29 Python
OpenCV 边缘检测
Jul 10 Python
Python线上环境使用日志的及配置文件
Jul 28 Python
Python threading.local代码实例及原理解析
Mar 16 Python
Python求解排列中的逆序数个数实例
May 03 Python
Expected conditions模块使用方法汇总代码解析
Aug 13 Python
python 用Matplotlib作图中有多个Y轴
Nov 28 Python
Ubuntu20.04环境安装tensorflow2的方法步骤
Jan 29 Python
python能在浏览器能运行吗
Jun 17 #Python
python的pip有什么用
Jun 17 #Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 #Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
You might like
PHP求最大子序列和的算法实现
2011/06/24 PHP
PHP二维数组的去重问题解析
2011/07/17 PHP
利用yahoo汇率接口实现实时汇率转换示例 汇率转换器
2014/01/14 PHP
jQuery插件原来如此简单 jQuery插件的机制及实战
2012/02/07 Javascript
基于mootools插件实现遮罩层新手引导
2012/05/24 Javascript
Js,alert出现乱码问题的解决方法
2013/06/19 Javascript
jquery进行数组遍历如何跳出当前的each循环
2014/06/05 Javascript
angular中使用路由和$location切换视图
2015/01/23 Javascript
jQuery获取页面元素绝对与相对位置的方法
2015/06/10 Javascript
jquery判断复选框选中状态以及区分attr和prop
2015/12/18 Javascript
在AngularJS中如何使用谷歌地图把当前位置显示出来
2016/01/25 Javascript
使用基于Node.js的构建工具Grunt来发布ASP.NET MVC项目
2016/02/15 Javascript
简单分析javascript中的函数
2016/09/10 Javascript
实例浅析js的this
2016/12/11 Javascript
vue父组件中获取子组件中的数据(实例讲解)
2017/09/27 Javascript
js与jQuery实现的用户注册协议倒计时功能实例【三种方法】
2017/11/09 jQuery
Webpack中雪碧图插件使用详解
2018/05/25 Javascript
elementUI 动态生成几行几列的方法示例
2019/07/11 Javascript
layui操作列按钮个数和文字颜色的判断实例
2019/09/11 Javascript
[01:04:30]Fnatic vs Mineski 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/18 DOTA
对于Python的Django框架使用的一些实用建议
2015/04/03 Python
详解使用python的logging模块在stdout输出的两种方法
2017/05/17 Python
浅谈django开发者模式中的autoreload是如何实现的
2017/08/18 Python
python中从str中提取元素到list以及将list转换为str的方法
2018/06/26 Python
Python 将Matrix、Dict保存到文件的方法
2018/10/30 Python
HTML5 CSS3给网站设计带来出色效果
2009/07/16 HTML / CSS
Banana Republic欧盟:美国都市简约风格的代表品牌
2018/05/09 全球购物
花卉与景观设计系大学生求职信
2013/10/01 职场文书
求职信怎么写
2014/05/23 职场文书
趣味运动会广播稿
2014/09/13 职场文书
致800米运动员广播稿(10篇)
2014/10/17 职场文书
2014年学生会部门工作总结
2014/11/07 职场文书
党员转正意见怎么写
2015/06/03 职场文书
建立共青团委员会的请示
2019/04/02 职场文书
68句权威创业名言
2019/08/26 职场文书
vue3中provide && inject的使用
2021/07/01 Vue.js