浅谈keras2 predict和fit_generator的坑


Posted in Python onJune 17, 2020

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中__call__用法实例
Aug 29 Python
利用python批量修改word文件名的方法示例
Oct 17 Python
Python3多线程基础知识点
Feb 19 Python
numpy.where() 用法详解
May 27 Python
python实现数据分析与建模
Jul 11 Python
pygame实现烟雨蒙蒙下彩虹雨
Nov 11 Python
解决django后台管理界面添加中文内容乱码问题
Nov 15 Python
如何使用python3获取当前路径及os.path.dirname的使用
Dec 13 Python
python实现爱奇艺登陆密码RSA加密的方法示例详解
May 27 Python
python+flask编写一个简单的登录接口
Nov 13 Python
完美解决Pycharm中matplotlib画图中文乱码问题
Jan 11 Python
Python基础之赋值,浅拷贝,深拷贝的区别
Apr 30 Python
python能在浏览器能运行吗
Jun 17 #Python
python的pip有什么用
Jun 17 #Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 #Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
You might like
IIS下配置Php+Mysql+zend的图文教程
2006/12/08 PHP
PHP 超链接 抓取实现代码
2009/06/29 PHP
php通用防注入程序 推荐
2011/02/26 PHP
PHP抓取淘宝商品的用户晒单评论+图片+搜索商品列表实例
2016/04/14 PHP
ThinkPHP中limit()使用方法详解
2016/04/19 PHP
PHP仿微信发红包领红包效果
2016/10/30 PHP
提高网站信任度的技巧
2008/10/17 Javascript
Javascript Boolean、Nnumber、String 强制类型转换的区别详细介绍
2012/12/13 Javascript
JS简单实现登陆验证附效果图
2013/11/19 Javascript
document节点对象的获取方式示例介绍
2013/12/24 Javascript
jQuery设置单选按钮radio选中/不可用的实例代码
2016/06/24 Javascript
jQuery实现级联下拉框实战(5)
2017/02/08 Javascript
ES5学习教程之Array对象
2017/04/01 Javascript
JS实现简单的浮动碰撞效果示例
2017/12/28 Javascript
vue2.0使用swiper组件实现轮播的示例代码
2018/03/03 Javascript
微信小程序引入模块中wxml、wxss、js的方法示例
2019/08/09 Javascript
VUE实现密码验证与提示功能
2019/10/18 Javascript
解决vue-cli项目开发运行时内存暴涨卡死电脑问题
2019/10/29 Javascript
Angular之jwt令牌身份验证的实现
2020/02/14 Javascript
python调用java的Webservice示例
2014/03/10 Python
python中的闭包函数
2018/02/09 Python
python实现基于朴素贝叶斯的垃圾分类算法
2019/07/09 Python
使用Python和OpenCV检测图像中的物体并将物体裁剪下来
2019/10/30 Python
python 如何去除字符串头尾的多余符号
2019/11/19 Python
python django中8000端口被占用的解决
2019/12/17 Python
Python实现密钥密码(加解密)实例详解
2020/04/26 Python
python中pivot()函数基础知识点
2021/01/03 Python
Agoda西班牙:全球特价酒店预订
2017/06/03 全球购物
纪律教育学习心得体会
2014/09/02 职场文书
商超业务员岗位职责
2015/02/13 职场文书
2015年学校办公室主任工作总结
2015/07/20 职场文书
《植物妈妈有办法》教学反思
2016/02/23 职场文书
python第三方网页解析器 lxml 扩展库与 xpath 的使用方法
2021/04/06 Python
4种非常实用的python内置数据结构
2021/04/28 Python
解决Laravel使用验证时跳转到首页的问题
2021/11/17 PHP
tomcat默认最大连接数及相关调整方法
2022/05/06 Servers