Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作


Posted in Python onMay 25, 2021

使用keras实现CNN,直接上代码:

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras import backend as K
 
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = {'batch':[], 'epoch':[]}
        self.accuracy = {'batch':[], 'epoch':[]}
        self.val_loss = {'batch':[], 'epoch':[]}
        self.val_acc = {'batch':[], 'epoch':[]}
 
    def on_batch_end(self, batch, logs={}):
        self.losses['batch'].append(logs.get('loss'))
        self.accuracy['batch'].append(logs.get('acc'))
        self.val_loss['batch'].append(logs.get('val_loss'))
        self.val_acc['batch'].append(logs.get('val_acc'))
 
    def on_epoch_end(self, batch, logs={}):
        self.losses['epoch'].append(logs.get('loss'))
        self.accuracy['epoch'].append(logs.get('acc'))
        self.val_loss['epoch'].append(logs.get('val_loss'))
        self.val_acc['epoch'].append(logs.get('val_acc'))
 
    def loss_plot(self, loss_type):
        iters = range(len(self.losses[loss_type]))
        plt.figure()
        # acc
        plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
        # loss
        plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
        if loss_type == 'epoch':
            # val_acc
            plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
            # val_loss
            plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
        plt.grid(True)
        plt.xlabel(loss_type)
        plt.ylabel('acc-loss')
        plt.legend(loc="upper right")
        plt.show()
 
history = LossHistory()
 
batch_size = 128
nb_classes = 10
nb_epoch = 20
img_rows, img_cols = 28, 28
nb_filters = 32
pool_size = (2,2)
kernel_size = (3,3)
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
 
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
 
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
 
model3 = Sequential()
 
model3.add(Convolution2D(nb_filters, kernel_size[0] ,kernel_size[1],
                        border_mode='valid',
                        input_shape=input_shape))
model3.add(Activation('relu'))
 
model3.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1]))
model3.add(Activation('relu'))
 
model3.add(MaxPooling2D(pool_size=pool_size))
model3.add(Dropout(0.25))
 
model3.add(Flatten())
 
model3.add(Dense(128))
model3.add(Activation('relu'))
model3.add(Dropout(0.5))
 
model3.add(Dense(nb_classes))
model3.add(Activation('softmax'))
 
model3.summary()
 
model3.compile(loss='categorical_crossentropy',
              optimizer='adadelta',
              metrics=['accuracy'])
 
model3.fit(X_train, Y_train, batch_size=batch_size, epochs=nb_epoch,
          verbose=1, validation_data=(X_test, Y_test),callbacks=[history])
 
score = model3.evaluate(X_test, Y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])
 
#acc-loss
history.loss_plot('epoch')

补充:使用keras全连接网络训练mnist手写数字识别并输出可视化训练过程以及预测结果

前言

mnist 数字识别问题的可以直接使用全连接实现但是效果并不像CNN卷积神经网络好。Keras是目前最为广泛的深度学习工具之一,底层可以支持Tensorflow、MXNet、CNTK、Theano

准备工作

TensorFlow版本:1.13.1

Keras版本:2.1.6

Numpy版本:1.18.0

matplotlib版本:2.2.2

导入所需的库

from keras.layers import Dense,Flatten,Dropout
from keras.datasets import mnist
from keras import Sequential
import matplotlib.pyplot as plt
import numpy as np

Dense输入层作为全连接,Flatten用于全连接扁平化操作(也就是将二维打成一维),Dropout避免过拟合。使用datasets中的mnist的数据集,Sequential用于构建模型,plt为可视化,np用于处理数据。

划分数据集

# 训练集       训练集标签       测试集      测试集标签
(train_image,train_label),(test_image,test_label) = mnist.load_data()
print('shape:',train_image.shape)   #查看训练集的shape
plt.imshow(train_image[0])    #查看第一张图片
print('label:',train_label[0])      #查看第一张图片对应的标签
plt.show()

输出shape以及标签label结果:

Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作

查看mnist数据集中第一张图片:

Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作

数据归一化

train_image = train_image.astype('float32')
test_image = test_image.astype('float32')
train_image /= 255.0
test_image /= 255.0

将数据归一化,以便于训练的时候更快的收敛。

模型构建

#初始化模型(模型的优化 ---> 增大网络容量,直到过拟合)
model = Sequential()
model.add(Flatten(input_shape=(28,28)))    #将二维扁平化为一维(60000,28,28)---> (60000,28*28)输入28*28个神经元
model.add(Dropout(0.1))
model.add(Dense(1024,activation='relu'))   #全连接层 输出64个神经元 ,kernel_regularizer=l2(0.0003)
model.add(Dropout(0.1))
model.add(Dense(512,activation='relu'))    #全连接层
model.add(Dropout(0.1))
model.add(Dense(256,activation='relu'))    #全连接层
model.add(Dropout(0.1))
model.add(Dense(10,activation='softmax'))  #输出层,10个类别,用softmax分类

