浅谈keras2 predict和fit_generator的坑


Posted in Python onJune 17, 2020

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之不要红头文件(1)
Sep 28 Python
Python中的各种装饰器详解
Apr 11 Python
浅谈django三种缓存模式的使用及注意点
Sep 30 Python
python实现的生成word文档功能示例
Aug 23 Python
多版本python的pip 升级后, pip2 pip3 与python版本失配解决方法
Sep 11 Python
python求一个字符串的所有排列的实现方法
Feb 04 Python
Django ORM判断查询结果是否为空,判断django中的orm为空实例
Jul 09 Python
Python性能分析工具py-spy原理用法解析
Jul 27 Python
详解python中的闭包
Sep 07 Python
python matplotlib绘制三维图的示例
Sep 24 Python
python中的yield from语法快速学习
Nov 06 Python
python3实现Dijkstra算法最短路径的实现
May 12 Python
python能在浏览器能运行吗
Jun 17 #Python
python的pip有什么用
Jun 17 #Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 #Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
You might like
php UTF8 文件的签名问题
2009/10/30 PHP
Codeigniter操作数据库表的优化写法总结
2014/06/12 PHP
运用jquery实现table单双行不同显示并能单行选中
2009/07/25 Javascript
用 Javascript 验证表单(form)中的单选(radio)值
2009/09/08 Javascript
js中top/parent/frame概述及案例应用
2013/02/06 Javascript
js 实现日期灵活格式化的小例子
2013/07/14 Javascript
ECMAScript6的新特性箭头函数(Arrow Function)详细介绍
2014/06/07 Javascript
JQuery表单验证插件EasyValidator用法分析
2014/11/15 Javascript
JavaScript实现文字与图片拖拽效果的方法
2015/02/16 Javascript
AngularJS的一些基本样式初窥
2015/07/27 Javascript
NodeJs的优势和适合开发的程序
2016/08/14 NodeJs
JavaScript获取URL中参数querystring的方法详解
2016/10/11 Javascript
vue自定义指令实现v-tap插件
2016/11/03 Javascript
echarts3 使用总结(绘制各种图表,地图)
2017/01/05 Javascript
Angularjs2不同组件间的通信实例代码
2017/05/06 Javascript
jQuery操作DOM_动力节点Java学院整理
2017/07/04 jQuery
jQuery插件artDialog.js使用与关闭方法示例
2017/10/09 jQuery
Vue-Router2.X多种路由实现方式总结
2018/02/09 Javascript
使用vue打包时vendor文件过大或者是app.js文件很大的问题
2018/06/29 Javascript
three.js搭建室内场景教程
2018/12/30 Javascript
解决ant Design中Select设置initialValue时的大坑
2020/10/29 Javascript
Python标准库与第三方库详解
2014/07/22 Python
在Python的循环体中使用else语句的方法
2015/03/30 Python
python获取一组汉字拼音首字母的方法
2015/07/01 Python
python 递归遍历文件夹,并打印满足条件的文件路径实例
2017/08/30 Python
Python实现嵌套列表及字典并按某一元素去重复功能示例
2017/11/30 Python
Python txt文件加入字典并查询的方法
2019/01/15 Python
python用for循环求和的方法总结
2019/07/08 Python
python celery分布式任务队列的使用详解
2019/07/08 Python
Pycharm 2019 破解激活方法图文详解
2019/10/11 Python
python requests.get带header
2020/05/05 Python
英国知名衬衫品牌美国网站:Charles Tyrwhitt美国
2016/08/28 全球购物
俄罗斯GamePark游戏商店网站:购买游戏、游戏机和配件
2020/03/13 全球购物
中学生运动会入场词
2014/02/12 职场文书
乡镇信息公开实施方案
2014/03/23 职场文书
筑梦中国心得体会
2016/01/18 职场文书