浅谈keras2 predict和fit_generator的坑


Posted in Python onJune 17, 2020

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的setuptools框架下生成egg的教程
Apr 13 Python
Python实现对PPT文件进行截图操作的方法
Apr 28 Python
总结Python中逻辑运算符的使用
May 13 Python
Python中文竖排显示的方法
Jul 28 Python
python实现简单socket通信的方法
Apr 19 Python
使用python画社交网络图实例代码
Jul 10 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 Python
基于Python数据分析之pandas统计分析
Mar 03 Python
windows上彻底删除jupyter notebook的实现
Apr 13 Python
Python try except异常捕获机制原理解析
Apr 18 Python
Django生成数据库及添加用户报错解决方案
Oct 09 Python
Python何绘制带有背景色块的折线图
Apr 23 Python
python能在浏览器能运行吗
Jun 17 #Python
python的pip有什么用
Jun 17 #Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 #Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
You might like
SONY SRF-22W(33W)的电路分析和维修案例
2021/03/02 无线电
cakephp2.X多表联合查询join及使用分页查询的方法
2017/02/23 PHP
PHP面向对象多态性实现方法简单示例
2017/09/27 PHP
PHP的垃圾回收机制代码实例讲解
2021/02/27 PHP
JavaScript修改css样式style
2008/04/15 Javascript
document.compatMode介绍
2009/05/21 Javascript
ExtJs Excel导出并下载IIS服务器端遇到的问题
2011/09/16 Javascript
js 加密压缩出现bug解决方案
2014/11/25 Javascript
js控制输入框获得和失去焦点时状态显示的方法
2015/01/30 Javascript
JS实现点击文字对应DIV层不停闪动效果的方法
2015/03/02 Javascript
JavaScript使用replace函数替换字符串的方法
2015/04/06 Javascript
js简单实现竖向tab选项卡的方法
2015/05/04 Javascript
JavaScript String(字符串)对象的简单实例(推荐)
2016/08/31 Javascript
js简单正则验证汉字英文及下划线的方法
2016/11/28 Javascript
原生js实现回复评论功能
2017/01/18 Javascript
JavaScript运动框架 多值运动(四)
2017/05/18 Javascript
AngularJS模糊查询功能实现代码(过滤内容下拉菜单排序过滤敏感字符验证判断后添加表格信息)
2017/10/24 Javascript
vue.js实现标签页切换效果
2018/06/07 Javascript
前端防止用户重复提交js实现代码示例
2018/09/07 Javascript
小程序scroll-view组件实现滚动的示例代码
2018/09/20 Javascript
jQuery点击页面其他部分隐藏下拉菜单功能
2018/11/27 jQuery
[01:08:56]DOTA2-DPC中国联赛 正赛 Magma vs LBZS BO3 第一场 2月7日
2021/03/11 DOTA
Python读取一个目录下所有目录和文件的方法
2016/07/15 Python
python中利用zfill方法自动给数字前面补0
2018/04/10 Python
python向已存在的excel中新增表,不覆盖原数据的实例
2018/05/02 Python
使用CodeMirror实现Python3在线编辑器的示例代码
2019/01/14 Python
Django csrf 两种方法设置form的实例
2019/02/03 Python
Python面向对象程序设计之私有属性及私有方法示例
2019/04/08 Python
Python3之外部文件调用Django程序操作model等文件实现方式
2020/04/07 Python
css3的focus-within选择器的使用
2020/05/11 HTML / CSS
中职生求职信
2014/07/01 职场文书
国家奖学金获奖感言
2014/08/16 职场文书
就业协议书盖章的注意事项
2014/09/28 职场文书
员工试用期转正自我评价
2015/03/10 职场文书
2015年信访工作总结
2015/04/07 职场文书
Html5调用企业微信的实现
2021/04/16 HTML / CSS