每层使用一次Dropout防止过拟合,激活函数使用relu,最后一层Dense神经元设置为10,使用softmax作为激活函数,因为只有0-9个数字。如果是二分类问题就使用sigmod函数来处理。

编译模型

#编译模型
model.compile(
    optimizer='adam',      #优化器使用默认adam
    loss='sparse_categorical_crossentropy', #损失函数使用sparse_categorical_crossentropy
    metrics=['acc']       #评价指标
)

sparse_categorical_crossentropy与categorical_crossentropy的区别:

sparse_categorical_crossentropy要求target为非One-hot编码,函数内部进行One-hot编码实现。

categorical_crossentropy要求target为One-hot编码。

One-hot格式如: [0,0,0,0,0,1,0,0,0,0] = 5

训练模型

#训练模型
history = model.fit(
    x=train_image,                          #训练的图片
    y=train_label,                          #训练的标签
    epochs=10,                              #迭代10次
    batch_size=512,                         #划分批次
    validation_data=(test_image,test_label) #验证集
)

迭代10次后的结果:

Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作

绘制loss、acc图

#绘制loss acc图
plt.figure()
plt.plot(history.history['acc'],label='training acc')
plt.plot(history.history['val_acc'],label='val acc')
plt.title('model acc')
plt.ylabel('acc')
plt.xlabel('epoch')
plt.legend(loc='lower right')
plt.figure()
plt.plot(history.history['loss'],label='training loss')
plt.plot(history.history['val_loss'],label='val loss')
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper right')
plt.show()

绘制出的loss变化图:

Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作

绘制出的acc变化图:

Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作

预测结果

print("前十个图片对应的标签: ",test_label[:10]) #前十个图片对应的标签
print("取前十张图片测试集预测:",np.argmax(model.predict(test_image[:10]),axis=1)) #取前十张图片测试集预测

打印的结果:

Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作

可看到在第9个数字预测错了,标签为5的,预测成了6,为了避免这种问题可以适当的加深网络结构,或使用CNN模型。

保存模型

model.save('./mnist_model.h5')

完整代码

from keras.layers import Dense,Flatten,Dropout
from keras.datasets import mnist
from keras import Sequential
import matplotlib.pyplot as plt
import numpy as np
# 训练集       训练集标签       测试集      测试集标签
(train_image,train_label),(test_image,test_label) = mnist.load_data()
# print('shape:',train_image.shape)   #查看训练集的shape
# plt.imshow(train_image[0]) #查看第一张图片
# print('label:',train_label[0])      #查看第一张图片对应的标签
# plt.show()
#归一化(收敛)
train_image = train_image.astype('float32')
test_image = test_image.astype('float32')
train_image /= 255.0
test_image /= 255.0
#初始化模型(模型的优化 ---> 增大网络容量,直到过拟合)
model = Sequential()
model.add(Flatten(input_shape=(28,28)))   #将二维扁平化为一维(60000,28,28)---> (60000,28*28)输入28*28个神经元
model.add(Dropout(0.1))
model.add(Dense(1024,activation='relu'))    #全连接层 输出64个神经元 ,kernel_regularizer=l2(0.0003)
model.add(Dropout(0.1))
model.add(Dense(512,activation='relu'))    #全连接层
model.add(Dropout(0.1))
model.add(Dense(256,activation='relu'))    #全连接层
model.add(Dropout(0.1))
model.add(Dense(10,activation='softmax')) #输出层,10个类别,用softmax分类
#编译模型
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['acc']
)
#训练模型
history = model.fit(
    x=train_image,                          #训练的图片
    y=train_label,                          #训练的标签
    epochs=10,                              #迭代10次
    batch_size=512,                         #划分批次
    validation_data=(test_image,test_label) #验证集
)
#绘制loss acc 图
plt.figure()
plt.plot(history.history['acc'],label='training acc')
plt.plot(history.history['val_acc'],label='val acc')
plt.title('model acc')
plt.ylabel('acc')
plt.xlabel('epoch')
plt.legend(loc='lower right')
plt.figure()
plt.plot(history.history['loss'],label='training loss')
plt.plot(history.history['val_loss'],label='val loss')
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper right')
plt.show()
print("前十个图片对应的标签: ",test_label[:10]) #前十个图片对应的标签
print("取前十张图片测试集预测:",np.argmax(model.predict(test_image[:10]),axis=1)) #取前十张图片测试集预测
#优化前(一个全连接层(隐藏层))
#- 1s 12us/step - loss: 1.8765 - acc: 0.8825
# [7 2 1 0 4 1 4 3 5 4]
# [7 2 1 0 4 1 4 9 5 9]
#优化后(三个全连接层(隐藏层))
#- 1s 14us/step - loss: 0.0320 - acc: 0.9926 - val_loss: 0.2530 - val_acc: 0.9655
# [7 2 1 0 4 1 4 9 5 9]
# [7 2 1 0 4 1 4 9 5 9]
model.save('./model_nameALL.h5')

