keras 两种训练模型方式详解fit和fit_generator(节省内存)


Posted in Python onJuly 03, 2020

第一种,fit

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

#读取数据
x_train = np.load("D:\\machineTest\\testmulPE_win7\\data_sprase.npy")[()]
y_train = np.load("D:\\machineTest\\testmulPE_win7\\lable_sprase.npy")

# 获取分类类别总数
classes = len(np.unique(y_train))

#对label进行one-hot编码,必须的
label_encoder = LabelEncoder()
integer_encoded = label_encoder.fit_transform(y_train)
onehot_encoder = OneHotEncoder(sparse=False)
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
y_train = onehot_encoder.fit_transform(integer_encoded)

#shuffle
X_train, X_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.3, random_state=0)

model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=784))
model.add(Dense(units=classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])
model.fit(X_train, y_train, epochs=50, batch_size=128)
score = model.evaluate(X_test, y_test, batch_size=128)
# #fit参数详情
# keras.models.fit(
# self,
# x=None, #训练数据
# y=None, #训练数据label标签
# batch_size=None, #每经过多少个sample更新一次权重,defult 32
# epochs=1, #训练的轮数epochs
# verbose=1, #0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
# callbacks=None,#list,list中的元素为keras.callbacks.Callback对象,在训练过程中会调用list中的回调函数
# validation_split=0., #浮点数0-1,将训练集中的一部分比例作为验证集,然后下面的验证集validation_data将不会起到作用
# validation_data=None, #验证集
# shuffle=True, #布尔值和字符串,如果为布尔值,表示是否在每一次epoch训练前随机打乱输入样本的顺序,如果为"batch",为处理HDF5数据
# class_weight=None, #dict,分类问题的时候,有的类别可能需要额外关注,分错的时候给的惩罚会比较大,所以权重会调高,体现在损失函数上面
# sample_weight=None, #array,和输入样本对等长度,对输入的每个特征+个权值,如果是时序的数据,则采用(samples,sequence_length)的矩阵
# initial_epoch=0, #如果之前做了训练,则可以从指定的epoch开始训练
# steps_per_epoch=None, #将一个epoch分为多少个steps,也就是划分一个batch_size多大,比如steps_per_epoch=10,则就是将训练集分为10份,不能和batch_size共同使用
# validation_steps=None, #当steps_per_epoch被启用的时候才有用,验证集的batch_size
# **kwargs #用于和后端交互
# )
# 
# 返回的是一个History对象,可以通过History.history来查看训练过程,loss值等等

第二种,fit_generator(节省内存)

# 第二种,可以节省内存
'''
Created on 2018-4-11
fit_generate.txt,后面两列为lable,已经one-hot编码
1 2 0 1
2 3 1 0
1 3 0 1
1 4 0 1
2 4 1 0
2 5 1 0

'''
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split

count =1 
def generate_arrays_from_file(path):
 global count
 while 1:
  datas = np.loadtxt(path,delimiter=' ',dtype="int")
  x = datas[:,:2]
  y = datas[:,2:]
  print("count:"+str(count))
  count = count+1
  yield (x,y)
x_valid = np.array([[1,2],[2,3]])
y_valid = np.array([[0,1],[1,0]])
model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])

model.fit_generator(generate_arrays_from_file("D:\\fit_generate.txt"),steps_per_epoch=10, epochs=2,max_queue_size=1,validation_data=(x_valid, y_valid),workers=1)
# steps_per_epoch 每执行一次steps,就去执行一次生产函数generate_arrays_from_file
# max_queue_size 从生产函数中出来的数据时可以缓存在queue队列中
# 输出如下:
# Epoch 1/2
# count:1
# count:2
# 
# 1/10 [==>...........................] - ETA: 2s - loss: 0.7145 - acc: 0.3333count:3
# count:4
# count:5
# count:6
# count:7
# 
# 7/10 [====================>.........] - ETA: 0s - loss: 0.7001 - acc: 0.4286count:8
# count:9
# count:10
# count:11
# 
# 10/10 [==============================] - 0s 36ms/step - loss: 0.6960 - acc: 0.4500 - val_loss: 0.6794 - val_acc: 0.5000
# Epoch 2/2
# 
# 1/10 [==>...........................] - ETA: 0s - loss: 0.6829 - acc: 0.5000count:12
# count:13
# count:14
# count:15
# 
# 5/10 [==============>...............] - ETA: 0s - loss: 0.6800 - acc: 0.5000count:16
# count:17
# count:18
# count:19
# count:20
# 
# 10/10 [==============================] - 0s 11ms/step - loss: 0.6766 - acc: 0.5000 - val_loss: 0.6662 - val_acc: 0.5000

补充知识:

自动生成数据还可以继承keras.utils.Sequence,然后写自己的生成数据类:

keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):
 
 def __init__(self, datas, batch_size=1, shuffle=True):
  self.batch_size = batch_size
  self.datas = datas
  self.indexes = np.arange(len(self.datas))
  self.shuffle = shuffle

 def __len__(self):
  #计算每一个epoch的迭代次数
  return math.ceil(len(self.datas) / float(self.batch_size))

 def __getitem__(self, index):
  #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
  # 生成batch_size个索引
  batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
  # 根据索引获取datas集合中的数据
  batch_datas = [self.datas[k] for k in batch_indexs]

  # 生成数据
  X, y = self.data_generation(batch_datas)

  return X, y

 def on_epoch_end(self):
  #在每一次epoch结束是否需要进行一次随机,重新随机一下index
  if self.shuffle == True:
   np.random.shuffle(self.indexes)

 def data_generation(self, batch_datas):
  images = []
  labels = []

  # 生成数据
  for i, data in enumerate(batch_datas):
   #x_train数据
   image = cv2.imread(data)
   image = list(image)
   images.append(image)
   #y_train数据 
   right = data.rfind("\\",0)
   left = data.rfind("\\",0,right)+1
   class_name = data[left:right]
   if class_name=="dog":
    labels.append([0,1])
   else: 
    labels.append([1,0])
  #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
  return np.array(images), np.array(labels)
 
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = [] 
for file in os.listdir("D:/xxx"):
 file_path = os.path.join("D:/xxx", file)
 if os.path.isdir(file_path):
  class_num = class_num + 1
  for sub_file in os.listdir(file_path):
   train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras 两种训练模型方式详解fit和fit_generator(节省内存)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python之yield表达式学习
Sep 02 Python
将Emacs打造成强大的Python代码编辑工具
Nov 20 Python
关于Tensorflow中的tf.train.batch函数的使用
Apr 24 Python
使用python制作游戏下载进度条的代码(程序说明见注释)
Oct 24 Python
使用Pytorch来拟合函数方式
Jan 14 Python
Django Admin设置应用程序及模型顺序方法详解
Apr 01 Python
python继承threading.Thread实现有返回值的子类实例
May 02 Python
Django QuerySet查询集原理及代码实例
Jun 13 Python
python操作toml文件的示例代码
Nov 27 Python
2021年最新用于图像处理的Python库总结
Jun 15 Python
Python字典的基础操作
Nov 01 Python
Python使用openpyxl模块处理Excel文件
Jun 05 Python
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
Jul 03 #Python
keras分类模型中的输入数据与标签的维度实例
Jul 03 #Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
You might like
操作Oracle的php类
2006/10/09 PHP
php array_map array_multisort 高效处理多维数组排序
2009/06/11 PHP
php去除HTML标签实例
2013/11/06 PHP
Fleaphp常见函数功能与用法示例
2016/11/15 PHP
yii2使用GridView实现数据全选及批量删除按钮示例
2017/03/01 PHP
Gambit vs ForZe BO3 第三场 2.13
2021/03/10 DOTA
JS 获取select(多选下拉)中所选值的示例代码
2013/08/02 Javascript
jQuery中验证表单提交方式及序列化表单内容的实现
2014/01/06 Javascript
jquery跨域请求示例分享(jquery发送ajax请求)
2014/03/25 Javascript
avalonjs制作响应式瀑布流特效
2015/05/06 Javascript
Javascript验证方法大全
2015/09/21 Javascript
JavaScript常用基础知识强化学习
2015/12/09 Javascript
javascript html5实现表单验证
2016/03/01 Javascript
详解JS对象封装的常用方式
2016/12/30 Javascript
bootstrap table使用入门基本用法
2017/05/24 Javascript
详解webpack3如何正确引用并使用jQuery库
2017/08/26 jQuery
JS二分查找算法详解
2017/11/01 Javascript
jQuery实现为动态添加的元素绑定事件实例分析
2018/09/07 jQuery
在vue+element ui框架里实现lodash的debounce防抖
2019/11/13 Javascript
JavaScript中变量提升机制示例详解
2019/12/27 Javascript
详解vue beforeEach 死循环问题解决方法
2020/02/25 Javascript
js实现移动端图片滑块验证功能
2020/09/29 Javascript
[00:12]DAC2018 no[o]ne亮相SOLO赛 他是否如他的id一样无人可挡?
2018/04/06 DOTA
编写Python爬虫抓取暴走漫画上gif图片的实例分享
2016/04/20 Python
django中send_mail功能实现详解
2018/02/06 Python
Django框架模板文件使用及模板文件加载顺序分析
2019/05/23 Python
给Django Admin添加验证码和多次登录尝试限制的实现
2020/07/26 Python
优秀员工个人的自我评价
2013/11/29 职场文书
电子商务专业推荐信范文
2013/12/02 职场文书
大学生个人求职口试自我评价
2014/02/16 职场文书
高中学生自我评价范文
2014/09/23 职场文书
在职人员跳槽求职信
2015/03/20 职场文书
apache基于端口创建虚拟主机的示例
2021/04/22 Servers
go使用Gin框架利用阿里云实现短信验证码功能
2021/08/04 Golang
解决linux下redis数据库overcommit_memory问题
2022/02/24 Redis
PostgreSQL并行计算算法及参数强制并行度设置方法
2022/04/06 PostgreSQL