浅谈keras2 predict和fit_generator的坑


Posted in Python onJune 17, 2020

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)
    
    # data processing
    # ................
    
    total_size=X_train.size
    #batch_size means how many data you want to train one step
    
    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用SAX解析xml实例
Nov 21 Python
Python中利用sqrt()方法进行平方根计算的教程
May 15 Python
Python脚本实现自动发带图的微博
Apr 27 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 Python
Python使用Matplotlib实现Logos设计代码
Dec 25 Python
Python建立Map写Excel表实例解析
Jan 17 Python
Python对切片命名的实现方法
Oct 16 Python
python 堆和优先队列的使用详解
Mar 05 Python
Python集中化管理平台Ansible介绍与YAML简介
Jun 12 Python
python实现从尾到头打印单链表操作示例
Feb 22 Python
django admin 添加自定义链接方式
Mar 11 Python
Pycharm 2020.1 版配置优化的详细教程
Aug 07 Python
python能在浏览器能运行吗
Jun 17 #Python
python的pip有什么用
Jun 17 #Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 #Python
python用什么编辑器进行项目开发
Jun 17 #Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 #Python
python语言的优势是什么
Jun 17 #Python
python有几个版本
Jun 17 #Python
You might like
不重新编译PHP为php增加openssl模块的方法
2011/06/14 PHP
php获取网站百度快照日期的方法
2015/07/29 PHP
jQuery向下滚动即时加载内容实现的瀑布流效果
2016/01/07 PHP
Laravel最佳分割路由文件(routes.php)的方式
2016/08/04 PHP
设定php简写功能的方法
2019/11/28 PHP
深入聊聊Array的sort方法的使用技巧.详细点评protype.js中的sortBy方法
2007/04/12 Javascript
地震发生中逃生十大法则
2008/05/12 Javascript
分享jQuery网页元素拖拽插件
2020/12/01 Javascript
jQuery+HTML5实现弹出创意搜索框层
2016/12/29 Javascript
bootstrap table实现x-editable的行单元格编辑及解决数据Empty和支持多样式问题
2017/08/10 Javascript
JS中Attr的用法详解
2017/10/09 Javascript
JS+H5 Canvas实现时钟效果
2018/07/20 Javascript
vue源码中的检测方法的实现
2019/09/26 Javascript
将RGB值转换为灰度值的简单算法
2019/10/09 Javascript
vue-cli3使用mock数据的方法分析
2020/03/16 Javascript
vue+element获取el-table某行的下标,根据下标操作数组对象方式
2020/08/07 Javascript
JavaScript读取本地文件常用方法流程解析
2020/10/12 Javascript
vue3+typescript实现图片懒加载插件
2020/10/26 Javascript
python实现批量转换文件编码(批转换编码示例)
2014/01/23 Python
Python程序员开发中常犯的10个错误
2014/07/07 Python
跟老齐学Python之从if开始语句的征程
2014/09/14 Python
Python中使用bidict模块双向字典结构的奇技淫巧
2016/07/12 Python
Python实现全排列的打印
2018/08/18 Python
python3.6根据m3u8下载mp4视频
2019/06/17 Python
浅谈python3中input输入的使用
2019/08/02 Python
关于numpy中eye和identity的区别详解
2019/11/29 Python
基于python的docx模块处理word和WPS的docx格式文件方式
2020/02/13 Python
CSS3 旋转立方体问题详解
2020/01/09 HTML / CSS
SEPHORA丝芙兰德国官方购物网站:化妆品、护肤品和香水
2020/01/21 全球购物
周鸿祎:教你写创业计划书
2013/12/30 职场文书
2014年审计人员工作总结
2014/12/19 职场文书
设备技术员岗位职责
2015/04/11 职场文书
小学生安全保证书
2015/05/09 职场文书
入党积极分子党支部意见
2015/06/02 职场文书
清明节随笔
2015/08/15 职场文书
iPhone13 Pro外观确定,升级4800万镜头,4月20日发新品
2021/04/15 数码科技