浅谈keras2 predict和fit_generator的坑


Posted in Python onJune 17, 2020

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现网页链接提取的方法分享
Feb 25 Python
跟老齐学Python之眼花缭乱的运算符
Sep 14 Python
Python发送email的3种方法
Apr 28 Python
利用python实现简单的循环购物车功能示例代码
Jul 05 Python
Python人脸识别第三方库face_recognition接口说明文档
May 03 Python
Python实现操纵控制windows注册表的方法分析
May 24 Python
实例详解python函数的对象、函数嵌套、名称空间和作用域
May 31 Python
python爬取盘搜的有效链接实现代码
Jul 20 Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 Python
使用matplotlib动态刷新指定曲线实例
Apr 23 Python
keras 使用Lambda 快速新建层 添加多个参数操作
Jun 10 Python
python爬虫线程池案例详解(梨视频短视频爬取)
Feb 20 Python
python能在浏览器能运行吗
Jun 17 #Python
python的pip有什么用
Jun 17 #Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 #Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
You might like
如何将数据从文本导入到mysql
2006/10/09 PHP
PHP DataGrid 实现代码
2009/08/12 PHP
PHP chmod 函数与批量修改文件目录权限
2010/05/10 PHP
php下将图片以二进制存入mysql数据库中并显示的实现代码
2010/05/27 PHP
php实现的简易扫雷游戏实例
2015/07/09 PHP
PHP 中使用ajax时一些常见错误总结整理
2017/02/27 PHP
yii2.0框架数据库操作简单示例【添加,修改,删除,查询,打印等】
2020/04/13 PHP
A标签触发onclick事件而不跳转的多种解决方法
2013/06/27 Javascript
Jquery多选框互相内容交换的实例代码
2013/07/04 Javascript
js单词形式的运算符
2014/05/06 Javascript
JavaScript中对象property的读取和写入方法介绍
2014/12/30 Javascript
jQuery设置指定网页元素宽度和高度的方法
2015/03/25 Javascript
js判断浏览器是否支持严格模式的方法
2016/10/04 Javascript
JS日程管理插件FullCalendar简单实例
2017/02/07 Javascript
详解Vue单元测试Karma+Mocha学习笔记
2018/01/31 Javascript
JavaScript 中定义函数用 var foo = function () {} 和 function foo()区别介绍
2018/03/01 Javascript
小程序click-scroll组件设计
2019/06/18 Javascript
javascript中的数据类型检测方法详解
2019/08/07 Javascript
Django中login_required装饰器的深入介绍
2017/11/24 Python
Python解决走迷宫问题算法示例
2018/07/27 Python
python 搭建简单的http server,可直接post文件的实例
2019/01/03 Python
python实现得到当前登录用户信息的方法
2019/06/21 Python
用pytorch的nn.Module构造简单全链接层实例
2020/01/14 Python
python转化excel数字日期为标准日期操作
2020/07/14 Python
舒适的豪华鞋:Taryn Rose
2018/05/03 全球购物
企业面试题试卷附带答案
2015/12/20 面试题
人力资源行政经理自我评价
2013/10/23 职场文书
应用英语专业自荐信
2014/01/26 职场文书
个人承诺书
2014/03/26 职场文书
就业协议书范本
2014/04/11 职场文书
2015教师年度工作总结范文
2015/04/07 职场文书
试用期转正工作总结2015
2015/05/28 职场文书
2016年学校安全教育月活动总结
2016/04/06 职场文书
2019大学生实习报告
2019/06/21 职场文书
导游词之南京栖霞山
2019/10/18 职场文书
部分武汉产收音机展览
2022/04/07 无线电