python神经网络学习 使用Keras进行简单分类


Posted in Python onMay 04, 2022

学习前言

上一步讲了如何构建回归算法,这一次将怎么进行简单分类。

Keras中分类的重要函数

1、np_utils.to_categorical

np_utils.to_categorical用于将标签转化为形如(nb_samples, nb_classes)的二值序列。

假设num_classes = 10。

如将[1,2,3,……4]转化成:

[[0,1,0,0,0,0,0,0]
[0,0,1,0,0,0,0,0]
[0,0,0,1,0,0,0,0]
……
[0,0,0,0,1,0,0,0]]

这样的形态。

如将Y_train转化为二值序列,可以用如下方式:

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)

2、Activation

Activation是激活函数,一般在每一层的输出使用。

当我们使用Sequential模型构建函数的时候,只需要在每一层Dense后面添加Activation就可以了。

Sequential函数也支持直接在参数中完成所有层的构建,使用方法如下。

model = Sequential([
    Dense(32,input_dim = 784),
    Activation("relu"),
    Dense(10),
    Activation("softmax")
    ]
)

其中两次Activation分别使用了relu函数和softmax函数。

3、metrics=[‘accuracy’]

在model.compile中添加metrics=[‘accuracy’]表示需要计算分类精确度,具体使用方式如下:

model.compile(
	loss = 'categorical_crossentropy',
	optimizer = rmsprop,
	metrics=['accuracy']
)

全部代码

这是一个简单的仅含有一个隐含层的神经网络,用于完成手写体识别。在本例中,使用的优化器是RMSprop,具体可以使用的优化器可以参照Keras中文文档

import numpy as np
from keras.models import Sequential
from keras.layers import Dense,Activation ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import RMSprop
# 获取训练集
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
# 首先进行标准化 
X_train = X_train.reshape(X_train.shape[0],-1)/255
X_test = X_test.reshape(X_test.shape[0],-1)/255
# 计算categorical_crossentropy需要对分类结果进行categorical
# 即需要将标签转化为形如(nb_samples, nb_classes)的二值序列
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
# 构建模型
model = Sequential([
    Dense(32,input_dim = 784),
    Activation("relu"),
    Dense(10),
    Activation("softmax")
    ]
)
rmsprop = RMSprop(lr = 0.001,rho = 0.9,epsilon = 1e-08,decay = 0)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = rmsprop,metrics=['accuracy'])
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 32)
print("\nTest")
cost,accuracy = model.evaluate(X_test,Y_test)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)

实验结果为:

Epoch 1/2
60000/60000 [==============================] - 12s 202us/step - loss: 0.3512 - acc: 0.9022
Epoch 2/2
60000/60000 [==============================] - 11s 183us/step - loss: 0.2037 - acc: 0.9419
Test
10000/10000 [==============================] - 1s 108us/step
accuracy: 0.9464

以上就是python神经网络学习使用Keras进行简单分类的详细内容!


Tags in this post...

Python 相关文章推荐
Python专用方法与迭代机制实例分析
Sep 15 Python
探索Python3.4中新引入的asyncio模块
Apr 08 Python
详解Python中的__getitem__方法与slice对象的切片操作
Jun 27 Python
Python网络编程 Python套接字编程
Sep 13 Python
pandas中的DataFrame按指定顺序输出所有列的方法
Apr 10 Python
在python中pandas的series合并方法
Nov 12 Python
python生成多个只含0,1元素的随机数组或列表的实例
Nov 12 Python
python交易记录整合交易类详解
Jul 03 Python
Python中的几种矩阵乘法(小结)
Jul 10 Python
python实现指定ip端口扫描方式
Dec 17 Python
Tensorflow 实现将图像与标签数据转化为tfRecord文件
Feb 17 Python
pytorch分类模型绘制混淆矩阵以及可视化详解
Apr 07 Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
python pygame 开发五子棋双人对弈
May 02 #Python
Python开发简易五子棋小游戏
May 02 #Python
You might like
php防注
2007/01/15 PHP
PHP自动选择 连接本地还是远程数据库
2010/12/02 PHP
Yii实现简单分页的方法
2016/04/29 PHP
Yii框架表单提交验证功能分析
2017/01/07 PHP
JS 自定义函数缺省值的设置方法
2010/05/05 Javascript
jquery中focus()函数实现当对象获得焦点后自动把光标移到内容最后
2013/09/29 Javascript
对JavaScript客户端应用编程的一些建议
2015/06/24 Javascript
基于JavaScript实现鼠标悬浮弹出跟随鼠标移动的带箭头的信息层
2016/01/18 Javascript
jquery UI Datepicker时间控件冲突问题解决
2016/12/16 Javascript
JavaScript 异步调用
2017/10/25 Javascript
js 原生判断内容区域是否滚动到底部的实例代码
2017/11/15 Javascript
jquery获取transform里的值实现方法
2017/12/12 jQuery
mpvue全局引入sass文件的方法步骤
2019/03/06 Javascript
详解微信小程序网络请求接口封装实例
2019/05/02 Javascript
vue使用websocket的方法实例分析
2019/06/22 Javascript
vue eslint简要配置教程详解
2019/07/26 Javascript
解决Vue动态加载本地图片问题
2019/10/09 Javascript
借助云开发实现小程序短信验证码的发送
2020/01/06 Javascript
JavaScript仿京东秒杀倒计时
2020/03/17 Javascript
Python中的闭包总结
2014/09/18 Python
Python编程中time模块的一些关键用法解析
2016/01/19 Python
Python 循环语句之 while,for语句详解
2018/04/23 Python
windows下添加Python环境变量的方法汇总
2018/05/14 Python
python实现计数排序与桶排序实例代码
2019/03/28 Python
Python装饰器用法与知识点小结
2020/03/09 Python
Python的历史与优缺点整理
2020/05/26 Python
python如何利用Mitmproxy抓包
2020/10/10 Python
绢花、人造花和人造花卉:BLOOM
2019/08/07 全球购物
什么是触发器(trigger)? 触发器有什么作用?
2013/09/18 面试题
市场部管理制度
2014/02/02 职场文书
会计电算化专业求职信
2014/06/10 职场文书
2015年乡镇环保工作总结
2015/04/22 职场文书
确保减税降费落地生根,用实实在在措施
2019/07/19 职场文书
go语言中GOPATH GOROOT的作用和设置方式
2021/05/05 Golang
Eclipse+Java+Swing+Mysql实现电影购票系统(详细代码)
2022/01/18 Java/Android
react中useState使用:如何实现在当前表格直接更改数据
2022/08/05 Javascript