python神经网络学习 使用Keras进行简单分类


Posted in Python onMay 04, 2022

学习前言

上一步讲了如何构建回归算法,这一次将怎么进行简单分类。

Keras中分类的重要函数

1、np_utils.to_categorical

np_utils.to_categorical用于将标签转化为形如(nb_samples, nb_classes)的二值序列。

假设num_classes = 10。

如将[1,2,3,……4]转化成:

[[0,1,0,0,0,0,0,0]
[0,0,1,0,0,0,0,0]
[0,0,0,1,0,0,0,0]
……
[0,0,0,0,1,0,0,0]]

这样的形态。

如将Y_train转化为二值序列,可以用如下方式:

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)

2、Activation

Activation是激活函数,一般在每一层的输出使用。

当我们使用Sequential模型构建函数的时候,只需要在每一层Dense后面添加Activation就可以了。

Sequential函数也支持直接在参数中完成所有层的构建,使用方法如下。

model = Sequential([
    Dense(32,input_dim = 784),
    Activation("relu"),
    Dense(10),
    Activation("softmax")
    ]
)

其中两次Activation分别使用了relu函数和softmax函数。

3、metrics=[‘accuracy’]

在model.compile中添加metrics=[‘accuracy’]表示需要计算分类精确度,具体使用方式如下:

model.compile(
	loss = 'categorical_crossentropy',
	optimizer = rmsprop,
	metrics=['accuracy']
)

全部代码

这是一个简单的仅含有一个隐含层的神经网络,用于完成手写体识别。在本例中,使用的优化器是RMSprop,具体可以使用的优化器可以参照Keras中文文档

import numpy as np
from keras.models import Sequential
from keras.layers import Dense,Activation ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import RMSprop
# 获取训练集
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
# 首先进行标准化 
X_train = X_train.reshape(X_train.shape[0],-1)/255
X_test = X_test.reshape(X_test.shape[0],-1)/255
# 计算categorical_crossentropy需要对分类结果进行categorical
# 即需要将标签转化为形如(nb_samples, nb_classes)的二值序列
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
# 构建模型
model = Sequential([
    Dense(32,input_dim = 784),
    Activation("relu"),
    Dense(10),
    Activation("softmax")
    ]
)
rmsprop = RMSprop(lr = 0.001,rho = 0.9,epsilon = 1e-08,decay = 0)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = rmsprop,metrics=['accuracy'])
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 32)
print("\nTest")
cost,accuracy = model.evaluate(X_test,Y_test)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)

实验结果为:

Epoch 1/2
60000/60000 [==============================] - 12s 202us/step - loss: 0.3512 - acc: 0.9022
Epoch 2/2
60000/60000 [==============================] - 11s 183us/step - loss: 0.2037 - acc: 0.9419
Test
10000/10000 [==============================] - 1s 108us/step
accuracy: 0.9464

以上就是python神经网络学习使用Keras进行简单分类的详细内容!


Tags in this post...

Python 相关文章推荐
python实现2014火车票查询代码分享
Jan 10 Python
python算法演练_One Rule 算法(详解)
May 17 Python
Python使用回溯法子集树模板解决爬楼梯问题示例
Sep 08 Python
Python OpenCV实现图片上输出中文
Jan 22 Python
解决python中遇到字典里key值为None的情况,取不出来的问题
Oct 17 Python
解决pandas .to_excel不覆盖已有sheet的问题
Dec 10 Python
详解Django将秒转换为xx天xx时xx分
Sep 27 Python
Python利用多线程同步锁实现多窗口订票系统(推荐)
Dec 22 Python
Windows下Pycharm远程连接虚拟机中Centos下的Python环境(图文教程详解)
Mar 19 Python
Python使用configparser读取ini配置文件
May 25 Python
Python2及Python3如何实现兼容切换
Sep 01 Python
python3排序的实例方法
Oct 20 Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
python pygame 开发五子棋双人对弈
May 02 #Python
Python开发简易五子棋小游戏
May 02 #Python
You might like
2021年最新CPU天梯图
2021/03/04 数码科技
PHP chmod 函数与批量修改文件目录权限
2010/05/10 PHP
PHP文件操作实例总结【文件上传、下载、分页】
2018/12/08 PHP
使用Js让Html中特殊字符不被转义
2013/11/05 Javascript
图片翻转效果具体实现代码
2014/01/09 Javascript
为jQuery添加Webkit的触摸的方法分享
2014/02/02 Javascript
extjs_02_grid显示本地数据、显示跨域数据
2014/06/23 Javascript
jquery 插件实现多行文本框[textarea]自动高度
2015/03/04 Javascript
JavaScript的Date()方法使用详解
2015/06/09 Javascript
javascript实现九宫格相加数值相等
2020/05/28 Javascript
使用jquery如何获取时间
2016/10/13 Javascript
详解nodejs微信公众号开发——6.自定义菜单
2017/04/13 NodeJs
基于require.js的使用(实例讲解)
2017/09/07 Javascript
微信小程序实现打卡日历功能
2020/09/21 Javascript
微信小程序select下拉框实现效果
2019/05/15 Javascript
基于JavaScript获取base64图片大小
2019/10/18 Javascript
[51:20]完美世界DOTA2联赛PWL S2 Magma vs PXG 第一场 11.28
2020/12/01 DOTA
[01:16:12]完美世界DOTA2联赛PWL S2 FTD vs Inki 第一场 11.21
2020/11/23 DOTA
Python3读取zip文件信息的方法
2015/05/22 Python
利用python实现简单的循环购物车功能示例代码
2017/07/05 Python
对python多线程中Lock()与RLock()锁详解
2019/01/11 Python
微信小程序python用户认证的实现
2019/07/29 Python
Python 用matplotlib画以时间日期为x轴的图像
2019/08/06 Python
Python制作词云图代码实例
2019/09/09 Python
python如何将两个txt文件内容合并
2019/10/18 Python
Numpy之reshape()使用详解
2019/12/26 Python
利用HTML5画出一个坦克的形状具体实现代码
2013/06/20 HTML / CSS
HTML5 的新的表单元素(datalist/keygen/output)使用介绍
2013/07/19 HTML / CSS
html5 冒号分隔符对齐的实现
2019/07/31 HTML / CSS
惠普墨西哥官方商店:HP墨西哥
2016/12/01 全球购物
商务日语毕业生自荐信
2013/11/23 职场文书
学生会主席事迹材料
2014/01/28 职场文书
大学生党员自我批评
2014/02/14 职场文书
入职担保书怎么写
2014/05/12 职场文书
冬季作息时间调整通知
2015/04/24 职场文书
anaconda python3.8安装后降级
2021/06/11 Python