总结

使用全连接层训练得到的最后结果train_loss: 0.0242 - train_acc: 0.9918 - val_loss: 0.0560 - val_acc: 0.9826,由loss acc可视化图可以看出训练有着明显的效果。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的下载8000首儿歌的代码分享
Nov 21 Python
Python的净值数据接口调用示例分享
Mar 15 Python
spyder常用快捷键(分享)
Jul 19 Python
python实现在pandas.DataFrame添加一行
Apr 04 Python
一篇文章读懂Python赋值与拷贝
Apr 19 Python
python去除拼音声调字母,替换为字母的方法
Nov 28 Python
python利用selenium进行浏览器爬虫
Apr 25 Python
python使用yield压平嵌套字典的超简单方法
Nov 02 Python
PyTorch笔记之scatter()函数的使用
Feb 12 Python
Keras 快速解决OOM超内存的问题
Jun 11 Python
Python中用xlwt制作表格实例讲解
Nov 05 Python
一劳永逸彻底解决pip install慢的办法
May 24 Python
python编写五子棋游戏
浅谈python数据类型及其操作
对Keras自带Loss Function的深入研究
May 25 #Python
pytorch中的model=model.to(device)使用说明
May 24 #Python
解决pytorch-gpu 安装失败的记录
May 24 #Python
如何解决.cuda()加载用时很长的问题
一劳永逸彻底解决pip install慢的办法
May 24 #Python
You might like
PHP调用三种数据库的方法(1)
2006/10/09 PHP
php4的session功能评述(一)
2006/10/09 PHP
php IP转换整形(ip2long)的详解
2013/06/06 PHP
php页面跳转session cookie丢失导致不能登录等问题的解决方法
2016/12/12 PHP
Laravel实现autoload方法详解
2017/05/07 PHP
PHP数据库编程之MySQL优化策略概述
2017/08/16 PHP
php在windows环境下获得cpu内存实时使用率(推荐)
2018/02/08 PHP
php实现生成带二维码图片并强制下载功能
2018/02/24 PHP
JavaScript通过function定义对象并给对象添加toString()方法实例分析
2015/03/23 Javascript
jquery.multiselect多选下拉框实现代码
2016/11/11 Javascript
Ajax基础知识详解
2017/02/17 Javascript
react-redux中connect()方法详细解析
2017/05/27 Javascript
Vue cli3 库模式搭建组件库并发布到 npm的流程
2018/10/12 Javascript
JS/HTML5游戏常用算法之碰撞检测 地图格子算法实例详解
2018/12/12 Javascript
使用 Vue cli 3.0 构建自定义组件库的方法
2019/04/30 Javascript
JS 实现发送短信验证码的“59秒后重新发送验证短信”功能
2019/08/23 Javascript
javascript实现函数柯里化与反柯里化过程解析
2019/10/08 Javascript
vue 公共列表选择组件,引用Vant-UI的样式方式
2020/11/02 Javascript
如何利用Fabric自动化你的任务
2016/10/20 Python
Python数据分析之获取双色球历史信息的方法示例
2018/02/03 Python
python实现windows倒计时锁屏功能
2019/07/30 Python
Pymysql实现往表中插入数据过程解析
2020/06/02 Python
python实现二分查找算法
2020/09/18 Python
巧用CSS3 border实现图片遮罩效果代码
2012/04/09 HTML / CSS
波兰购物网站:MALL.PL
2019/05/01 全球购物
军用级手机壳,专为冒险而建:Zizo Wireless
2019/08/07 全球购物
机械专业个人求职自荐信格式
2013/09/21 职场文书
销售员个人求职的自我评价
2014/02/10 职场文书
经典婚礼主持词
2014/03/13 职场文书
幼儿园春季开学寄语
2014/04/03 职场文书
市场营销战略计划书
2014/05/06 职场文书
法院四风对照检查材料思想汇报
2014/10/06 职场文书
小兵张嘎电影观后感
2015/06/03 职场文书
开学第一天的感想
2015/08/10 职场文书
10大幻兽系恶魔果实 蝙蝠果实上榜,第一自愈能力强
2022/03/18 日漫
SQLyog的下载、安装、破解、配置教程(MySQL可视化工具安装)
2022/09/23 MySQL