keras 两种训练模型方式详解fit和fit_generator(节省内存)


Posted in Python onJuly 03, 2020

第一种,fit

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

#读取数据
x_train = np.load("D:\\machineTest\\testmulPE_win7\\data_sprase.npy")[()]
y_train = np.load("D:\\machineTest\\testmulPE_win7\\lable_sprase.npy")

# 获取分类类别总数
classes = len(np.unique(y_train))

#对label进行one-hot编码,必须的
label_encoder = LabelEncoder()
integer_encoded = label_encoder.fit_transform(y_train)
onehot_encoder = OneHotEncoder(sparse=False)
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
y_train = onehot_encoder.fit_transform(integer_encoded)

#shuffle
X_train, X_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.3, random_state=0)

model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=784))
model.add(Dense(units=classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])
model.fit(X_train, y_train, epochs=50, batch_size=128)
score = model.evaluate(X_test, y_test, batch_size=128)
# #fit参数详情
# keras.models.fit(
# self,
# x=None, #训练数据
# y=None, #训练数据label标签
# batch_size=None, #每经过多少个sample更新一次权重,defult 32
# epochs=1, #训练的轮数epochs
# verbose=1, #0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
# callbacks=None,#list,list中的元素为keras.callbacks.Callback对象,在训练过程中会调用list中的回调函数
# validation_split=0., #浮点数0-1,将训练集中的一部分比例作为验证集,然后下面的验证集validation_data将不会起到作用
# validation_data=None, #验证集
# shuffle=True, #布尔值和字符串,如果为布尔值,表示是否在每一次epoch训练前随机打乱输入样本的顺序,如果为"batch",为处理HDF5数据
# class_weight=None, #dict,分类问题的时候,有的类别可能需要额外关注,分错的时候给的惩罚会比较大,所以权重会调高,体现在损失函数上面
# sample_weight=None, #array,和输入样本对等长度,对输入的每个特征+个权值,如果是时序的数据,则采用(samples,sequence_length)的矩阵
# initial_epoch=0, #如果之前做了训练,则可以从指定的epoch开始训练
# steps_per_epoch=None, #将一个epoch分为多少个steps,也就是划分一个batch_size多大,比如steps_per_epoch=10,则就是将训练集分为10份,不能和batch_size共同使用
# validation_steps=None, #当steps_per_epoch被启用的时候才有用,验证集的batch_size
# **kwargs #用于和后端交互
# )
# 
# 返回的是一个History对象,可以通过History.history来查看训练过程,loss值等等

第二种,fit_generator(节省内存)

# 第二种,可以节省内存
'''
Created on 2018-4-11
fit_generate.txt,后面两列为lable,已经one-hot编码
1 2 0 1
2 3 1 0
1 3 0 1
1 4 0 1
2 4 1 0
2 5 1 0

'''
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split

count =1 
def generate_arrays_from_file(path):
 global count
 while 1:
  datas = np.loadtxt(path,delimiter=' ',dtype="int")
  x = datas[:,:2]
  y = datas[:,2:]
  print("count:"+str(count))
  count = count+1
  yield (x,y)
x_valid = np.array([[1,2],[2,3]])
y_valid = np.array([[0,1],[1,0]])
model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])

model.fit_generator(generate_arrays_from_file("D:\\fit_generate.txt"),steps_per_epoch=10, epochs=2,max_queue_size=1,validation_data=(x_valid, y_valid),workers=1)
# steps_per_epoch 每执行一次steps,就去执行一次生产函数generate_arrays_from_file
# max_queue_size 从生产函数中出来的数据时可以缓存在queue队列中
# 输出如下:
# Epoch 1/2
# count:1
# count:2
# 
# 1/10 [==>...........................] - ETA: 2s - loss: 0.7145 - acc: 0.3333count:3
# count:4
# count:5
# count:6
# count:7
# 
# 7/10 [====================>.........] - ETA: 0s - loss: 0.7001 - acc: 0.4286count:8
# count:9
# count:10
# count:11
# 
# 10/10 [==============================] - 0s 36ms/step - loss: 0.6960 - acc: 0.4500 - val_loss: 0.6794 - val_acc: 0.5000
# Epoch 2/2
# 
# 1/10 [==>...........................] - ETA: 0s - loss: 0.6829 - acc: 0.5000count:12
# count:13
# count:14
# count:15
# 
# 5/10 [==============>...............] - ETA: 0s - loss: 0.6800 - acc: 0.5000count:16
# count:17
# count:18
# count:19
# count:20
# 
# 10/10 [==============================] - 0s 11ms/step - loss: 0.6766 - acc: 0.5000 - val_loss: 0.6662 - val_acc: 0.5000

补充知识:

自动生成数据还可以继承keras.utils.Sequence,然后写自己的生成数据类:

keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
 
 def __init__(self, datas, batch_size=1, shuffle=True):
  self.batch_size = batch_size
  self.datas = datas
  self.indexes = np.arange(len(self.datas))
  self.shuffle = shuffle

 def __len__(self):
  #计算每一个epoch的迭代次数
  return math.ceil(len(self.datas) / float(self.batch_size))

 def __getitem__(self, index):
  #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
  # 生成batch_size个索引
  batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
  # 根据索引获取datas集合中的数据
  batch_datas = [self.datas[k] for k in batch_indexs]

  # 生成数据
  X, y = self.data_generation(batch_datas)

  return X, y

 def on_epoch_end(self):
  #在每一次epoch结束是否需要进行一次随机,重新随机一下index
  if self.shuffle == True:
   np.random.shuffle(self.indexes)

 def data_generation(self, batch_datas):
  images = []
  labels = []

  # 生成数据
  for i, data in enumerate(batch_datas):
   #x_train数据
   image = cv2.imread(data)
   image = list(image)
   images.append(image)
   #y_train数据 
   right = data.rfind("\\",0)
   left = data.rfind("\\",0,right)+1
   class_name = data[left:right]
   if class_name=="dog":
    labels.append([0,1])
   else: 
    labels.append([1,0])
  #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
  return np.array(images), np.array(labels)
 
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
 file_path = os.path.join("D:/xxx", file)
 if os.path.isdir(file_path):
  class_num = class_num + 1
  for sub_file in os.listdir(file_path):
   train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras 两种训练模型方式详解fit和fit_generator(节省内存)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之眼花缭乱的运算符
Sep 14 Python
python之wxPython菜单使用详解
Sep 28 Python
Python读写txt文本文件的操作方法全解析
Jun 26 Python
实现python版本的按任意键继续/退出
Sep 26 Python
ubuntu安装mysql pycharm sublime
Feb 20 Python
Python解析并读取PDF文件内容的方法
May 08 Python
python使用folium库绘制地图点击框
Sep 21 Python
python在回调函数中获取返回值的方法
Feb 22 Python
Python自动化导出zabbix数据并发邮件脚本
Aug 16 Python
详解Python并发编程之从性能角度来初探并发编程
Aug 23 Python
使用python 将图片复制到系统剪贴中
Dec 13 Python
pytorch加载自己的图像数据集实例
Jul 07 Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 #Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
You might like
用php实现选择排序的解决方法
2013/05/04 PHP
PHP在innodb引擎下快速代建全文搜索功能简明教程【基于xunsearch】
2016/10/14 PHP
php中通用的excel导出方法实例
2017/12/30 PHP
默认让页面的第一个控件选中的javascript代码
2009/12/26 Javascript
jQuery前台数据获取实现代码
2011/03/16 Javascript
10款新鲜出炉的 jQuery 插件(Ajax 插件,有幻灯片、图片画廊、菜单等)
2011/06/08 Javascript
JS编程小常识很有用
2012/11/26 Javascript
js实时获取系统当前时间实例代码
2013/06/28 Javascript
js简单实现HTML标签Select联动带跳转
2013/10/23 Javascript
js判断IE浏览器版本过低示例代码
2013/11/22 Javascript
jquery实现图片按比例缩放示例
2014/07/01 Javascript
使用jQueryMobile实现滑动翻页效果的方法
2015/02/04 Javascript
JQuery分屏指示器图片轮换效果实例
2015/05/21 Javascript
jQuery实现按钮只点击一次后就取消点击事件绑定的方法
2015/06/26 Javascript
详解JavaScript的Date对象(制作简易钟表)
2020/04/07 Javascript
基于jQuery实现鼠标点击导航菜单水波动画效果附源码下载
2016/01/06 Javascript
js 获取本地文件及目录的方法(推荐)
2016/11/10 Javascript
详解Vue 多级组件透传新方法provide/inject
2018/05/09 Javascript
解决vue点击控制单个样式的问题
2018/09/05 Javascript
详解微信小程序缓存--缓存时效性
2019/05/02 Javascript
[00:21]DOTA2亚洲邀请赛 Logo演绎
2015/02/07 DOTA
[03:58]兄弟们,回来开黑了!DOTA2昔日战友招募宣传视频
2016/07/17 DOTA
python中 ? : 三元表达式的使用介绍
2013/10/09 Python
python计算圆周率pi的方法
2015/07/11 Python
python制作最美应用的爬虫
2015/10/28 Python
解读! Python在人工智能中的作用
2017/11/14 Python
python矩阵的转置和逆转实例
2018/12/12 Python
pycharm快捷键汇总
2020/02/14 Python
python 实现 hive中类似 lateral view explode的功能示例
2020/05/18 Python
Django nginx配置实现过程详解
2020/09/10 Python
前端canvas水印快速制作(附完整代码)
2019/09/19 HTML / CSS
英国派对礼服和连衣裙购物网站:TFNC London
2018/07/07 全球购物
物理分数没达标检讨书
2014/09/13 职场文书
高中家长意见怎么写
2015/06/03 职场文书
php中pcntl_fork详解
2021/04/01 PHP
baselines示例程序train_cartpole.py的ImportError
2022/05/20 Python